#!/usr/bin/env python
import argparse
import os
from os import path

import nibabel as nib
import numpy as np

import react
from react.utils import check_can_write_file

OUT_MASK_STAGE_1 = 'mask_stage1.nii.gz'
OUT_MASK_STAGE_2 = 'mask_stage2.nii.gz'

DESCRIPTION = f"""
v{react.__version__}
Receptor-Enriched Analysis of Functional Connectivity by Targets. All files 
must be in the same space.
"""
PROG = 'react_mask'
EPILOG = 'REFERENCE: https://doi.org/10.1016/j.neuroimage.2019.04.007 - ' \
         'Dipasquale, O., Selvaggi, P., Veronese, M., Gabay, A. S., ' \
         'Turkheimer, F., & Mehta, M. A. (2019). Receptor-Enriched Analysis ' \
         'of functional connectivity by targets (REACT): A novel, multimodal ' \
         'analytical approach informed by PET to study the pharmacodynamic ' \
         'response of the brain under MDMA. Neuroimage, 195, 252-260.'


def get_parsed_args():
    parser = argparse.ArgumentParser(prog=PROG, epilog=EPILOG,
                                     description=DESCRIPTION)

    parser.add_argument(
        'subject_list',
        type=str,
        help='Txt file reporting the subjects` data to be included in the '
             'mask. '
             'E.g., `home/study/data/subject_list.txt` .'
    )

    parser.add_argument(
        'pet_atlas',
        type=str,
        help='3D or 4D file containing the PET atlases '
             'to be used in the REACT analysis. '
             'E.g., `/home/study/data/PETatlas.nii.gz` .'

    )

    parser.add_argument(
        'gm_mask',
        type=str,
        help='Grey matter mask. '
             'E.g., `/home/study/data/GMmask.nii.gz` .'
    )

    parser.add_argument(
        'out_masks',
        type=str,
        help='Directory where the output masks will be saved. '
             'E.g., `/home/study/REACT/` .'
    )

    parser.add_argument(
        '--force',
        action='store_true',
        help="If set, allows to overwrite existing files."
    )

    return parser.parse_args()


def main():
    args = get_parsed_args()

    if not os.path.exists(args.out_masks):
        os.makedirs(args.out_masks)

    for f in [OUT_MASK_STAGE_1, OUT_MASK_STAGE_2]:
        check_can_write_file(path.join(args.out_masks, f), args.force)

    with open(args.subject_list, 'rt') as f:
        # it skips empty lines
        sublist = [l.rstrip('\n') for l in list(filter(str.rstrip, f))]

    volume = nib.load(sublist[0])
    data = np.zeros(volume.shape[:3])
    for i, subject in enumerate(sublist):
        v = np.std(nib.load(subject).get_fdata(), axis=3) > 0
        if not np.all(data.shape == v.shape):
            raise ValueError('Shape of input volume is incompatible.')
        data[v] += 1
    mask = data == len(sublist)

    gm_mask = nib.load(args.gm_mask).get_fdata() > 0
    if not np.all(gm_mask.shape == mask.shape):
        raise ValueError('Shape of input volume is incompatible.')
    mask_st2 = np.logical_and(mask, gm_mask)

    fout = path.join(args.out_masks, OUT_MASK_STAGE_2)
    nib.save(nib.Nifti1Image(mask_st2.astype(np.int8), affine=volume.affine),
             fout)

    volume = nib.load(args.pet_atlas)
    affine = volume.affine.copy()
    volume = volume.get_fdata()
    if volume.ndim == 3:
        volume = volume[..., np.newaxis]
    data = np.sum(volume > 0, axis=3) == volume.shape[3]
    if not np.all(mask_st2.shape == data.shape):
        raise ValueError('Shape of input volume is incompatible.')
    mask_st1 = np.logical_and(mask_st2, data > 0)

    fout = path.join(args.out_masks, OUT_MASK_STAGE_1)
    nib.save(nib.Nifti1Image(mask_st1.astype(np.int8), affine=affine),
             fout)


if __name__ == '__main__':
    main()
