#! /usr/bin/env python
"""
zort-initialize :
Create an objects file and rcip map for all lightcurve files in a directory.
These initialization files speed up quick reading for zort.
"""

import argparse
import sys
import subprocess
import numpy as np

from zort.initialize import \
    gather_lightcurve_files, generate_objects_file, generate_rcid_map


def main():
    # Get arguments
    parser = argparse.ArgumentParser(description=__doc__)
    arguments = parser.add_argument_group('arguments')
    arguments.add_argument('--lightcurve-file-directory', type=str,
                           help='Directory containing lightcurve files.',
                           required=True)

    parallelgroup = parser.add_mutually_exclusive_group()
    parallelgroup.add_argument('--single', dest='parallelFlag',
                               action='store_false',
                               help='Run in single mode. DEFAULT.')
    parallelgroup.add_argument('--parallel', dest='parallelFlag',
                               action='store_true',
                               help='Run in parallel mode. Requires mpi4py.')
    parser.set_defaults(parallelFlag=False)

    args = parser.parse_args()

    if args.parallelFlag:
        reqs = subprocess.check_output([sys.executable, '-m', 'pip', 'freeze'])
        installed_packages = [r.decode().split('==')[0] for r in reqs.split()]
        if 'mpi4py' not in installed_packages:
            print('mpi4py must be installed to use --parallel mode.')
            sys.exit(0)

    lightcurve_files = gather_lightcurve_files(args.lightcurve_file_directory)
    if args.parallelFlag:
        from mpi4py import MPI
        comm = MPI.COMM_WORLD
        rank = comm.Get_rank()
        size = comm.Get_size()
    else:
        rank = 0
        size = 1

    if rank == 0:
        print('Generating object files and RCID maps '
              'for %i lightcurve files' % len(lightcurve_files))

    my_lightcurve_files = np.array_split(lightcurve_files, size)[rank]
    for lightcurve_file in my_lightcurve_files:
        generate_objects_file(lightcurve_file)
        generate_rcid_map(lightcurve_file)


if __name__ == '__main__':
    main()
