#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Track eddy with Identification file produce with EddyIdentification
"""
from py_eddy_tracker import EddyParser
from yaml import load as yaml_load
from py_eddy_tracker.tracking import Correspondances
from os.path import exists, dirname, basename
from os import mkdir
from re import compile as re_compile
from os.path import join as join_path
from numpy import bytes_, empty, unique
from netCDF4 import Dataset
from datetime import datetime
from glob import glob
import logging
import datetime as dt

logger = logging.getLogger("pet")


def browse_dataset_in(
    data_dir,
    files_model,
    date_regexp,
    date_model,
    start_date=None,
    end_date=None,
    sub_sampling_step=1,
    files=None,
):
    if files is not None:
        pattern_regexp = re_compile(".*/" + date_regexp)
        filenames = bytes_(files)
    else:
        pattern_regexp = re_compile(".*/" + date_regexp)
        full_path = join_path(data_dir, files_model)
        logger.info("Search files : %s", full_path)
        filenames = bytes_(glob(full_path))

    dataset_list = empty(
        len(filenames), dtype=[("filename", "S500"), ("date", "datetime64[D]"),]
    )
    dataset_list["filename"] = filenames

    logger.info("%s grids available", dataset_list.shape[0])
    mode_attrs = False
    if "(" not in date_regexp:
        logger.debug("Attrs date : %s", date_regexp)
        mode_attrs = date_regexp.strip().split(":")
    else:
        logger.debug("Pattern date : %s", date_regexp)

    for item in dataset_list:
        str_date = None
        if mode_attrs:
            with Dataset(item["filename"].decode("utf-8")) as h:
                if len(mode_attrs) == 1:
                    str_date = getattr(h, mode_attrs[0])
                else:
                    str_date = getattr(h.variables[mode_attrs[0]], mode_attrs[1])
        else:
            result = pattern_regexp.match(str(item["filename"]))
            if result:
                str_date = result.groups()[0]

        if str_date is not None:
            item["date"] = datetime.strptime(str_date, date_model).date()

    dataset_list.sort(order=["date", "filename"])

    steps = unique(dataset_list["date"][1:] - dataset_list["date"][:-1])
    if len(steps) > 1:
        raise Exception("Several days steps in grid dataset %s" % steps)

    if sub_sampling_step != 1:
        logger.info("Grid subsampling %d", sub_sampling_step)
        dataset_list = dataset_list[::sub_sampling_step]

    if start_date is not None or end_date is not None:
        logger.info(
            "Available grid from %s to %s",
            dataset_list[0]["date"],
            dataset_list[-1]["date"],
        )
        logger.info("Filtering grid by time %s, %s", start_date, end_date)
        mask = (dataset_list["date"] >= start_date) * (dataset_list["date"] <= end_date)

        dataset_list = dataset_list[mask]
    return dataset_list


def usage():
    """Usage
    """
    # Run using:
    parser = EddyParser("Tool to use identification step to compute tracking")
    parser.add_argument("yaml_file", help="Yaml file to configure py-eddy-tracker")
    parser.add_argument("--correspondance_in", help="Filename of saved correspondance")
    parser.add_argument("--correspondance_out", help="Filename to save correspondance")
    parser.add_argument(
        "--save_correspondance_and_stop",
        action="store_true",
        help="Stop tracking after correspondance computation,"
        " merging can be done with EddyFinalTracking",
    )
    parser.add_argument(
        "--zarr", action="store_true", help="Output will be wrote in zarr"
    )
    parser.add_argument("--unraw", action="store_true", help="Load unraw data")
    parser.add_argument(
        "--blank_period",
        type=int,
        default=0,
        help="Nb of detection which will not use at the end of the period",
    )
    args = parser.parse_args()

    # Read yaml configuration file
    with open(args.yaml_file, "r") as stream:
        config = yaml_load(stream)
    if args.correspondance_in is not None and not exists(args.correspondance_in):
        args.correspondance_in = None
    return (
        config,
        args.save_correspondance_and_stop,
        args.correspondance_in,
        args.correspondance_out,
        args.blank_period,
        args.zarr,
        not args.unraw,
    )


if __name__ == "__main__":
    (
        CONFIG,
        SAVE_STOP,
        CORRESPONDANCES_IN,
        CORRESPONDANCES_OUT,
        BLANK_PERIOD,
        ZARR,
        RAW,
    ) = usage()
    # Create output directory
    SAVE_DIR = CONFIG["PATHS"].get("SAVE_DIR", None)
    if SAVE_DIR is not None and not exists(SAVE_DIR):
        mkdir(SAVE_DIR)

    YAML_CORRESPONDANCES_IN = CONFIG["PATHS"].get("CORRESPONDANCES_IN", None)
    YAML_CORRESPONDANCES_OUT = CONFIG["PATHS"].get("CORRESPONDANCES_OUT", None)
    if CORRESPONDANCES_IN is None:
        CORRESPONDANCES_IN = YAML_CORRESPONDANCES_IN
    if CORRESPONDANCES_OUT is None:
        CORRESPONDANCES_OUT = YAML_CORRESPONDANCES_OUT
    if YAML_CORRESPONDANCES_OUT is None and CORRESPONDANCES_OUT is None:
        CORRESPONDANCES_OUT = "{path}/{sign_type}_correspondances.nc"

    if "CLASS" in CONFIG:
        CLASS = getattr(
            __import__(
                CONFIG["CLASS"]["MODULE"], globals(), locals(), CONFIG["CLASS"]["CLASS"]
            ),
            CONFIG["CLASS"]["CLASS"],
        )
    else:
        CLASS = None

    NB_VIRTUAL_OBS_MAX_BY_SEGMENT = int(CONFIG.get("VIRTUAL_LENGTH_MAX", 0))

    if isinstance(CONFIG["PATHS"]["FILES_PATTERN"], list):
        DATASET_LIST = browse_dataset_in(
            data_dir=None,
            files_model=None,
            files=CONFIG["PATHS"]["FILES_PATTERN"],
            date_regexp=".*c_([0-9]*?).[nz].*",
            date_model="%Y%m%d",
        )
    else:
        DATASET_LIST = browse_dataset_in(
            data_dir=dirname(CONFIG["PATHS"]["FILES_PATTERN"]),
            files_model=basename(CONFIG["PATHS"]["FILES_PATTERN"]),
            date_regexp=".*c_([0-9]*?).[nz].*",
            date_model="%Y%m%d",
        )

    if BLANK_PERIOD > 0:
        DATASET_LIST = DATASET_LIST[:-BLANK_PERIOD]
        logger.info("Last %d files will be pop", BLANK_PERIOD)

    START_TIME = dt.datetime.now()
    logger.info("Start tracking on %d files", len(DATASET_LIST))

    NB_OBS_MIN = int(CONFIG.get("TRACK_DURATION_MIN", 14))
    if NB_OBS_MIN > len(DATASET_LIST):
        raise Exception(
            "Input file number (%s) is shorter than TRACK_DURATION_MIN (%s)."
            % (len(DATASET_LIST), NB_OBS_MIN)
        )

    CORRESPONDANCES = Correspondances(
        datasets=DATASET_LIST["filename"],
        virtual=NB_VIRTUAL_OBS_MAX_BY_SEGMENT,
        class_method=CLASS,
        previous_correspondance=CORRESPONDANCES_IN,
    )

    CORRESPONDANCES.track()
    logger.info("Track finish")
    logger.info("Start merging")

    DATE_START, DATE_STOP = CORRESPONDANCES.period
    DICT_COMPLETION = dict(
        date_start=DATE_START,
        date_stop=DATE_STOP,
        date_prod=START_TIME,
        path=SAVE_DIR,
        sign_type=CORRESPONDANCES.current_obs.sign_legend,
    )

    CORRESPONDANCES.save(CORRESPONDANCES_OUT, DICT_COMPLETION)
    if SAVE_STOP:
        exit()

    # Merge correspondance, only do if we stop and store just after compute of correspondance
    CORRESPONDANCES.prepare_merging()

    logger.info(
        "Longer track saved have %d obs", CORRESPONDANCES.nb_obs_by_tracks.max()
    )
    logger.info(
        "The mean length is %d observations before filtering",
        CORRESPONDANCES.nb_obs_by_tracks.mean(),
    )

    CORRESPONDANCES.get_unused_data(raw_data=RAW).write_file(
        path=SAVE_DIR, filename="%(path)s/%(sign_type)s_untracked.nc", zarr_flag=ZARR
    )

    SHORT_CORRESPONDANCES = CORRESPONDANCES._copy()
    SHORT_CORRESPONDANCES.shorter_than(size_max=NB_OBS_MIN)

    CORRESPONDANCES.longer_than(size_min=NB_OBS_MIN)

    FINAL_EDDIES = CORRESPONDANCES.merge(raw_data=RAW)
    SHORT_TRACK = SHORT_CORRESPONDANCES.merge(raw_data=RAW)

    # We flag obs
    if CORRESPONDANCES.virtual:
        FINAL_EDDIES["virtual"][:] = FINAL_EDDIES["time"] == 0
        FINAL_EDDIES.filled_by_interpolation(FINAL_EDDIES["virtual"] == 1)
        SHORT_TRACK["virtual"][:] = SHORT_TRACK["time"] == 0
        SHORT_TRACK.filled_by_interpolation(SHORT_TRACK["virtual"] == 1)

    # Total running time
    FULL_TIME = dt.datetime.now() - START_TIME
    logger.info("Mean duration by loop : %s", FULL_TIME / (len(DATASET_LIST) - 1))
    logger.info("Duration : %s", FULL_TIME)

    logger.info(
        "Longer track saved have %d obs", CORRESPONDANCES.nb_obs_by_tracks.max()
    )
    logger.info(
        "The mean length is %d observations after filtering",
        CORRESPONDANCES.nb_obs_by_tracks.mean(),
    )

    FINAL_EDDIES.write_file(path=SAVE_DIR, zarr_flag=ZARR)
    SHORT_TRACK.write_file(
        filename="%(path)s/%(sign_type)s_track_too_short.nc",
        path=SAVE_DIR,
        zarr_flag=ZARR,
    )

