#!/usr/bin/env python

'''
Usage:  
    j2spectrum --config <configfile> --source <schemafile> --schema <ext_schema> --table <tablename>
    j2spectrum --config <configfile> --source <schemafile> --schema <ext_schema> --tables <comma_delim_tbl_list>
    j2spectrum --config <configfile> --source <schemafile> --schema <ext_schema> --all-tables
    j2spectrum --config <configfile> --source <schemafile> --list-tables
'''

import os, sys
import copy
import json
from snap import common
import docopt

COLUMN_ENCODING_TYPES = [
    'delta',
    'delta32k',
    'raw',
    'text255',
    'bytedict'
]

DIST_STYLES = [
    'even',
    'key',
    'all',
    'auto(all)',
    'auto(even)'
]

TABLE_DDL_TEMPLATE = '''
CREATE EXTERNAL TABLE {schema}.{tablename} (
{table_body}
) 
{properties}
LOCATION '{s3_uri}/{schema}/{tablename}/'
'''

# TODO: add TABLE PROPERTIES clause to template
# TABLE PROPERTIES ('{table_props}')

ROW_FORMAT_CLAUSE = 'ROW FORMAT {row_format}'
STORAGE_CLAUSE = 'STORED AS {storage_type}'


TABLE_LEVEL_SETTINGS = [
    'varchar_length'
]

COLUMN_LEVEL_SETTINGS = [
    'varchar_length'
]

RS_STRING_TYPES = [
    'varchar'
]

def is_string_type(type_name):
    if type_name.lower() in RS_STRING_TYPES:
        return True
    return False

class DuplicateColumn(Exception):
    def __init__(self, column_name):
        Exception.__init__(self, 'TableSpec already contains column "%s".' % column_name)


class ColumnSpec(object):
    def __init__(self, table_name, column_name, column_type, is_nullable=True, is_pk=False, **kwargs):
        self.tablename = table_name
        self.name = column_name
        self.datatype = column_type
        self.nullable = is_nullable
        self.length = kwargs.get('length', '') # can be null; usually only applies to string types

        self.is_primary_key = is_pk
        self.default_value = kwargs.get('default')        
            
    def sql(self, table_creation_context):
        null_clause = ''
        if not self.nullable:
            null_clause = 'NOT NULL'
        
        if self.length:
            length_clause = '(%s)' % self.length
        else:
            length_clause = ''

        src_colname = self.name
        src_coltype = self.datatype
        if self.default_value is None:
            template = '{name} {datatype}{length} {null} {column_suffix}'
            
            field_line = template.format(name=table_creation_context.get_target_column_name(self.tablename,
                                                                                            self.name),
                                         datatype=table_creation_context.get_target_column_type(self.tablename,
                                                                                                self.name,
                                                                                                self.datatype),
                                         length=length_clause,
                                         null=null_clause,
                                         column_suffix=table_creation_context.get_column_suffix(self.tablename, self.name))
        else:
            template = '{name} {datatype}{length} {null} DEFAULT {default} {column_suffix}'
            field_line = template.format(name=table_creation_context.get_target_column_name(self.tablename,
                                                                                            self.name),
                                        datatype=table_creation_context.get_target_column_type(self.tablename,
                                                                                               self.datatype),
                                        length=length_clause,
                                        null=null_clause,
                                        default=self.default_value,
                                        column_suffix=table_creation_context.get_column_suffix(self.tablename, self.name))
        if self.is_primary_key:
            field_line = field_line + ' PRIMARY KEY'                                                                                                            
        return field_line.lstrip().rstrip()                                                    


class TableSpec(object):
    def __init__(self, table_name, schema_name, **kwargs):
        self.name = table_name
        self.schema = schema_name
        self.columns = []


    def has_pk(self):
        for columnspec in self.columns:
            if columnspec.is_primary_key:
                return True
        return False


    def generate_pk_column(self, table_creation_context):
        return ColumnSpec(self.name,
                          table_creation_context.pk_name,
                          table_creation_context.pk_type,
                          False,
                          True,
                          default=table_creation_context.pk_value)


    def insert_pk(self, columnspec):
        self.columns.insert(0, columnspec)


    def add_columnspec(self, columnspec):
        self.columns.append(columnspec)


    def add_column(self, source_column_name, source_column_type, creation_context, **kwargs):
        target_column_type = creation_context.get_target_column_type(self.name, source_column_name, source_column_type)
        target_column_name = creation_context.get_target_column_name(self.name, source_column_name)

        if self.get_column(target_column_name):
            raise DuplicateColumn(target_column_name)

        nullable = kwargs.get('is_nullable', True)
        is_pk = False
        self.columns.append(ColumnSpec(self.name, target_column_name, target_column_type, nullable, is_pk, **kwargs))


    def get_column(self, column_name):
        for c in self.columns:
            if c.name == column_name:
                return c
        return None


    def remove_column(self, column_name):
        c = self.get_column(column_name)
        if c:
            self.columns.remove(c)
            return True
        return False


class SpectrumTableCreationContext(object):
    def __init__(self, yaml_config):

        self.defaults = yaml_config['spectrum_defaults']        
        self.default_varchar_length = self.defaults['varchar_length']                
        #self.delimiter = self.defaults.get('delimiter')
        self.s3_uri = self.defaults['s3_uri']

        if self.defaults['row_format'] in ['lines', 'fields']:
            if self.defaults.get('delimiter') is None:
                raise Exception('Specifying a row format of "lines" or "fields" requires you to set the "delimiter" field.')

        if self.defaults['row_format'] == 'serde':
            if not self.defaults.get('serde_properties'):
                raise Exception('Specifying a row format of "serde" requires you to set the "serde_properties" field.')

            serde_props = []
            for key, value in defaults['serde_properties'].items():
                serde_props.append('%s=%s' % (key, value))

            self.row_format = 'SERDE WITH SERDEPROPERTIES(%s)' % ','.join(serde_props)

        elif self.defaults['row_format'] == 'lines':
            self.row_format = "DELIMITED LINES TERMINATED BY '%s'" % self.defaults['delimiter']

        elif self.defaults['row_format'] == 'fields':
            self.row_format = "DELIMITED FIELDS TERMINATED BY '%s'" % self.defaults['delimiter']
        else:
            raise Exception('Unsupported row format %s.' % self.defaults['row_format'])

        self.storage_type = self.defaults['storage_type']

        self.overrides = {}

        table_overrides = yaml_config.get('tables', {})
        for tablename, settings_dict in table_overrides.items():
            self.overrides[tablename] = copy.deepcopy(self.defaults)
            self.overrides[tablename].update(settings_dict)


    def generate_properties_clause(self, table_name):
        # TODO: allow per-table property overrides of default settings
        props = []
        props.append(ROW_FORMAT_CLAUSE.format(row_format=self.row_format))     
        props.append(STORAGE_CLAUSE.format(storage_type=self.storage_type))
        return '\n'.join(props)


    def get_table_suffix(self, table_name):
        if not self.overrides:
            return self.defaults.get('table_suffix', '')
        if not self.overrides.get(table_name):
            return self.defaults.get('table_suffix', '')
        return self.overrides[table_name].get('table_suffix', '')

    def get_column_suffix(self, table_name, column_name):
        if not self.overrides:
            return self.defaults.get('column_suffix', '')
        if not self.overrides.get(table_name):
            return self.defaults.get('column_suffix', '')

        suffix = self.overrides[table_name].get('column_suffix', '')

        if self.overrides[table_name].get('column_settings') and self.overrides[table_name]['column_settings'].get(column_name):            
            suffix = suffix or self.overrides[table_name]['column_settings'][column_name].get('column_suffix')

        return suffix

    def get_mapped_table_name(self, table_name):
        if not self.overrides:
            return table_name
        if not self.overrides.get(table_name):
            return table_name

        return self.overrides[table_name].get('rename_to', table_name)     

    '''
    def map_column_type(self, tablename, source_typename, target_typename):
        self.type_map[source_typename] = target_typename
    '''

    def get_target_column_type(self, table_name, column_name, source_column_type):        
        return source_column_type


    def get_target_column_name(self, table_name, column_name):
        name = column_name
        if not self.overrides:
            return name
        if not self.overrides.get(table_name):
            return name
        if self.overrides[table_name].get('column_name_map'):
            name = self.overrides[table_name]['column_name_map'].get(column_name, column_name)

        return name


    def get_varchar_length(self, table_name, column_name):

        varchar_length = self.default_varchar_length
        if not self.overrides:
            return varchar_length

        if self.overrides.get(table_name):
            varchar_length = self.overrides[table_name].get('varchar_length', varchar_length)
            
            if self.overrides[table_name].get('column_settings'):
                if self.overrides[table_name]['column_settings'].get(column_name):
                    varchar_length = self.overrides[table_name]['column_settings'][column_name].get('length', varchar_length)

        return varchar_length


def create_tablespec_from_json_config(tablename, json_config, dbschema, creation_context, **kwargs):
    source_tablename = tablename
    target_tablename = creation_context.get_mapped_table_name(source_tablename)
    tspec = TableSpec(source_tablename, dbschema)
    
    for column_config in json_config['columns']:
        source_column_name = column_config['column_name']        
        source_column_type = column_config['column_type']

        settings = {}
        if is_string_type(column_config['column_type']):
            settings['length'] = creation_context.get_varchar_length(source_tablename, source_column_name)
        tspec.add_column(source_column_name,
                         source_column_type,
                         creation_context,
                         **settings)

    return tspec


def generate_sql(tablespec, table_creation_context):
    
    field_lines = []
    for column in tablespec.columns:
        field_lines.append(column.sql(table_creation_context))
        
    ddl_stmt = TABLE_DDL_TEMPLATE.format(schema=tablespec.schema,
                                         tablename=table_creation_context.get_mapped_table_name(tablespec.name),
                                         table_body=',\n'.join(field_lines),
                                         s3_uri=table_creation_context.s3_uri,
                                         properties=table_creation_context.generate_properties_clause(table_creation_context.get_mapped_table_name(tablespec.name)))

    return ddl_stmt.strip() + ';'

    
def find_table_config(tablename, json_dbschema):
    for entry in json_dbschema['tables']:
        if entry['table_name'] == tablename:
            return entry
    return None


def main(args):

    configfile = args['<configfile>']
    yaml_config = common.read_config_file(configfile)
    
    json_dbschema = None
    schema_filename = args['<schemafile>']
    with open(schema_filename) as f:
        json_dbschema = json.loads(f.read())

    tablenames = []
    project_schema = json_dbschema['schema_name']
    for entry in json_dbschema['tables']:
        tablenames.append(entry['table_name'])

    if args.get('--list-tables'):
        print('\n'.join(tablenames))
        return
    
    project_schema = args['<ext_schema>']
    table_creation_context = SpectrumTableCreationContext(yaml_config)   

    #print(common.jsonpretty(table_creation_context.defaults))
    #print(common.jsonpretty(table_creation_context.overrides))
    #return 

    if args.get('--table'):
        table_name = args['<tablename>']
        tablecfg = find_table_config(table_name, json_dbschema)
        if not tablecfg:
            print('No table "%s" defined in schema file %s.' % (table_name, schema_filename))
            return
        tablespec = create_tablespec_from_json_config(table_name,
                                                      tablecfg,
                                                      project_schema,
                                                      table_creation_context)

        print(generate_sql(tablespec, table_creation_context))
        return 
    elif args.get('--tables'):
        list_string = args['<comma_delim_tbl_list>']
        selected_table_names = [t.lstrip().rstrip() for t in list_string.split(',')]
        for table_name in selected_table_names:
            tablecfg = find_table_config(table_name, json_dbschema)
            if not tablecfg:
                print('No table "%s" defined in schema file %s.' % (table_name, schema_filename))
                return
            tablespec = create_tablespec_from_json_config(table_name,
                                                      tablecfg,
                                                      project_schema,
                                                      table_creation_context)

            print(generate_sql(tablespec, table_creation_context))
            print('\n')                
    elif args.get('--all-tables'):
        # generate all the tables specified in the metadata file
        for table_name in tablenames:
            tablecfg = find_table_config(table_name, json_dbschema)
            if not tablecfg:
                raise Exception('No table "%s" defined in schema file %s.' % (table_name, schema_filename))
                exit(1)
            tablespec = create_tablespec_from_json_config(table_name,
                                                        tablecfg,
                                                        project_schema,
                                                        table_creation_context)

            print(generate_sql(tablespec, table_creation_context))
            print('\n')
        return 


if __name__ == '__main__':
    args = docopt.docopt(__doc__)
    main(args)
