#!/usr/bin/env python
# encoding: utf-8
"""
tempo estimator
"""

import argparse
import sys
from os.path import splitext, dirname, basename, join

import jams
import librosa

from tempocnn.classifier import TempoClassifier
from tempocnn.feature import read_features
from tempocnn.version import __version__ as package_version


def main():
    """tempo"""

    # define parser
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter, description='''
    The program 'tempo' estimates a global tempo for a given file.
    The underlying algorithm is described in detail in:
    
    Hendrik Schreiber, Meinard Müller,
    "A single-step approach to musical meter estimation using a
    convolutional neural network"
    Proceedings of the 19th International Society for Music Information
    Retrieval Conference (ISMIR), Paris, France, Sept. 2018.
    
    License: GNU Affero General Public License v3
    ''')

    parser.add_argument('-v', '--version',
                        action='version',
                        version=f'tempo {package_version}')
    parser.add_argument('--interpolate',
                        help='interpolate between tempo classes',
                        action='store_true')
    parser.add_argument('-m', '--model',
                        nargs='?',
                        default='fcn',
                        help='model name [ismir2018|fma2018|cnn|fcn|mazurka|deeptemp|deepsquare|shallowtemp], defaults to fcn')
    output_format = parser.add_mutually_exclusive_group()
    output_format.add_argument('--mirex',
                               help='use MIREX format for output',
                               action="store_true")
    output_format.add_argument('--jams',
                               help='use JAMS format for output',
                               action="store_true")
    parser.add_argument('-i', '--input',
                        nargs='+',
                        help='input audio file(s) to process')

    output_options = parser.add_mutually_exclusive_group()
    output_location = parser.add_mutually_exclusive_group()
    output_location.add_argument('-o', '--output',
                                 nargs='*',
                                 help='output file(s)')
    output_location.add_argument('-d', '--outputdir',
                                 help='output directory')
    output_options.add_argument('-e', '--extension',
                                help='append given extension to original file name for results')
    output_options.add_argument('-c', '--cont',
                                help='continue after error, if multiple files are processed',
                                action='store_true')

    # parse arguments
    args = parser.parse_args()

    if args.output is not None and 0 < len(args.output) != len(args.input):
        print('Number of input files ({}) must match number of output files ({}).\nInput={}\nOutput={}'
              .format(len(args.input), len(args.output), args.input, args.output), file=sys.stderr)
        parser.print_help(file=sys.stderr)
        sys.exit(1)

    if args.input is None:
        print('No input files given.', file=sys.stderr)
        parser.print_help(file=sys.stderr)
        sys.exit(1)

    # load model
    print('Loading model \'{}\'...'.format(args.model))
    classifier = TempoClassifier(args.model)
    print('Loaded model with {} parameters.'.format(classifier.model.count_params()))

    print('Processing file(s)', end='', flush=True)
    for index, input_file in enumerate(args.input):
        try:
            print('.', end='', flush=True)
            features = read_features(input_file)

            if args.mirex or args.jams:
                t1, t2, s1 = classifier.estimate_mirex(features, interpolate=args.interpolate)
                if args.mirex:
                    result = str(t1) + '\t' + str(t2) + '\t' + str(s1)
                else:
                    result = jams.JAMS()
                    y, sr = librosa.load(input_file)
                    track_duration = librosa.get_duration(y=y, sr=sr)
                    result.file_metadata.duration = track_duration
                    result.file_metadata.identifiers = {'file': basename(input_file)}
                    tempo_a = jams.Annotation(namespace='tempo', time=0, duration=track_duration)
                    tempo_a.annotation_metadata = jams.AnnotationMetadata(version='0.0.4',
                                                                          annotation_tools='schreiber tempo-cnn (model=' + args.model + '), https://github.com/hendriks73/tempo-cnn',
                                                                          data_source='Hendrik Schreiber, Meinard Müller. A Single-Step Approach to Musical Tempo Estimation Using a Convolutional Neural Network. In Proceedings of the 19th International Society for Music Information Retrieval Conference (ISMIR), Paris, France, Sept. 2018.')
                    tempo_a.append(time=0.0,
                                   duration=track_duration,
                                   value=t1,
                                   confidence=s1)
                    tempo_a.append(time=0.0,
                                   duration=track_duration,
                                   value=t2,
                                   confidence=(1-s1))
                    result.annotations.append(tempo_a)
            else:
                tempo = classifier.estimate_tempo(features, interpolate=args.interpolate)
                result = str(tempo)

            output_file = None
            file_dir = dirname(input_file)
            file_name = basename(input_file)
            if args.outputdir is not None:
                file_dir = args.outputdir
            if args.jams:
                base, file_extension = splitext(file_name)
                output_file = join(file_dir, base + '.jams')
            elif args.extension is not None:
                output_file = join(file_dir, file_name + args.extension)
            elif args.output is not None and index < len(args.output):
                output_file = args.output[index]
            if args.jams:
                result.save(output_file)
            elif output_file is None:
                print('\n' + result)
            else:
                with open(output_file, mode='w') as file_name:
                    file_name.write(result + '\n')
        except Exception as e:
            if not args.cont:
                print('\nAn error occurred while processing \'{}\':\n{}\n'.format(input_file, e), file=sys.stderr)
                raise e
            else:
                print('E({})'.format(input_file), end='', flush=True)
    print('\nDone')


if __name__ == '__main__':
    main()
