#!/usr/bin/env python

import h5py
import sys
import numpy as np
import argparse

parser = argparse.ArgumentParser(description='concath5')

parser.add_argument('input', metavar='input_files', type=str, nargs='+',
                   help='files to be joined')

parser.add_argument('output', metavar='output_file', type=str, nargs=1,
                   help='joined file')

parser.add_argument('-f', metavar='N', type=int, nargs=1, default=[0],
                   help='first steps to be skipped (default:0)')

parser.add_argument('--velocity_only', action='store_true',
                   help='Do not save trajectory in output file')

args = parser.parse_args()

#print(args.n)
#print(args.input)
#print(args.output)

def read_file_hdf5(file_name):

    trajectory = None
    velocity = None
    vc = None
    reduced_q_vector = None

    hdf5_file = h5py.File(file_name, "r")

    time = hdf5_file['time'][:]
    supercell = hdf5_file['super_cell'][:]

    if 'trajectory' in hdf5_file:
        trajectory = hdf5_file['trajectory'][:]
    if 'velocity' in hdf5_file:
        velocity = hdf5_file['velocity'][:]
    if 'vc' in hdf5_file:
        vc = hdf5_file['vc'][:]
        reduced_q_vector = hdf5_file['reduced_q_vector'][:]


    hdf5_file.close()

    return vc, velocity, trajectory, time, supercell, reduced_q_vector

def save_data_hdf5(file_name, time, supercell, trajectory=None, velocity=None, vc=None, reduced_q_vector=None):

    hdf5_file = h5py.File(file_name, "w")

    hdf5_file.create_dataset('time', data=time)
    hdf5_file.create_dataset('super_cell', data=supercell)

    if reduced_q_vector is not None:
        hdf5_file.create_dataset('reduced_q_vector', data=reduced_q_vector)

    steps = 0

    if trajectory is not None:
        hdf5_file.create_dataset('trajectory', data=trajectory)

    if velocity is not None:
        hdf5_file.create_dataset('velocity', data=velocity)
        steps = velocity.shape[0]

    if vc is not None:
        hdf5_file.create_dataset('vc', data=vc)
        steps = vc.shape[0]


    print ('saved {0} steps'.format(steps))

    hdf5_file.close()

if len(sys.argv) == 1:
    print('Usage: concath5 file.h5 file2.h5 ...  concatenated.h5')
    exit()


first = args.f[0]
print("skipping first {0} steps".format(first))
vc, velocity, trajectory, time, supercell, r_q_vector = read_file_hdf5(args.input[0])


if args.velocity_only:
    trajectory = None
    vc = None

if velocity is not None:
    velocity = velocity[first:]

if vc is not None:
    vc = vc[first:]

if trajectory is not None:
    trajectory = trajectory[first:]

print(args.input[0])

for arg in args.input[1:]:
    print(arg)
    new_vc, new_velocity, new_trajectory = read_file_hdf5(arg)[0:3]

    if trajectory is not None:
        trajectory = np.concatenate((trajectory, new_trajectory[first:]), axis=0)

    if velocity is not None:
        velocity = np.concatenate((velocity, new_velocity[first:]), axis=0)

    if vc is not  None:
        vc = np.concatenate((vc, new_vc[first:]), axis=0)

print ('Final: ', args.output[0])

save_data_hdf5(args.output[0],
               time,
               supercell,
               trajectory=trajectory,
               velocity=velocity,
               vc=vc,
               reduced_q_vector=r_q_vector)
