#!/usr/bin/env python
# encoding: utf-8
"""
tempogram generator
"""

import argparse

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import ScalarFormatter

from tempocnn.classifier import TempoClassifier
from tempocnn.feature import read_features
from tempocnn.version import __version__ as package_version


def main():
    """tempogram"""

    # define parser
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter, description='''
    The program 'tempogram' estimates local tempi for a given file and displays
    their probability distributions in a graph. 
    The underlying algorithm is described in detail in:

    Hendrik Schreiber, Meinard Müller,
    "A single-step approach to musical meter estimation using a
    convolutional neural network"
    Proceedings of the 19th International Society for Music Information
    Retrieval Conference (ISMIR), Paris, France, Sept. 2018.
    
    License: GNU Affero General Public License v3
    ''')

    parser.add_argument('-v', '--version', action='version', version=f'tempogram {package_version}')
    parser.add_argument('-p', '--png',
                        help='write the tempogram to a file, '
                             'adding the file extension .png to the input file name',
                        action="store_true")
    parser.add_argument('-c', '--csv',
                        help='write the tempogram data to a csv file, '
                             'adding the file extension .csv to the input file name',
                        action="store_true")
    parser.add_argument('-s', '--sharpen',
                        help='sharpen the image to a one-hot representation',
                        action="store_true")
    parser.add_argument('-n', '--norm-frame',
                        help='enable framewise normalization using (max|l1|l2)')
    parser.add_argument('--hop-length',
                        help='hop length between predictions, 1 hop = 0.0464399093s',
                        default=32, type=int)
    parser.add_argument('-m', '--model', nargs='?', default='fcn',
                        help='model name (ismir2018|fma2018|cnn|fcn|mazurka|deeptemp|deepsquare|shallowtemp), defaults to fcn')
    parser.add_argument('audio_file', nargs='+', help='audio file to process')

    # parse arguments
    args = parser.parse_args()

    # load model
    print('Loading model \'{}\'...'.format(args.model))
    classifier = TempoClassifier(args.model)
    print('Loaded model with {} parameters.'.format(classifier.model.count_params()))

    hop_length = args.hop_length
    sr = 11025.0
    fft_hop_length = 512.0
    log_scale = 'fma' in args.model
    if log_scale:
        min_bpm = 50
        max_bpm = 500
        max_ylim = 510
    else:
        min_bpm = 30
        max_bpm = 286
        max_ylim = 300

    print('Processing file(s)', end='', flush=True)
    for file in args.audio_file:
        print('.', end='', flush=True)
        features = read_features(file, hop_length=hop_length, zero_pad=True)
        predictions = classifier.estimate(features)

        if args.norm_frame is not None:
            norm_order = np.inf
            if 'max' == args.norm_frame.lower():
                norm_order = np.inf
            elif 'l1' == args.norm_frame.lower():
                norm_order = 1
            elif 'l2' == args.norm_frame.lower():
                norm_order = 2
            else:
                print('Unknown norm. Using max norm.', end='', flush=True)
            predictions = (predictions.T / np.linalg.norm(predictions, ord=norm_order, axis=1)).T

        if args.sharpen:
            predictions = (predictions.T / np.max(predictions, axis=1)).T
            predictions = np.where(predictions != 1, 0, predictions)

        max_windows = predictions.shape[0] * hop_length
        max_length_in_s = max_windows * (fft_hop_length / sr)
        frame_length = (fft_hop_length / sr) * hop_length

        fig = plt.figure()
        fig.canvas.set_window_title('tempogram: ' + file)
        if args.png:
            fig.set_size_inches(5, 2)
        ax = fig.add_subplot(111)
        ax.set_ylim((0, max_ylim))
        ax.imshow(predictions.T, origin='lower', cmap='Greys', aspect='auto',
                  extent=(0, max_length_in_s, min_bpm, max_bpm))
        if log_scale:
            ax.set_yscale('log')
            ax.yaxis.set_major_formatter(ScalarFormatter())
            ax.yaxis.set_minor_formatter(ScalarFormatter())

        ax.set_xlabel('Time (seconds)')
        ax.set_ylabel('Tempo (BPM)')

        if args.csv:
            if args.sharpen:
                # for now simple argmax, we could use quad interpolation instead
                index = np.argmax(predictions, axis=1)
                bpm = classifier.to_bpm(index)
                np.savetxt(file + '.csv',
                           bpm,
                           fmt='%1.2f',
                           delimiter=",",
                           header='Predictions using model \'{}\' '
                                  'argmax of tempo distribution '
                                  '{}-{} BPM (column) '
                                  'log_scale={} '
                                  'feature frequency={} Hz '
                                  'i.e. {} ms/feature (rows)'
                           .format(classifier.model_name, min_bpm, max_bpm, log_scale,
                                   1./frame_length,
                                   frame_length*1000.))
            else:
                np.savetxt(file + '.csv',
                           predictions,
                           fmt='%1.6f',
                           delimiter=",",
                           header='Predictions using model \'{}\' '
                                  '{}-{} BPM (columns) '
                                  'log_scale={} '
                                  'feature frequency={} Hz '
                                  'i.e. {} ms/feature (rows)'
                           .format(classifier.model_name, min_bpm, max_bpm, log_scale,
                                   1./frame_length,
                                   frame_length*1000.))

        if args.png:
            plt.tight_layout()
            fig.savefig(file + '.png', dpi=300)
        else:
            plt.show()

    print('\nDone')


if __name__ == '__main__':
    main()
