#!/usr/bin/env python

'''
     Usage:  
        j2sqlgen --config <configfile> --source <schemafile> [(--schema <db_schema>)] --table <tablename>
        j2sqlgen --config <configfile> --source <schemafile> [(--schema <db_schema>)] --tables <comma_delim_tbl_list>
        j2sqlgen --config <configfile> --source <schemafile> [(--schema <db_schema>)] --all-tables
        j2sqlgen --config <configfile> --source <schemafile> --list_tables

'''

import os, sys
import json
from snap import common
import docopt


BIGQUERY_STRING_TYPES = ['STRING']

TABLE_DDL_TEMPLATE = '''
CREATE TABLE {schema}.{tablename} (
{table_body}
) {table_suffix}
'''

TABLE_LEVEL_SETTINGS = [
    'autocreate_pk',
    'pk_name',
    'pk_type', 
    'autocreated_pk_default',
    'varchar_length',
    'column_type_map'
    'column_name_map'
]

COLUMN_LEVEL_SETTINGS = [
    'varchar_length'
    'column_type_map'
    'column_name_map'
]


def is_string_type(type_name):
    if type_name in BIGQUERY_STRING_TYPES:
        return True
    return False


class DuplicateColumn(Exception):
    def __init__(self, column_name):
        Exception.__init__(self, 'TableSpec already contains column "%s".' % column_name)


class ColumnSpec(object):
    def __init__(self, table_name, column_name, column_type, is_nullable=True, is_pk=False, **kwargs):
        self.tablename = table_name
        self.name = column_name
        self.datatype = column_type
        self.nullable = is_nullable
        self.length = kwargs.get('length', '') # can be null; usually only applies to string types
        self.is_primary_key = is_pk
        self.default_value = kwargs.get('default')        
            
    def sql(self, table_creation_context):
        null_clause = ''
        if not self.nullable:
            null_clause = 'NOT NULL'
        
        if self.length:
            length_clause = '(%s)' % self.length
        else:
            length_clause = ''

        src_colname = self.name
        src_coltype = self.datatype
        if self.default_value is None:
            template = '{name} {datatype}{length} {null} {column_suffix}'
            
            field_line = template.format(name=table_creation_context.get_target_column_name(self.tablename,
                                                                                            self.name),
                                         datatype=table_creation_context.get_target_column_type(self.tablename,
                                                                                                self.name,
                                                                                                self.datatype),
                                         length=length_clause,
                                         null=null_clause,
                                         column_suffix=table_creation_context.get_column_suffix(self.tablename, self.name))
        else:
            template = '{name} {datatype}{length} {null} DEFAULT {default} {column_suffix}'
            field_line = template.format(name=table_creation_context.get_target_column_name(self.tablename,
                                                                                            self.name),
                                        datatype=table_creation_context.get_target_column_type(self.tablename,
                                                                                               self.datatype),
                                        length=length_clause,
                                        null=null_clause,
                                        default=self.default_value,
                                        column_suffix=table_creation_context.get_column_suffix(self.tablename, self.name))
        if self.is_primary_key:
            field_line = field_line + ' PRIMARY KEY'                                                                                                            
        return field_line.lstrip().rstrip()                                                    


class TableSpec(object):
    def __init__(self, table_name, schema_name, **kwargs):
        self.name = table_name
        self.schema = schema_name
        self.columns = []


    def has_pk(self):
        for columnspec in self.columns:
            if columnspec.is_primary_key:
                return True
        return False


    def generate_pk_column(self, table_creation_context):
        return ColumnSpec(self.name,
                          table_creation_context.pk_name,
                          table_creation_context.pk_type,
                          False,
                          True,
                          default=table_creation_context.pk_value)


    def insert_pk(self, columnspec):
        self.columns.insert(0, columnspec)


    def add_columnspec(self, columnspec):
        self.columns.append(columnspec)


    def add_column(self, source_column_name, source_column_type, creation_context, **kwargs):
        target_column_type = creation_context.get_target_column_type(self.name, source_column_name, source_column_type)
        target_column_name = creation_context.get_target_column_name(self.name, source_column_name)

        if self.get_column(target_column_name):
            raise DuplicateColumn(target_column_name)

        nullable = kwargs.get('is_nullable', True)
        is_pk = False        
        self.columns.append(ColumnSpec(self.name, target_column_name, target_column_type, nullable, is_pk, **kwargs))


    def get_column(self, column_name):
        for c in self.columns:
            if c.name == column_name:
                return c
        return None


    def remove_column(self, column_name):
        c = self.get_column(column_name)
        if c:
            self.columns.remove(c)
            return True
        return False


class TableCreationContext(object):
    def __init__(self, yaml_config):

        self.defaults = yaml_config['defaults']
        self.pk_type = None
        self.pk_name = None
        self.pk_value = self.defaults.get('pk_value')
        self.default_varchar_length = self.defaults['varchar_length']
        self.type_map = {}         
        self.type_map.update(self.defaults['column_type_map'] or {})

        self.overrides = yaml_config.get('tables', {})
        autocreate_pk = self.defaults.get('autocreate_pk_if_missing', False)
        if autocreate_pk == True:            
            self.pk_type = self.defaults['pk_type']
            self.pk_name = self.defaults['pk_name']

    def get_table_suffix(self, table_name):
        if not self.overrides:
            return self.defaults.get('table_suffix', '')
        if not self.overrides.get(table_name):
            return self.defaults.get('table_suffix', '')
        return self.overrides[table_name].get('table_suffix', '')

    def get_column_suffix(self, table_name, column_name):
        if not self.overrides:
            return self.defaults.get('column_suffix', '')
        if not self.overrides.get(table_name):
            return self.defaults.get('column_suffix', '')

        suffix = self.overrides[table_name].get('column_suffix', '')

        if self.overrides[table_name].get('column_settings') and self.overrides[table_name]['column_settings'].get(column_name):            
            suffix = suffix or self.overrides[table_name]['column_settings'][column_name].get('column_suffix')

        return suffix

    def get_mapped_table_name(self, table_name):
        if not self.overrides:
            return table_name
        if not self.overrides.get(table_name):
            return table_name

        return self.overrides[table_name].get('rename_to', table_name)     


    def should_autocreate_pk(self, table_name): 
        setting = self.defaults['autocreate_pk_if_missing']
        if not self.overrides:
            return setting
        
        table_level_settings = self.overrides.get(table_name)
        
        if not table_level_settings:            
            return setting

        if table_level_settings.get('autocreate_pk_if_missing') is not None:
            setting = table_level_settings['autocreate_pk_if_missing']

        return setting


    def map_column_type(self, tablename, source_typename, target_typename):
        self.type_map[source_typename] = target_typename


    def get_target_column_type(self, table_name, column_name, source_column_type):        
        return self.type_map.get(source_column_type, source_column_type)


    def get_target_column_name(self, table_name, column_name):
        name = column_name
        if not self.overrides:
            return name
        if not self.overrides.get(table_name):
            return name
        if self.overrides[table_name].get('column_name_map'):
            name = self.overrides[table_name]['column_name_map'].get(column_name, column_name)

        return name


    def get_varchar_length(self, table_name, column_name):

        varchar_length = self.default_varchar_length
        if not self.overrides:
            return varchar_length

        if self.overrides.get(table_name):
            varchar_length = self.overrides[table_name].get('varchar_length', varchar_length)
            
            if self.overrides[table_name].get('column_settings'):
                if self.overrides[table_name]['column_settings'].get(column_name):
                    varchar_length = self.overrides[table_name]['column_settings'][column_name].get('length', varchar_length)

        return varchar_length


def create_tablespec_from_json_config(tablename, json_config, dbschema, creation_context, **kwargs):
    source_tablename = tablename
    target_tablename = creation_context.get_mapped_table_name(source_tablename)
    tspec = TableSpec(source_tablename, dbschema)
    
    for column_config in json_config['columns']:
        source_column_name = column_config['column_name']        
        source_column_type = column_config['column_type']

        settings = {}
        if is_string_type(column_config['column_type']):
            settings['length'] = creation_context.get_varchar_length(source_tablename, source_column_name)
        
        tspec.add_column(source_column_name,
                         source_column_type,
                         creation_context,
                         **settings)

    if not tspec.has_pk() and creation_context.should_autocreate_pk(source_tablename):
        primary_key_col = tspec.generate_pk_column(creation_context)
        tspec.insert_pk(primary_key_col)

    return tspec


def generate_sql(tablespec, table_creation_context):
    
    field_lines = []
    for column in tablespec.columns:
        field_lines.append(column.sql(table_creation_context))
        
    ddl_stmt = TABLE_DDL_TEMPLATE.format(schema=tablespec.schema,
                                     tablename=table_creation_context.get_mapped_table_name(tablespec.name),
                                     table_body=',\n'.join(field_lines),
                                     table_suffix=table_creation_context.get_table_suffix(table_creation_context.get_mapped_table_name(tablespec.name)))

    return ddl_stmt.strip() + ';'

    
def find_table_config(tablename, json_dbschema):
    for entry in json_dbschema['tables']:
        if entry['table_name'] == tablename:
            return entry
    return None


def main(args):

    configfile = args['<configfile>']
    yaml_config = common.read_config_file(configfile)
    
    json_dbschema = None
    schema_filename = args['<schemafile>']
    with open(schema_filename) as f:
        json_dbschema = json.loads(f.read())

    tablenames = []
    project_schema = json_dbschema['schema_name']
    for entry in json_dbschema['tables']:
        tablenames.append(entry['table_name'])

    if args.get('--list_tables'):
        print('\n'.join(tablenames))
        return

    
    if args.get('--schema'):
        project_schema = args['<db_schema>']
    table_creation_context = TableCreationContext(yaml_config)    

    if args.get('--table'):
        table_name = args['<tablename>']
        tablecfg = find_table_config(table_name, json_dbschema)
        if not tablecfg:
            print('No table "%s" defined in schema file %s.' % (table_name, schema_filename))
            return
        tablespec = create_tablespec_from_json_config(table_name,
                                                      tablecfg,
                                                      project_schema,
                                                      table_creation_context)

        print(generate_sql(tablespec, table_creation_context))
        return 
    elif args.get('--tables'):
        list_string = args['<comma_delim_tbl_list>']
        selected_table_names = [t.lstrip().rstrip() for t in list_string.split(',')]
        for table_name in selected_table_names:
            tablecfg = find_table_config(table_name, json_dbschema)
            if not tablecfg:
                print('No table "%s" defined in schema file %s.' % (table_name, schema_filename))
                return
            tablespec = create_tablespec_from_json_config(table_name,
                                                      tablecfg,
                                                      project_schema,
                                                      table_creation_context)

            print(generate_sql(tablespec, table_creation_context))
            print('\n')                
    elif args.get('--all-tables'):
        # generate all the tables specified in the metadata file
        for table_name in tablenames:
            tablecfg = find_table_config(table_name, json_dbschema)
            if not tablecfg:
                print('No table "%s" defined in schema file %s.' % (table_name, schema_filename), file=sys.stderr)
                exit(1)
            tablespec = create_tablespec_from_json_config(table_name,
                                                        tablecfg,
                                                        project_schema,
                                                        table_creation_context)

            print(generate_sql(tablespec, table_creation_context))
            print('\n')
        return 


if __name__ == '__main__':
    args = docopt.docopt(__doc__)
    main(args)
