#!/usr/bin/env python
import argparse
import logging
import os
from os import path

import nibabel as nib
import numpy as np

import react
from react.utils import check_can_write_file

OUT_MASK_STAGE_1 = 'mask_stage1.nii.gz'
OUT_MASK_STAGE_2 = 'mask_stage2.nii.gz'

DESCRIPTION = f"""
v{react.__version__}
Receptor-Enriched Analysis of Functional Connectivity by Targets. All files 
must be in the same space.
"""
PROG = 'react_mask'
EPILOG = 'REFERENCE: https://doi.org/10.1016/j.neuroimage.2019.04.007 - ' \
         'Dipasquale, O., Selvaggi, P., Veronese, M., Gabay, A. S., ' \
         'Turkheimer, F., & Mehta, M. A. (2019). Receptor-Enriched Analysis ' \
         'of functional connectivity by targets (REACT): A novel, multimodal ' \
         'analytical approach informed by PET to study the pharmacodynamic ' \
         'response of the brain under MDMA. Neuroimage, 195, 252-260.'


def get_parsed_args():
    parser = argparse.ArgumentParser(prog=PROG, epilog=EPILOG,
        description=DESCRIPTION)

    parser.add_argument(
        'subject_list',
        type=str,
        help='Txt file reporting the subjects` data to be included in the '
             'mask. '
             'E.g., `home/study/data/subject_list.txt` .'
    )

    parser.add_argument(
        'pet_atlas',
        type=str,
        help='3D or 4D file containing the PET atlases '
             'to be used in the REACT analysis. '
             'E.g., `/home/study/data/PETatlas.nii.gz` .'

    )

    parser.add_argument(
        'gm_mask',
        type=str,
        help='Grey matter mask. '
             'E.g., `/home/study/data/GMmask.nii.gz` .'
    )

    parser.add_argument(
        'out_masks',
        type=str,
        help='Directory where the output masks will be saved. '
             'E.g., `/home/study/REACT/` .'
    )

    parser.add_argument(
        '-v', '--verbose',
        action='store_true',
        help='Set verbose output.'
    )

    parser.add_argument(
        '--force',
        action='store_true',
        help="If set, allows to overwrite existing files."
    )

    return parser.parse_args()


def main():
    args = get_parsed_args()

    if args.verbose:
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )

    if not os.path.exists(args.out_masks):
        os.makedirs(args.out_masks)

    for f in [OUT_MASK_STAGE_1, OUT_MASK_STAGE_2]:
        check_can_write_file(path.join(args.out_masks, f), args.force)

    logging.info(f'Reading subject list from {args.subject_list}')
    with open(args.subject_list, 'rt') as f:
        # it skips empty lines
        subject_list = [l.rstrip('\n') for l in list(filter(str.rstrip, f))]
    logging.info(f'Found {len(subject_list)} subjects')

    volume = nib.load(subject_list[0])
    data = np.zeros(volume.shape[:3])
    for i, subject in enumerate(subject_list):
        logging.info(f'Loading {subject}')
        fmri_subject = nib.load(subject).get_fdata()
        logging.info(f'\tshape={fmri_subject.shape}')

        v = np.std(fmri_subject, axis=3) > 0
        if not np.all(data.shape == v.shape):
            raise ValueError('Shape of input volume is incompatible.')
        data[v] += 1
    # your mask is the set of voxels where all subjects have std > 0
    mask = data == len(subject_list)

    gm_mask = nib.load(args.gm_mask).get_fdata() > 0
    if not np.all(gm_mask.shape == mask.shape):
        raise ValueError('Shape of input volume is incompatible.')
    mask_st2 = np.logical_and(mask, gm_mask)
    logging.info(f'Mask stage 2 contains {np.count_nonzero(mask_st2)} '
                 f'non-zero voxels')

    fout = path.join(args.out_masks, OUT_MASK_STAGE_2)
    logging.info(f'Saving mask stage 2 in {fout}')
    nib.save(nib.Nifti1Image(mask_st2.astype(np.int8), affine=volume.affine),
        fout)

    logging.info(f'Loading PET atlas from {args.pet_atlas}')
    pet_volume = nib.load(args.pet_atlas)
    affine = pet_volume.affine.copy()
    pet_volume = pet_volume.get_fdata()
    logging.info(f'\tshape={pet_volume.shape}')
    if pet_volume.ndim == 3:
        pet_volume = pet_volume[..., np.newaxis]
    if not np.all(mask_st2.shape == pet_volume.shape[:3]):
        raise ValueError('Shape of PET atlas must be compatible with'
                         'the one of fMRI.')
    # check where all pet atlases have non-zero values
    data = np.sum(pet_volume > 0, axis=3) == pet_volume.shape[3]
    mask_st1 = np.logical_and(mask_st2, data > 0)
    logging.info(f'Mask stage 1 contains {np.count_nonzero(mask_st1)} '
                 f'non-zero voxels')

    fout = path.join(args.out_masks, OUT_MASK_STAGE_1)
    logging.info(f'Saving mask stage 1 in {fout}')
    nib.save(nib.Nifti1Image(mask_st1.astype(np.int8), affine=affine), fout)

    logging.info('react_masks: done :)')


if __name__ == '__main__':
    main()
