#!/usr/bin/env python

'''Usage:
            seesv --xform=<transform_file> --xmap=<transform_map>  <datafile>
            seesv (-t | -f) --schema=<schema_file> --rtype=<record_type> <datafile>
            seesv -i

   Options:
            -t --test          Test the records in the target file for schema compliance
            -f --filter        Send the compliant records in the target file to stdout
            -i --interactive   Start up in interactive mode
            -l --lookup        Run seesv in lookup mode; attempt to look up missing data
'''

#
# seesv: command line utility for examining CSV files and evaluating them
# against a schema of required fields
#


import docopt
from docopt import docopt as docopt_func
from docopt import DocoptExit
import os, sys
from snap import snap, common
from mercury import datamap as dmap
import yaml
import logging


class TransformProcessor(dmap.DataProcessor):
    def __init__(self, transformer, data_processor):
        dmap.DataProcessor.__init__(self, data_processor)
        self._transformer = transformer
        self._records = []


    def _process(self, data_dict):        
        output = self._transformer.transform(data_dict)
        #print(common.jsonpretty(output))
        return output



class Dictionary2CSVProcessor(dmap.DataProcessor):
    def __init__(self, header_fields, delimiter, data_processor=None, **kwargs):
        dmap.DataProcessor.__init__(self, data_processor)
        self._delimiter = delimiter
        self._header_fields = header_fields
        self._record_count = 0


    def _process(self, data_dict):
        if self._record_count == 0:
            print(self._delimiter.join(self._header_fields))
        else:
            record = []
            for field in self._header_fields:
                data = data_dict.get(field)
                if data is None:
                    data = ''
                record.append(str(data))
            print(self._delimiter.join(record))

        self._record_count += 1
        return data_dict
        


def build_transformer(map_file_path, mapname):
    transformer_builder = dmap.RecordTransformerBuilder(map_file_path,
                                                        map_name=mapname)
    return transformer_builder.build()


def transform_data(source_datafile, src_header_fields, target_header_fields, transformer):    
    delimiter = '|'
    quote_character = '"'

    transform_proc = TransformProcessor(transformer, dmap.WhitespaceCleanupProcessor())
    d2csv_proc = Dictionary2CSVProcessor(target_header_fields, delimiter, transform_proc)
    extractor = dmap.CSVFileDataExtractor(d2csv_proc,
                                          delimiter=delimiter,
                                          quotechar=quote_character,
                                          header_fields=src_header_fields)

    extractor.extract(source_datafile)


class FilteringDisplayProcessor(dmap.DataProcessor):
    def __init__(self, required_record_fields, header_fields, delimiter, should_pass_good_records, data_processor=None):
        dmap.DataProcessor.__init__(self, data_processor)

        self._required_fields = required_record_fields        
        self._pass_good = should_pass_good_records
        self._dict2csvproc = Dictionary2CSVProcessor(header_fields, delimiter, processor=data_processor)


    def _process(self, record_dict):
        record_is_good = True
        for field in self._required_fields:
            if record_dict.get(field) is None or str(record_dict.get(field)) == '':
                record_is_good = False
                break
        
        if self._pass_good == record_is_good:
            self._dict2csvproc.process(record_dict)
        
        #TODO: THIS IS BAD AND HACKY
        # The behavior of this class violates the law of least surprises; it correctly filters noncompliant records,
        # but returns both compliant and noncompliant records from this method. If you are trying to generate a stream of
        # compliant (or noncompliant) records for further processing, you must find a different solution until this 
        # logic is fixed. (Sorry, I was in a hurry.) --binarymachineshop@gmail.com
        return record_dict



class ComplianceStatsProcessor(dmap.DataProcessor):
    def __init__(self, required_record_fields, processor=None):
        dmap.DataProcessor.__init__(self, processor)
        self._required_fields = required_record_fields
        self._valid_record_count = 0
        self._invalid_record_count = 0
        self._error_table = {}
        self._record_index = 0


    @property
    def total_records(self):
        return self._valid_record_count + self._invalid_record_count


    @property
    def valid_records(self):
        return self._valid_record_count


    @property
    def invalid_records(self):
        return self._invalid_record_count


    def match_format(self, obj, type_name='string'):
        # TODO: use a lookup table of regexes
        '''
        if obj is None:
            return True
        '''
        return True


    def _process(self, record_dict):
        error = False
        self._record_index += 1
        for name, datatype in self._required_fields.iteritems():
            if record_dict.get(name) is None or record_dict.get(name) == '':
                error = True
                self._error_table[self._record_index] = (name, 'null')
                break
            elif not self.match_format(record_dict[name]):
                error = True
                self._error_table[self._record_index] = (name, 'invalid_type')
                break

        if error:
            self._invalid_record_count += 1
        else:
            self._valid_record_count += 1

        return record_dict


    def get_stats(self):
        validity_stats = {
        'invalid_records': self.invalid_records,
        'valid_records': self.valid_records,
        'total_records': self.total_records,
        'errors_by_record': self._error_table
        }
        return validity_stats


class RecordRepairProcessor(dmap.DataProcessor):
    def __init__(self, required_record_fields, data_supplier, processor=None):
        dmap.DataProcessor.__init__(self, processor)
        self._required_record_fields = required_record_fields
        self._data_supplier = data_supplier


    def _process(self, record):
        output_record = {}
        for field_name, field_value in record.iteritems():
            if field_name in self._required_record_fields:
                if field_value is None or str(field_value) == '':
                    value = self._data_supplier.supply(field_name, record)
                    if value is None:
                        output_record[field_name] = field_value # do not change empties to None
                    else:
                        output_record[field_name] = value
                else:
                    output_record[field_name] = field_value
            else:
                output_record[field_name] = field_value                    
        return output_record
                
            

def get_schema_compliance_stats(source_datafile, schema_config, record_type, pre_processor):

    required_fields_dict = {}
    for field_name in schema_config:
        if schema_config[field_name]['required'] == True:
            required_fields_dict[field_name] = schema_config[field_name]['type']

    cstats_proc = ComplianceStatsProcessor(required_fields_dict, processor=pre_processor)
 
    extractor = dmap.CSVFileDataExtractor(cstats_proc,
                                          delimiter='|',
                                          quotechar='"',
                                          header_fields=required_fields_dict.keys())
    extractor.extract(source_datafile)
    return cstats_proc.get_stats()


def get_schema_compliance_with_lookup(source_datafile, schema_config_file, record_type, pre_processor):
    required_fields_dict = {}
    with open(schema_config_file) as f:
        record_config = yaml.load(f)
        schema_config = record_config['record_types'][record_type]

        for field_name in schema_config:
            if schema_config[field_name]['required'] == True:
                required_fields_dict[field_name] = schema_config[field_name]['type']

    cstats_proc = ComplianceStatsProcessor(required_fields_dict, processor=pre_processor)

    src_header = get_all_fields(record_type, schema_config_file)
    d2csv_proc = Dictionary2CSVProcessor(src_header, '|', data_processor=cstats_proc)

    extractor = dmap.CSVFileDataExtractor(d2csv_proc,
                                          delimiter='|',
                                          quotechar='"',
                                          header_fields=required_fields_dict.keys())
    extractor.extract(source_datafile)
    return cstats_proc.get_stats()


def load_data_supplier(record_type, schema_config_file, record_id_field):
    supplier_class = None
    svc_object_registry = None
    supplier_init_params = {'record_id_field': record_id_field}
    with open(schema_config_file) as f:
        config = yaml.load(f)
        supplier_module = config['globals']['supplier_module']
        supplier_classname = config['globals']['data_supplier']
        supplier_class = common.load_class(supplier_classname, supplier_module)
        service_objects = snap.initialize_services(config, logging.getLogger())
        svc_object_registry = common.ServiceObjectRegistry(service_objects)
        supplier_params = config['globals'].get('supplier_init_params') or []
        for param in supplier_params:
            supplier_init_params[param['name']] = common.load_config_var(param['value'])

    return supplier_class(svc_object_registry, **supplier_init_params)


def get_required_fields(record_type, schema_config_file):
    required_fields = []
    with open(schema_config_file) as f:
        config = yaml.load(f)            
        schema_config = config['record_types'].get(record_type)
        if not schema_config:
            raise Exception('No record type "%s" found in schema config file %s.' % (record_type, schema_config_file))
        
        for field_name in schema_config:
            if schema_config[field_name]['required'] == True:
                required_fields.append(field_name)

    return required_fields


def get_all_fields(record_type, schema_config_file):
    fields = []
    with open(schema_config_file) as f:
        record_config = yaml.load(f)            
        schema_config = record_config['record_types'].get(record_type)
        if not schema_config:
            raise Exception('No record type "%s" found in schema config file %s.' % (record_type, schema_config_file))
        
        for field_name in schema_config:
            fields.append(field_name)

    return fields


def get_transform_target_header(transform_config_file, map_name):
    header_fields = []
    with open(transform_config_file) as f:
        transform_config = yaml.load(f)
        transform_map = transform_config['maps'].get(map_name)
        if not transform_map:
            raise Exception('No transform map "%s" found in transform config file %s.' % (map_name, transform_config_file))
        
        header_fields = [field_name for field_name in transform_map['fields']]
    return header_fields



def find_env_vars(arg_dict):
    vars = []
    for value in arg_dict.values():
        if not value:
            continue
        raw_value_tokens = value.split(os.path.sep)
        for tok in raw_value_tokens:
            if tok.startswith('$'):
                vars.append(tok[1:])
    return vars
            

def resolve_initfile_args(arg_dict):
    '''Resolve filsystem refs such as tilde and dollar-sign'''
    env_vars = find_env_vars(arg_dict)
    localenv = common.LocalEnvironment(*env_vars)
    localenv.init()

    output_args = {}
    for key, value in arg_dict.iteritems():
        if not value:
            continue
        raw_value_tokens = value.split(os.path.sep)
        cooked_value_tokens = []
        for tok in raw_value_tokens:
            if tok is not None:
                #print('adding cooked value token: %s...' % common.load_config_var(tok))
                cooked_value_tokens.append(common.load_config_var(tok))
        output_args[key] = os.path.sep.join(cooked_value_tokens)
    return output_args



def main(args):
    logging.basicConfig(filename='seesv.log', level=logging.DEBUG)
    test_mode = args.get('--test')
    
    
    
    
    if args.get('--xform'):
        transform_mode = True
    interactive_mode = args.get('--interactive')
    src_datafile = args.get('<datafile>')

    if transform_mode:
        transform_config_file = args.get('--xform')
        transform_map = args.get('--xmap')
        schema_config_file = args.get('--schema')
        record_type = args.get('--rtype')
        
        src_header = get_required_fields(record_type, schema_config_file)
        target_header = get_transform_target_header(transform_config_file, transform_map)
        xformer = build_transformer(transform_config_file, transform_map)
        transform_data(src_datafile, src_header, target_header, xformer)

    elif lookup_mode:
        record_type = args.get('--rtype')
        schema_config_file = args.get('--schema')
        required_fields = get_required_fields(record_type, schema_config_file)
        data_supplier = load_data_supplier(record_type, schema_config_file, 'INDIR_SALES_KEY')
        print(data_supplier.get_service_object('mssql').db.list_tables())
        #     def __init__(self, required_record_fields, data_supplier, processor=None):
        repair_processor = RecordRepairProcessor(required_fields, data_supplier, processor=dmap.WhitespaceCleanupProcessor())
        get_schema_compliance_with_lookup(src_datafile, schema_config_file, record_type, repair_processor)

    elif interactive_mode:
        print('PLACEHOLDER: seesv interactive mode')
        return

    elif initfile_mode:
        initfile_name = args['--initfile']
        cmd_line = None
        with open(initfile_name) as f:
            cmd_line = f.readline()
        raw_args = docopt_func(__doc__, cmd_line.split(' '))
        cooked_args = resolve_initfile_args(raw_args)
        cooked_args['--initfile'] = None # just to make sure we don't recurse infinitely
        main(cooked_args)

    elif test_mode:
        #print('testing data in source file %s for schema compliance...' % src_datafile)
        schema_config_file = args.get('--schema')
        with open(schema_config_file) as f:
            record_config = yaml.load(f)
            record_type = args.get('--rtype')
            schema_config = record_config['record_types'][record_type]
            print(get_schema_compliance_stats(src_datafile, schema_config, record_type, dmap.WhitespaceCleanupProcessor()))

    elif filter_mode:
        # Filter source records for schema compliance
        schema_config_file = args.get('--schema')
        record_type = args.get('--rtype')        
        required_fields = get_required_fields(record_type, schema_config_file)
        header_fields = get_all_fields(record_type, schema_config_file)

        filter_proc = FilteringDisplayProcessor(required_fields, header_fields, '|', False)
        extractor = dmap.CSVFileDataExtractor(filter_proc,
                                              delimiter='|',
                                              quotechar='"',
                                              header_fields=header_fields)
        extractor.extract(src_datafile)
        

if __name__ == '__main__':
    args = docopt.docopt(__doc__)
    main(args)
