#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Track eddy with Identification file produce with EddyIdentification
"""
import logging
from datetime import datetime
from glob import glob
from os import mkdir
from os.path import basename, dirname, exists
from os.path import join as join_path
from re import compile as re_compile

from netCDF4 import Dataset
from numpy import bytes_, empty, unique
from yaml import load as yaml_load

from py_eddy_tracker import EddyParser
from py_eddy_tracker.tracking import Correspondances

logger = logging.getLogger("pet")


def browse_dataset_in(
    data_dir,
    files_model,
    date_regexp,
    date_model,
    start_date=None,
    end_date=None,
    sub_sampling_step=1,
    files=None,
):
    pattern_regexp = re_compile(".*/" + date_regexp)
    if files is not None:
        filenames = bytes_(files)
    else:
        full_path = join_path(data_dir, files_model)
        logger.info("Search files : %s", full_path)
        filenames = bytes_(glob(full_path))

    dataset_list = empty(
        len(filenames), dtype=[("filename", "S500"), ("date", "datetime64[D]"),]
    )
    dataset_list["filename"] = filenames

    logger.info("%s grids available", dataset_list.shape[0])
    mode_attrs = False
    if "(" not in date_regexp:
        logger.debug("Attrs date : %s", date_regexp)
        mode_attrs = date_regexp.strip().split(":")
    else:
        logger.debug("Pattern date : %s", date_regexp)

    for item in dataset_list:
        str_date = None
        if mode_attrs:
            with Dataset(item["filename"].decode("utf-8")) as h:
                if len(mode_attrs) == 1:
                    str_date = getattr(h, mode_attrs[0])
                else:
                    str_date = getattr(h.variables[mode_attrs[0]], mode_attrs[1])
        else:
            result = pattern_regexp.match(str(item["filename"]))
            if result:
                str_date = result.groups()[0]

        if str_date is not None:
            item["date"] = datetime.strptime(str_date, date_model).date()

    dataset_list.sort(order=["date", "filename"])

    steps = unique(dataset_list["date"][1:] - dataset_list["date"][:-1])
    if len(steps) > 1:
        raise Exception("Several days steps in grid dataset %s" % steps)

    if sub_sampling_step != 1:
        logger.info("Grid subsampling %d", sub_sampling_step)
        dataset_list = dataset_list[::sub_sampling_step]

    if start_date is not None or end_date is not None:
        logger.info(
            "Available grid from %s to %s",
            dataset_list[0]["date"],
            dataset_list[-1]["date"],
        )
        logger.info("Filtering grid by time %s, %s", start_date, end_date)
        mask = (dataset_list["date"] >= start_date) * (dataset_list["date"] <= end_date)

        dataset_list = dataset_list[mask]
    return dataset_list


def usage():
    """Usage
    """
    # Run using:
    parser = EddyParser("Tool to use identification step to compute tracking")
    parser.add_argument("yaml_file", help="Yaml file to configure py-eddy-tracker")
    parser.add_argument("--correspondance_in", help="Filename of saved correspondance")
    parser.add_argument("--correspondance_out", help="Filename to save correspondance")
    parser.add_argument(
        "--save_correspondance_and_stop",
        action="store_true",
        help="Stop tracking after correspondance computation,"
        " merging can be done with EddyFinalTracking",
    )
    parser.add_argument(
        "--zarr", action="store_true", help="Output will be wrote in zarr"
    )
    parser.add_argument("--unraw", action="store_true", help="Load unraw data")
    parser.add_argument(
        "--blank_period",
        type=int,
        default=0,
        help="Nb of detection which will not use at the end of the period",
    )
    args = parser.parse_args()

    # Read yaml configuration file
    with open(args.yaml_file, "r") as stream:
        config = yaml_load(stream)
    if args.correspondance_in is not None and not exists(args.correspondance_in):
        args.correspondance_in = None
    return (
        config,
        args.save_correspondance_and_stop,
        args.correspondance_in,
        args.correspondance_out,
        args.blank_period,
        args.zarr,
        not args.unraw,
    )


if __name__ == "__main__":
    (
        CONFIG,
        SAVE_STOP,
        CORRESPONDANCES_IN,
        CORRESPONDANCES_OUT,
        BLANK_PERIOD,
        ZARR,
        RAW,
    ) = usage()
    # Create output directory
    SAVE_DIR = CONFIG["PATHS"].get("SAVE_DIR", None)
    if SAVE_DIR is not None and not exists(SAVE_DIR):
        mkdir(SAVE_DIR)

    YAML_CORRESPONDANCES_OUT = CONFIG["PATHS"].get("CORRESPONDANCES_OUT", None)
    if CORRESPONDANCES_IN is None:
        CORRESPONDANCES_IN = CONFIG["PATHS"].get("CORRESPONDANCES_IN", None)
    if CORRESPONDANCES_OUT is None:
        CORRESPONDANCES_OUT = YAML_CORRESPONDANCES_OUT
    if YAML_CORRESPONDANCES_OUT is None and CORRESPONDANCES_OUT is None:
        CORRESPONDANCES_OUT = "{path}/{sign_type}_correspondances.nc"

    CLASS = None
    CLASS_KW = dict()
    if "CLASS" in CONFIG:
        CLASSNAME = CONFIG["CLASS"]["CLASS"]
        CLASS = getattr(
            __import__(CONFIG["CLASS"]["MODULE"], globals(), locals(), CLASSNAME),
            CLASSNAME,
        )
        CLASS_KW = CONFIG["CLASS"].get("OPTIONS", dict())

    NB_VIRTUAL_OBS_MAX_BY_SEGMENT = int(CONFIG.get("VIRTUAL_LENGTH_MAX", 0))

    if isinstance(CONFIG["PATHS"]["FILES_PATTERN"], list):
        DATASET_LIST = browse_dataset_in(
            data_dir=None,
            files_model=None,
            files=CONFIG["PATHS"]["FILES_PATTERN"],
            date_regexp=".*_([0-9]*?).[nz].*",
            date_model="%Y%m%d",
        )
    else:
        DATASET_LIST = browse_dataset_in(
            data_dir=dirname(CONFIG["PATHS"]["FILES_PATTERN"]),
            files_model=basename(CONFIG["PATHS"]["FILES_PATTERN"]),
            date_regexp=".*_([0-9]*?).[nz].*",
            date_model="%Y%m%d",
        )

    if BLANK_PERIOD > 0:
        DATASET_LIST = DATASET_LIST[:-BLANK_PERIOD]
        logger.info("Last %d files will be pop", BLANK_PERIOD)

    START_TIME = datetime.now()
    logger.info("Start tracking on %d files", len(DATASET_LIST))

    NB_OBS_MIN = int(CONFIG.get("TRACK_DURATION_MIN", 14))
    if NB_OBS_MIN > len(DATASET_LIST):
        raise Exception(
            "Input file number (%s) is shorter than TRACK_DURATION_MIN (%s)."
            % (len(DATASET_LIST), NB_OBS_MIN)
        )

    CORRESPONDANCES = Correspondances(
        datasets=DATASET_LIST["filename"],
        virtual=NB_VIRTUAL_OBS_MAX_BY_SEGMENT,
        class_method=CLASS,
        class_kw=CLASS_KW,
        previous_correspondance=CORRESPONDANCES_IN,
    )
    CORRESPONDANCES.track()
    logger.info("Track finish")

    logger.info("Start merging")
    DATE_START, DATE_STOP = CORRESPONDANCES.period
    DICT_COMPLETION = dict(
        date_start=DATE_START,
        date_stop=DATE_STOP,
        date_prod=START_TIME,
        path=SAVE_DIR,
        sign_type=CORRESPONDANCES.current_obs.sign_legend,
    )

    CORRESPONDANCES.save(CORRESPONDANCES_OUT, DICT_COMPLETION)
    if SAVE_STOP:
        exit()

    # Merge correspondance, only do if we stop and store just after compute of correspondance
    CORRESPONDANCES.prepare_merging()

    logger.info(
        "Longer track saved have %d obs", CORRESPONDANCES.nb_obs_by_tracks.max()
    )
    logger.info(
        "The mean length is %d observations before filtering",
        CORRESPONDANCES.nb_obs_by_tracks.mean(),
    )

    CORRESPONDANCES.get_unused_data(raw_data=RAW).write_file(
        path=SAVE_DIR, filename="%(path)s/%(sign_type)s_untracked.nc", zarr_flag=ZARR
    )

    SHORT_CORRESPONDANCES = CORRESPONDANCES._copy()
    SHORT_CORRESPONDANCES.shorter_than(size_max=NB_OBS_MIN)

    CORRESPONDANCES.longer_than(size_min=NB_OBS_MIN)

    FINAL_EDDIES = CORRESPONDANCES.merge(raw_data=RAW)
    SHORT_TRACK = SHORT_CORRESPONDANCES.merge(raw_data=RAW)

    # We flag obs
    if CORRESPONDANCES.virtual:
        FINAL_EDDIES["virtual"][:] = FINAL_EDDIES["time"] == 0
        FINAL_EDDIES.filled_by_interpolation(FINAL_EDDIES["virtual"] == 1)
        SHORT_TRACK["virtual"][:] = SHORT_TRACK["time"] == 0
        SHORT_TRACK.filled_by_interpolation(SHORT_TRACK["virtual"] == 1)

    # Total running time
    FULL_TIME = datetime.now() - START_TIME
    logger.info("Mean duration by loop : %s", FULL_TIME / (len(DATASET_LIST) - 1))
    logger.info("Duration : %s", FULL_TIME)

    logger.info(
        "Longer track saved have %d obs", CORRESPONDANCES.nb_obs_by_tracks.max()
    )
    logger.info(
        "The mean length is %d observations after filtering",
        CORRESPONDANCES.nb_obs_by_tracks.mean(),
    )

    FINAL_EDDIES.write_file(path=SAVE_DIR, zarr_flag=ZARR)
    SHORT_TRACK.write_file(
        filename="%(path)s/%(sign_type)s_track_too_short.nc",
        path=SAVE_DIR,
        zarr_flag=ZARR,
    )

