#!/usr/bin/env python

import numpy as np
import matplotlib.pyplot as plt
import argparse

import dynaphopy.interface.phonopy_link as pho_interface
import dynaphopy.interface.iofile as reading


parser = argparse.ArgumentParser(description='rfc_calc options')
parser.add_argument('input_file', metavar='data_file', type=str, nargs=1,
                    help='input file containing structure related data')

group = parser.add_mutually_exclusive_group (required=True)
group.add_argument('-cp', action='store_true',
                   help='get commensurate points')
group.add_argument('frequency_file', metavar='data_file', type=str, nargs='?',
                    help='input file containing renormalized phonon frequencies')

parser.add_argument('-p', action='store_true',
                   help='plot phonon band structure')

parser.add_argument('-s', metavar='FORCE_CONSTANTS', type=str, nargs=1,
                    help='save force_constants into a file')

parser.add_argument('--fcsymm', action='store_true',
                    help='symmetrize force constants')

parser.add_argument('--dim', metavar='N', type=int, nargs=3, default=None,
                    help='define custom supercell')

args = parser.parse_args()


def get_renormalized_constants(freq_filename, structure, symmetrize=False, degenerate=True):

    if 'force_sets_file_name' in input_parameters:
        structure.set_force_set(pho_interface.get_force_sets_from_file(file_name=input_parameters['force_sets_file_name'], fs_supercell=input_parameters['supercell_phonon']))
    if 'force_constants_file_name' in input_parameters:
        structure.set_force_constants(pho_interface.get_force_constants_from_file(file_name=input_parameters['force_constants_file_name'], fc_supercell=input_parameters['supercell_phonon'] ))

    structure.get_data_from_dict(input_parameters)

    fc_supercell = structure.get_supercell_phonon()
    if args.dim is None:
        if 'supercell_phonon' in input_parameters:
            fc_supercell = input_parameters['supercell_phonon']
        else:
            print('please define supercell phonon in inputfile')
            exit()

    com_points = pho_interface.get_commensurate_points(structure, fc_supercell)

    # Seems thet the commenurate points have to be in the same original order.
    renormalized_frequencies = []
    import yaml
    with open(freq_filename, 'r') as stream:
        data_dict = yaml.load(stream)
        for wv in com_points:
            for q_point in data_dict.values():
                # print q_point
                # print q_point['frequencies']
                if np.array(q_point['reduced_wave_vector'] == wv).all():
                    renormalized_frequencies.append(q_point['frequencies'])
                    # com_points.append(q_point['reduced_wave_vector'])

    renormalized_frequencies = np.array(renormalized_frequencies)
    #com_points = np.array(com_points)

    eigenvectors = [pho_interface.obtain_eigenvectors_and_frequencies(structure, point, print_data=False)[0] for point in com_points]

    renormalized_force_constants = pho_interface.get_renormalized_force_constants(renormalized_frequencies,
                                                                                  eigenvectors,
                                                                                  structure,
                                                                                  fc_supercell,
                                                                                  symmetrize=symmetrize)

    return renormalized_force_constants


def plot_renormalized_phonon_bands(structure, renormalized_force_constants):

    if '_band_ranges' in input_parameters:
        band_ranges = input_parameters['_band_ranges']
    else:
        band_ranges = structure.get_path_using_seek_path()['ranges']


    bands = pho_interface.obtain_phonon_dispersion_bands(structure,
                                                         band_ranges,
                                                         NAC=False)

    renormalized_bands = pho_interface.obtain_phonon_dispersion_bands(structure,
                                                                      band_ranges,
                                                                      force_constants=renormalized_force_constants,
                                                                      NAC=False)

    for i,freq in enumerate(renormalized_bands[1]):
        plt.plot(bands[1][i],bands[2][i],color ='b', label='Harmonic (0K)')
        plt.plot(renormalized_bands[1][i],renormalized_bands[2][i],color ='r', label='Renormalized')
        plt.axes().get_xaxis().set_visible(False)
    plt.axes().get_xaxis().set_ticks([])
    plt.ylabel('Frequency [THz]')
    plt.xlabel('Wave vector')
    plt.xlim([0, renormalized_bands[1][-1][-1]])
    plt.axhline(y=0, color='k', ls='dashed')
    plt.suptitle('Renormalized phonon dispersion')
    handles, labels = plt.gca().get_legend_handles_labels()
    plt.legend([handles[0], handles[-1]], ['Harmonic','Renormalized'])
    plt.show()


# Get data from input file & process parameters
input_parameters = reading.read_parameters_from_input_file(args.input_file[0])

if 'structure_file_name_outcar' in input_parameters:
    structure = reading.read_from_file_structure_outcar(input_parameters['structure_file_name_outcar'])
else:
    structure = reading.read_from_file_structure_poscar(input_parameters['structure_file_name_poscar'])

structure.get_data_from_dict(input_parameters)

# if args.dim is not None:
#    structure.set_supercell_phonon_renormalized(np.diag(np.array(args.dim, dtype=int)))

# Process options
if args.cp:

    if args.dim is None:
        if 'supercell_phonon' in input_parameters:
            fc_supercell = input_parameters['supercell_phonon']
        else:
            print('Warning: No supercell matrix defined! using unitcell as supercell')
            fc_supercell = np.identity(3)
        com_points = pho_interface.get_commensurate_points(structure, fc_supercell)
    else:
        com_points = pho_interface.get_commensurate_points(structure, np.diag(args.dim))

    for q_point in com_points:
        print(q_point)
    exit()

renormalized_force_constants = get_renormalized_constants(args.frequency_file, structure, symmetrize=args.fcsymm)

if args.s:
    pho_interface.save_force_constants_to_file(renormalized_force_constants, filename=args.s[0])

if args.p:
    plot_renormalized_phonon_bands(structure, renormalized_force_constants)