# AUTOGENERATED! DO NOT EDIT! File to edit: 04_cli.ipynb (unless otherwise specified).

__all__ = ["md5sum_string", "md5sum_file", "S4RefSimTool", "command_line_script"]

# Cell

import os
import toml
import healpy as hp
import numpy as np
import h5py
from pathlib import Path
import logging as log
from datetime import date

from .core import get_telescope, parse_channels

from s4_design_sim_tool import __version__

from .foregrounds import load_sky_emission
from .atmosphere import load_atmosphere, get_telecope_years
from .noise import load_noise

# Cell

import hashlib


def md5sum_string(string):
    return hashlib.md5(string.encode("utf-8")).hexdigest()


def md5sum_file(filename):
    """Compute md5 checksum of the contents of a file"""
    return md5sum_string(open(filename, "r").read())


# Cell


class S4RefSimTool:
    def __init__(self, config_filename, output_folder="output"):
        """Simulate CMB-S4 maps based on the experiment configuration

        Parameters
        ----------
        config : filename
            CMB-S4 configuration stored in a TOML file
            see for example s4_reference_design.toml in the repository
        output_folder : str or Path
            Output path
        """
        self.config_filename = config_filename
        self.config = toml.load(self.config_filename)
        self.output_filename_template = "cmbs4_{tag}_KCMB_{telescope}-{band}_{site}_nside{nside}_{split}_of_{nsplits}.fits"
        self.output_folder = Path(output_folder)
        self.output_folder.mkdir(parents=True, exist_ok=True)

    def run(self, channels="all", sites=["Pole", "Chile"], write_outputs=True):
        """Run the simulation

        Parameters
        ----------
        channels : str or list[str]
            list of channel tags, e.g.
            * ["LFS1", "LFS2"] or
            * "SAT" or "LAT"
            * "all" (default)
        site : list[str]
            ['Pole'] or ['Chile'], by default ["Pole", "Chile"]
        write_outputs : bool
            if True write the outputs to disk, if False return them (not implemented yet)
        """
        nsplits = self.config["experiment"].get("number_of_splits", 1)
        assert (
            nsplits < 8
        ), "We currently only have 7 independent realizations of atmosphere and noise"
        for site in sites:
            for channel in parse_channels(channels):

                if get_telecope_years(self.config, site, channel) == 0:
                    continue
                telescope = get_telescope(channel)
                subfolder = self.output_folder / f"{telescope}-{channel}_{site.lower()}"
                subfolder.mkdir(parents=True, exist_ok=True)

                log.info(f"Simulate channel {channel} at {site}")
                sky_emission = load_sky_emission(
                    self.config["sky_emission"], site, channel
                )
                for split in range(nsplits + 1):
                    nside = 512 if telescope == "SAT" else 4096
                    output_filename = self.output_filename_template.format(
                        nside=nside,
                        telescope=telescope,
                        band=channel,
                        site=site.lower(),
                        tag="sky_atmosphere_noise",
                        split=max(1, split),  # split=0 is full mission and we want 1
                        nsplits=1 if split == 0 else nsplits,
                    )
                    if os.path.exists(subfolder / output_filename):
                        log.info("File %s already exists, SKIP", output_filename)
                        continue
                    output_map = load_atmosphere(
                        self.config, site, channel, realization=split
                    )
                    output_map += load_noise(
                        self.config, site, channel, realization=split
                    )
                    if split > 0:
                        output_map *= np.sqrt(nsplits)
                    output_map += sky_emission
                    # Use UNSEEN instead of nan for missing pixels
                    output_map[np.isnan(output_map)] = hp.UNSEEN
                    if write_outputs:

                        log.info(f"Writing {output_filename}")
                        hp.write_map(
                            subfolder / output_filename,
                            output_map,
                            column_units="K_CMB",
                            extra_header=[
                                ("SOFTWARE", "s4_design_sim_tool"),
                                ("SW_VERS", __version__),
                                ("SKY_VERS", "1.0"),
                                ("ATM_VERS", "1.0"),
                                ("NOI_VERS", "1.0"),
                                ("SITE", site),
                                ("SPLIT", split),
                                ("NSPLITS", nsplits),
                                ("CHANNEL", channel),
                                ("DATE", str(date.today())),
                                ("CONFMD5", md5sum_file(self.config_filename)),
                            ],
                            coord="Q",
                            overwrite=True,
                        )
                    else:
                        raise NotImplementedError("Only writing FITS output for now")


# Cell


def command_line_script(args=None):

    import logging as log

    log.basicConfig(level=log.INFO)

    import argparse

    parser = argparse.ArgumentParser(description="Run s4_design_sim_tool")
    parser.add_argument("config", type=str, help="Configuration file")
    parser.add_argument(
        "--channels",
        type=str,
        help="Channels e.g. all, SAT, LAT, LFL1 or comma separated list of channels",
        required=False,
        default="all",
    )
    parser.add_argument(
        "--site",
        type=str,
        help="Pole, Chile or all, default all",
        required=False,
        default="all",
    )
    parser.add_argument(
        "--output_folder",
        type=str,
        help="Output folder, optional",
        required=False,
        default="output",
    )
    res = parser.parse_args(args)
    if res.site == "all":
        sites = ["Chile", "Pole"]
    else:
        sites = [res.site]
    sim = S4RefSimTool(res.config, output_folder=res.output_folder)
    sim.run(channels=res.channels, sites=sites)
