#!/usr/bin/env python
import os
import sys
import parsnip
import argparse
import time
import astropy.table
from tqdm import tqdm


if __name__ == '__main__':
    start_time = time.time()

    parser = argparse.ArgumentParser(
        description='Generate predictions from a ParSNIP model for a dataset.'
    )
    parser.add_argument('predictions_path')
    parser.add_argument('model_path')
    parser.add_argument('dataset_path')

    parser.add_argument('--overwrite', action='store_true')
    parser.add_argument('--chunk_size', default=10000, type=int)
    parser.add_argument('--augments', default=0, type=int)

    parser.add_argument('--device', default='cuda')
    parser.add_argument('--threads', default=8, type=int)

    args = vars(parser.parse_args())

    predictions_path = args['predictions_path']
    if os.path.exists(predictions_path):
        if args['overwrite']:
            print(f"Predictions '{predictions_path}' already exist, overwriting!")
        else:
            print(f"Predictions '{predictions_path}' already exist, skipping!")
            sys.exit()

    # Load the model
    model = parsnip.load_model(
        args['model_path'],
        device=args['device'],
        threads=args['threads'],
    )

    # Load the metadata for the dataset. We parse the dataset in chunks since we can't
    # necessarily fit large datasets all in memory.
    dataset = parsnip.load_dataset(args['dataset_path'], in_memory=False)

    # Parse the dataset in chunks. For large datasets, we can't fit them all in memory
    # at the same time.
    chunk_size = args['chunk_size']
    num_chunks = dataset.count_chunks(chunk_size)

    # Optionally, the dataset can be augmented a given number of times.
    augments = args['augments']

    predictions = []

    for chunk in tqdm(dataset.iterate_chunks(chunk_size), total=num_chunks,
                      file=sys.stdout):
        # Preprocess the light curves
        chunk = model.preprocess(chunk, verbose=False)

        # Generate the prediction
        if augments == 0:
            chunk_predictions = model.predict_dataset(chunk)
        else:
            chunk_predictions = model.predict_dataset_augmented(chunk,
                                                                augments=augments)
        predictions.append(chunk_predictions)

    predictions = astropy.table.vstack(predictions, 'exact')

    # Save the predictions
    os.makedirs(os.path.dirname(predictions_path), exist_ok=True)

    # By default, assume that we are writing to HDF5 format. In this case, we serialize
    # the table to preserve masked columns and data types. Note that the output will
    # only be able to be read by astropy.table.Table.
    try:
        predictions.write(predictions_path, overwrite=True, serialize_meta=True,
                          path='/predictions')
    except TypeError:
        # Writing to some other format that doesn't support serialize_meta.
        print(f"WARNING: filetype given by '{predictions_path}' may not handle masked "
              "columns correctly. HDF5 format (extension .h5) is recommended.")
        predictions.write(predictions_path, overwrite=True)

    # Calculate time taken in minutes
    end_time = time.time()
    elapsed_time = (end_time - start_time) / 60.
    print(f"Total time: {elapsed_time:.2f} min")
