# -*- encoding: utf-8 -*-

from dataclasses import dataclass, fields
from datetime import date
import logging as log
import math
import os
from pathlib import Path
from typing import Any, Dict, Union

import toml
import numpy as np
import healpy as hp
import pysm3
import pysm3.units as u

import litebird_sim as lbs


@dataclass
class _InstrumentFreq:
    bandcenter_ghz: float = 0.0  # bandcenter_ghz
    bandwidth_ghz: float = 0.0  # d.bandwidth_ghz
    fwhm_arcmin: float = 0.0  # fwhm_arcmin
    p_sens_ukarcmin: float = 0.0  # pol_sensitivity_ukarcmin

    @staticmethod
    def from_dict(dictionary):
        return _InstrumentFreq(
            bandcenter_ghz=dictionary["bandcenter_ghz"],
            bandwidth_ghz=dictionary["bandwidth_ghz"],
            fwhm_arcmin=dictionary["fwhm_arcmin"],
            p_sens_ukarcmin=dictionary["p_sens_ukarcmin"],
        )

    @staticmethod
    def from_npy(filename):
        data = np.load(filename, allow_picke=True).item()
        return _InstrumentFreq(
            bandcenter_ghz=data["freq"],
            bandwidth_ghz=data["freq_band"],
            fwhm_arcmin=data["beam"],
            p_sens_ukarcmin=data["p_sens"],
        )

    @staticmethod
    def from_toml(filename):
        data = toml.load(filename)
        return _InstrumentFreq(
            bandcenter_ghz=data["freq"],
            bandwidth_ghz=data["freq_band"],
            fwhm_arcmin=data["beam"],
            p_sens_ukarcmin=data["p_sens"],
        )


@dataclass
class MbsSavedMapInfo:
    """Information about a map saved by the :class:`.Mbs` class

    Instances of this class are usually returned by the method
    :meth:`.Mbs.run_all`. They have the following fields:

    - ``path``: a `pathlib.Path` object representing the Healpix FITS
      map saved on disk

    - ``component``: a string identifying the component; possible
      values are ``cmb``, ``fg``, ``noise``, and ``coadded``.

    - ``channel``: the name of the frequency channel (it is ``None``
      for CMB maps)

    - ``mc_num``: it can either be ``None`` or the number of a Monte
      Carlo realization

    """

    path: Union[Path, None] = None
    component: str = ""
    channel: Union[str, None] = None
    mc_num: Union[int, None] = None


@dataclass
class MbsParameters:
    """A class that specifies how sky maps should be generated by PySM3

    This class is used to specify how an instance the :class:`.Mbs`
    class generates sky maps. You can choose to include the following
    components in the maps:

    - CMB signal (use ``make_cmb=True``);

    - Foreground signal (use ``make_fg=True``);

    - Noise (use ``make_noise=True``); this should be used only when
      running map-based simulations, because if you are creating
      time-ordered data, chances are that you want to inject noise
      directly in the timelines.

    The full list of parameters and their default values are listed
    here:

    - ``nside`` (default: 512): value of the ``NSIDE`` parameter to
      use when creating the maps

    - ``save`` (default: ``False``): when ``True``, Healpix maps are
      saved as FITS files in the output path (specified by a
      :class:`.Simulation` object)

    - ``gaussian_smooth`` (default: ``False``): when ``True``, maps
      are smoothed using a Gaussian beam whose width is the FWHM of
      the frequency channel.

    - ``bandpass_int`` (default: ``False``): when ``True``, maps are
      integrated over the bandpass of each detector

    - ``coadd`` (default: ``None``): when ``True``, maps are coadded

    - ``parallel_mc`` (default: ``False``): when ``True``, Monte Carlo
      realizations are computed in parallel using MPI

    - ``make_noise`` (default: ``False``): when ``True``, noise maps
      are generated and added to the final result;

    - ``nmc_noise`` (default: 1): number of Monte Carlo runs for noise
      maps

    - ``seed_noise`` (default: ``None``): when specified, it must be a
      positive integer number used as the seed of the pseudo-random
      number generator

    - ``n_split`` (default: ``False``): when specified as an integer,
      number of splits used to generate noise maps

    - ``make_cmb`` (default: ``True``): when ``True``, a CMB map is
      generated and added to the final result

    - ``cmb_ps_file`` (default: ``""``): if specified, it must be the
      path to a FITS file containing the CMB power spectrum to be used
      to generate the CMB map

    - ``cmb_r`` (default: 0.0): value of the tensor-to-scalar ratio to
      be used when generating the CMB map

    - ``nmc_cmb`` (default: 1): number of Monte Carlo realizations for
      CMB maps

    - ``seed_cmb`` (default: None): if specified as an integer, the
      seed used to initialize the random number generator used to
      create a CMB realization

    - ``make_fg` (default: ``False``): when ``True``, a foreground map
      is generated and added to the maps

    - ``fg_models`` (default: ``None``): this specifies the foreground
      models as a list of strings, each specifying a foreground
      component. Alternatively, it can be provided as a dictionary
      which associates a name to a component, e.g., ``{"ame":
      "pysm_ame_1"}``.

    - ``output_string`` (default: ``""``): a string used to build the
      file names of the Healpix FITS files saved by the :class:`.Mbs`
      class.

    - ``units`` (default: ``K_CMB``): string used to specify the
      measurement unit of the pixels in the maps generated by the
      :class:`.Mbs` class. It follows the naming of the pysm3 units,
      e.g. "K_CMB" / "K_RJ" / "uK_CMB" / "uK_RJ"

    - ``maps_in_ecliptic`` (default: ``False``): when ``True`` the maps
      contained in the dictionary returned by `Mbs.run_all` are converted
      in ecliptic coordinates using the `healpy` routine `rotate_map_alms`

    """

    nside: int = 512
    save: bool = False
    gaussian_smooth: bool = False
    bandpass_int: bool = False
    coadd: Union[bool, None] = None
    parallel_mc: bool = False
    make_noise: bool = False
    nmc_noise: int = 1
    seed_noise: Union[int, None] = None
    n_split: Union[int, bool] = False
    make_cmb: bool = True
    cmb_ps_file: str = ""
    cmb_r: float = 0.0
    nmc_cmb: int = 1
    seed_cmb: Union[int, None] = None
    make_fg: bool = False
    fg_models: Union[Dict[str, Any], None] = None
    output_string: Union[str, None] = None
    units: str = "K_CMB"
    maps_in_ecliptic: bool = False

    def __post_init__(self):
        if self.n_split == 1:
            self.n_split = False

        if self.coadd is None:
            if self.save:
                self.coadd = False
            else:
                self.coadd = True

        if self.fg_models:
            try:
                del self.fg_models["make_fg"]
            except (TypeError, KeyError):
                pass

            if isinstance(self.fg_models, list):
                self.fg_models = {x: x for x in self.fg_models}

        if not self.output_string:
            self.output_string = "date_" + date.today().strftime("%y%m%d")

    @staticmethod
    def from_dict(dictionary):
        """Create a :class:`.MbsParameters` instance from a dictionary

        Typically, the dictionary passed to this static method should
        be taken from a TOML file, possibly through the `param` field
        of a :class:`.Simulation` object, like in the following
        example::

            import litebird_sim as lbs

            sim = lbs.Simulation(parameter_file="config.toml")
            mbs_params = lbs.MbsParameters.from_dict(sim.params.["mbs_params"])

        """
        result = MbsParameters()
        for sub_dict_name in ["general", "noise", "cmb", "fg", "output"]:
            sub_dict = dictionary.get(sub_dict_name, {})
            for param in fields(result):
                if param.name in sub_dict:
                    setattr(result, param.name, sub_dict[param.name])

        result.fg_models = dictionary.get("fg", {})

        # This is not a constructor, so we need to call it manually
        result.__post_init__()
        return result


def _from_sens_to_rms(sens, nside):
    return sens / hp.nside2resol(nside, arcmin=True)


class Mbs:
    """A class that generates sky maps.

    This class is used to generate synthetic maps of the sky starting
    from the definition of a set of frequency channels and detectors.

    The class uses PySM3 to generate the maps. The parameters used to
    set up the way sky components are generated are passed through an
    instance to a :class:`.MbsParameters` class, like in the following
    example::

        import litebird_sim as lbs

        sim = lbs.Simulation()
        params = lbs.MbsParameters(make_cmb=True)
        mbs = lbs.Mbs(
            simulation=sim,
            parameters=params,
            channel_list=[
                lbs.FreqChannelInfo.from_imo(
                    sim.imo,
                    "/releases/v1.0/satellite/LFT/L1-040/channel_info",
                ),
            ],
        )
        (healpix_maps, file_paths) = mbs.run_all()

    """

    def __init__(
        self,
        simulation,
        parameters: Union[Dict[str, Any], str, MbsParameters] = MbsParameters(),
        instrument=None,
        detector_list=None,
        channel_list=None,
    ):
        self.sim = simulation
        self.imo = self.sim.imo

        if isinstance(parameters, MbsParameters):
            self.params = parameters
        elif isinstance(parameters, str):
            self.params = MbsParameters.from_dict(simulation.parameters[parameters])
        else:
            self.params = MbsParameters.from_dict(parameters)

        self.instrument = instrument
        self.det_list = detector_list
        self.ch_list = channel_list
        self.pysm_units = u.Unit(self.params.units)

    def _parse_instrument_from_det_list(self):
        self.instrument = {}
        try:
            len(self.det_list)
        except TypeError:
            self.det_list = [self.det_list]

        for d in self.det_list:
            name = d.name.replace(" ", "_")
            self.instrument[name] = _InstrumentFreq(
                bandcenter_ghz=d.bandcenter_ghz,
                bandwidth_ghz=d.bandwidth_ghz,
                fwhm_arcmin=d.fwhm_arcmin,
                p_sens_ukarcmin=d.pol_sensitivity_ukarcmin,
            )

    def _parse_instrument_from_ch_list(self):
        self.instrument = {}
        try:
            len(self.ch_list)
        except TypeError:
            self.ch_list = [self.ch_list]
        for c in self.ch_list:
            name = c.channel.replace(" ", "_")
            self.instrument[name] = _InstrumentFreq(
                bandcenter_ghz=c.bandcenter_ghz,
                bandwidth_ghz=c.bandwidth_ghz,
                fwhm_arcmin=c.fwhm_arcmin,
                p_sens_ukarcmin=c.pol_sensitivity_channel_ukarcmin,
            )

    def _parse_instrument(self):
        if self.det_list:
            self._parse_instrument_from_det_list()
            return

        if self.ch_list:
            self._parse_instrument_from_ch_list()
            return

        if self.instrument:
            log.info("using the passed instrument to generate maps")
            if isinstance(self.instrument, dict):
                self.instrument = {
                    name: _InstrumentFreq.from_dict(value)
                    for (name, value) in self.instrument.items()
                }
            return

        try:
            config_inst = self.sim.parameters["instrument"]
        except (TypeError, KeyError):
            raise RuntimeError("you must specify a instrument/channel/detector for Mbs")

        custom_instrument = None

        try:
            custom_instrument = config_inst["custom_instrument"]
        except KeyError:
            self.imo_version = config_inst["imo_version"]
            try:
                self.telescopes = config_inst["telescopes"]
            except KeyError:
                try:
                    self.channels = config_inst["channels"]
                except KeyError:
                    log.info("you must pass a 'instrument' dictionary to Mbs()")

        if custom_instrument:
            if custom_instrument.endswith(".toml"):
                self.instrument = _InstrumentFreq.from_toml(filename=custom_instrument)
            elif custom_instrument.endswith(".npy"):
                self.instrument = _InstrumentFreq.from_npy(filename=custom_instrument)
            else:
                raise NameError("wrong instrument dictonary format")

            return

        self.instrument = {}
        if self.telescopes:
            channels = []
            for tel in self.telescopes:
                channels.append(
                    self.imo.query(
                        f"/releases/v{self.imo_version}/satellite/{tel}/instrument_info"
                    ).metadata["channel_names"]
                )
            channels = [item for sublist in channels for item in sublist]
        else:
            channels = self.channels

        for ch in channels:
            if "L" in ch:
                tel = "LFT"
            elif "M" in ch:
                tel = "MFT"
            elif "H" in ch:
                tel = "HFT"
            data_file = self.imo.query(
                f"/releases/v{self.imo_version}/satellite/{tel}/{ch}/channel_info"
            )
            bandcenter_ghz = data_file.metadata["bandcenter_ghz"]
            bandwidth_ghz = data_file.metadata["bandwidth_ghz"]
            fwhm_arcmin = data_file.metadata["fwhm_arcmin"]
            p_sens_ukarcmin = data_file.metadata["pol_sensitivity_channel_ukarcmin"]
            self.instrument[ch] = _InstrumentFreq(
                bandcenter_ghz=bandcenter_ghz,
                bandwidth_ghz=bandwidth_ghz,
                fwhm_arcmin=fwhm_arcmin,
                p_sens_ukarcmin=p_sens_ukarcmin,
            )

    def generate_noise(self):
        instr = self.instrument
        nmc_noise = self.params.nmc_noise
        nside = self.params.nside
        npix = hp.nside2npix(nside)
        root_dir = self.sim.base_path
        output_directory = root_dir / "noise"
        seed_noise = self.params.seed_noise
        n_split = self.params.n_split
        file_str = self.params.output_string
        channels = instr.keys()
        n_channels = len(channels)
        parallel = self.params.parallel_mc
        col_units = [self.params.units, self.params.units, self.params.units]
        saved_maps = []

        if parallel:
            comm = lbs.MPI_COMM_WORLD
            rank = comm.Get_rank()
            size = comm.Get_size()
        else:
            comm = None
            rank, size = 0, 1

        if rank == 0 and self.params.save:
            output_directory.mkdir(parents=True, exist_ok=True)

        nmc_noise = math.ceil(nmc_noise / size) * size
        if nmc_noise != self.params.nmc_noise:
            log.info(f"WARNING: setting nmc_noise = {nmc_noise}", rank)
        perrank = nmc_noise // size
        chnl_seed = 12
        if not self.params.save:
            noise_map_matrix = np.zeros((n_channels, 3, npix))
        else:
            noise_map_matrix = None

        for Nchnl, chnl in enumerate(channels):
            freq = instr[chnl].bandcenter_ghz
            chnl_seed += 67
            p_sens = instr[chnl].p_sens_ukarcmin * u.uK_CMB
            p_sens = p_sens.to_value(
                (self.pysm_units), equivalencies=u.cmb_equivalencies(freq * u.GHz)
            )
            P_rms = _from_sens_to_rms(p_sens, nside)
            T_rms = P_rms / np.sqrt(2.0)
            tot_rms = np.array([T_rms, P_rms, P_rms]).reshape(3, 1)
            for nmc in range(rank * perrank, (rank + 1) * perrank):
                if seed_noise:
                    np.random.seed(seed_noise + nmc + chnl_seed)
                nmc_str = f"{nmc:04d}"
                nmc_output_directory = output_directory / nmc_str
                if self.params.save:
                    nmc_output_directory.mkdir(parents=True, exist_ok=True)

                if n_split:
                    split_rms = tot_rms * np.sqrt(n_split)
                    noise_map = np.zeros((3, npix))
                    for hm in range(n_split):
                        noise_map_split = np.random.randn(3, npix) * split_rms
                        noise_map += noise_map_split
                        file_name = (
                            f"{chnl}_noise_SPLIT_{hm+1:04d}of{n_split:04d}"
                            + f"_{nmc_str}_{file_str}.fits"
                        )
                        cur_map_path = nmc_output_directory / file_name
                        lbs.write_healpix_map_to_file(
                            cur_map_path, noise_map_split, column_units=col_units
                        )
                        saved_maps.append(
                            MbsSavedMapInfo(
                                path=cur_map_path,
                                component="noise",
                                channel=chnl,
                                mc_num=nmc,
                            )
                        )
                    noise_map = noise_map / n_split
                else:
                    noise_map = np.random.randn(3, npix) * tot_rms
                if self.params.save:
                    file_name = f"{chnl}_noise_FULL_{nmc_str}_{file_str}.fits"
                    cur_map_path = nmc_output_directory / file_name
                    lbs.write_healpix_map_to_file(
                        cur_map_path, noise_map, column_units=col_units
                    )
                    saved_maps.append(
                        MbsSavedMapInfo(
                            path=cur_map_path, component="noise", channel=chnl
                        )
                    )
                else:
                    noise_map_matrix[Nchnl] = noise_map

        return (noise_map_matrix, saved_maps)

    def generate_cmb(self):
        instr = self.instrument
        nmc_cmb = self.params.nmc_cmb
        nside = self.params.nside
        npix = hp.nside2npix(nside)
        smooth = self.params.gaussian_smooth
        parallel = self.params.parallel_mc
        root_dir = self.sim.base_path
        output_directory = root_dir / "cmb"
        file_str = self.params.output_string
        channels = instr.keys()
        n_channels = len(channels)
        seed_cmb = self.params.seed_cmb
        cmb_ps_file = self.params.cmb_ps_file
        col_units = [self.params.units, self.params.units, self.params.units]
        saved_maps = []

        if parallel:
            comm = lbs.MPI_COMM_WORLD
            rank = comm.Get_rank()
            size = comm.Get_size()
        else:
            comm = None
            rank, size = 0, 1

        if rank == 0:
            output_directory.mkdir(parents=True, exist_ok=True)

        if cmb_ps_file:
            cl_cmb = hp.read_cl(cmb_ps_file)
        else:
            datautils_dir = Path(__file__).parent.parent / "datautils"

            cl_cmb_scalar = hp.read_cl(
                datautils_dir / "Cls_Planck2018_for_PTEP_2020_r0.fits"
            )
            cl_cmb_tensor = (
                hp.read_cl(
                    datautils_dir / "Cls_Planck2018_for_PTEP_2020_tensor_r1.fits"
                )
                * self.params.cmb_r
            )
            cl_cmb = cl_cmb_scalar + cl_cmb_tensor

        nmc_cmb = math.ceil(nmc_cmb / size) * size
        if nmc_cmb != self.params.nmc_cmb:
            log.info(f"setting nmc_cmb = {nmc_cmb}", rank)

        perrank = nmc_cmb // size

        if not self.params.save:
            cmb_map_matrix = np.zeros((n_channels, 3, npix))
        else:
            cmb_map_matrix = None

        for nmc in range(rank * perrank, (rank + 1) * perrank):
            if seed_cmb:
                np.random.seed(seed_cmb + nmc)
            nmc_str = f"{nmc:04d}"
            nmc_output_directory = output_directory / nmc_str
            if rank == 0:
                nmc_output_directory.mkdir(parents=True, exist_ok=True)
            cmb_temp = hp.synfast(cl_cmb, nside, new=True)
            file_name = f"cmb_{nmc_str}_{file_str}.fits"
            cur_map_path = nmc_output_directory / file_name
            lbs.write_healpix_map_to_file(
                cur_map_path, cmb_temp, column_units=col_units
            )
            saved_maps.append(
                MbsSavedMapInfo(path=cur_map_path, component="cmb", mc_num=nmc)
            )
            sky = pysm3.Sky(
                nside=nside,
                component_objects=[
                    pysm3.CMBMap(nside, map_IQU=(Path(cur_map_path)).absolute())
                ],
            )

            for Nchnl, chnl in enumerate(channels):
                freq = instr[chnl].bandcenter_ghz
                if self.params.bandpass_int:
                    band = instr[chnl].bandwidth_ghz
                    fmin = freq - band / 2.0
                    fmax = freq + band / 2.0
                    fsteps = int(np.ceil(fmax - fmin) + 1)
                    bandpass_frequencies = np.linspace(fmin, fmax, fsteps) * u.GHz
                    weights = np.ones(len(bandpass_frequencies))
                    cmb_map = sky.get_emission(bandpass_frequencies, weights)
                    cmb_map = cmb_map * pysm3.bandpass_unit_conversion(
                        bandpass_frequencies, weights, self.pysm_units
                    )
                else:
                    cmb_map = sky.get_emission(freq * u.GHz)
                    cmb_map = cmb_map.to(
                        self.pysm_units, equivalencies=u.cmb_equivalencies(freq * u.GHz)
                    )
                fwhm_arcmin = instr[chnl].fwhm_arcmin
                if smooth:
                    cmb_map_smt = hp.smoothing(
                        cmb_map,
                        fwhm=np.radians(fwhm_arcmin / 60.0),
                    )
                else:
                    cmb_map_smt = cmb_map
                if self.params.save:
                    file_name = f"{chnl}_cmb_{nmc_str}_{file_str}.fits"
                    cur_map_path = nmc_output_directory / file_name
                    lbs.write_healpix_map_to_file(
                        cur_map_path, cmb_map_smt, column_units=col_units
                    )
                    saved_maps.append(
                        MbsSavedMapInfo(path=cur_map_path, component="cmb")
                    )
                else:
                    cmb_map_matrix[Nchnl] = cmb_map_smt

        return (cmb_map_matrix, saved_maps)

    def generate_fg(self):
        parallel = self.params.parallel_mc
        instr = self.instrument
        nside = self.params.nside
        npix = hp.nside2npix(nside)
        smooth = self.params.gaussian_smooth
        root_dir = self.sim.base_path
        output_directory = root_dir / "foregrounds"
        file_str = self.params.output_string
        channels = instr.keys()
        n_channels = len(channels)
        fg_models = self.params.fg_models
        components = fg_models.keys()
        col_units = [self.params.units, self.params.units, self.params.units]
        saved_maps = []

        if parallel:
            comm = lbs.MPI_COMM_WORLD
            rank = comm.Get_rank()
        else:
            comm = None
            rank = 0

        if rank == 0 and self.params.save:
            output_directory.mkdir(parents=True, exist_ok=True)

        dict_fg = {}

        if rank != 0:
            if not self.params.save:
                return (dict_fg, saved_maps)
            else:
                return (None, saved_maps)

        for cmp in components:
            cmp_dir = output_directory / cmp
            if rank == 0 and self.params.save:
                cmp_dir.mkdir(parents=True, exist_ok=True)

            fg_config_file_name = fg_models[cmp]
            if ("lb" in fg_config_file_name) or ("pysm" in fg_config_file_name):
                fg_config_file_path = Path(__file__).parent / "fg_models"
                fg_config_file = fg_config_file_path / f"{fg_config_file_name}.cfg"
            else:
                fg_config_file = f"{fg_config_file_name}"

            sky = pysm3.Sky(nside=nside, component_config=fg_config_file)

            if not self.params.save:
                fg_map_matrix = np.zeros((n_channels, 3, npix))

            for Nchnl, chnl in enumerate(channels):
                freq = instr[chnl].bandcenter_ghz
                fwhm_arcmin = instr[chnl].fwhm_arcmin
                if self.params.bandpass_int:
                    band = instr[chnl].bandwidth_ghz
                    fmin = freq - band / 2.0
                    fmax = freq + band / 2.0
                    fsteps = int(np.ceil(fmax - fmin) + 1)
                    bandpass_frequencies = np.linspace(fmin, fmax, fsteps) * u.GHz
                    weights = np.ones(len(bandpass_frequencies))
                    sky_extrap = sky.get_emission(bandpass_frequencies, weights)
                    sky_extrap = sky_extrap * pysm3.bandpass_unit_conversion(
                        bandpass_frequencies, weights, self.pysm_units
                    )
                else:
                    sky_extrap = sky.get_emission(freq * u.GHz)
                    sky_extrap = sky_extrap.to(
                        self.pysm_units, equivalencies=u.cmb_equivalencies(freq * u.GHz)
                    )
                if smooth:
                    sky_extrap_smt = hp.smoothing(
                        sky_extrap,
                        fwhm=np.radians(fwhm_arcmin / 60.0),
                    )
                else:
                    sky_extrap_smt = sky_extrap
                if self.params.save:
                    file_name = f"{chnl}_{cmp}_{file_str}.fits"
                    cur_map_path = cmp_dir / file_name
                    lbs.write_healpix_map_to_file(
                        cur_map_path, sky_extrap_smt, column_units=col_units
                    )
                    saved_maps.append(
                        MbsSavedMapInfo(path=cur_map_path, component="fg", channel=chnl)
                    )
                else:
                    fg_map_matrix[Nchnl] = sky_extrap_smt
            if not self.params.save:
                dict_fg[cmp] = fg_map_matrix

        if not self.params.save:
            return (dict_fg, saved_maps)
        else:
            return (None, saved_maps)

    def write_coadded_maps(self, saved_maps):
        root_dir = self.sim.base_path
        fg_dir = root_dir / "foregrounds"
        cmb_dir = root_dir / "cmb"
        nside = self.params.nside
        file_str = self.params.output_string
        instr = self.instrument
        channels = instr.keys()
        coadd_dir = root_dir / "coadd_signal_maps"
        col_units = [self.params.units, self.params.units, self.params.units]

        coadd_dir.mkdir(parents=True, exist_ok=True)

        components = self.fg_models.keys()
        for chnl in channels:
            fg_tot = np.zeros((3, hp.nside2npix(nside)))
            for cmp in components:
                fg_dir_cmp = fg_dir / cmp
                fg_file_name = f"{chnl}_{cmp}_{file_str}.fits"
                try:
                    fg_cmp = hp.read_map(
                        fg_dir_cmp / fg_file_name,
                        (0, 1, 2),
                    )
                except IndexError:
                    fg_cmp = hp.read_map(fg_dir_cmp / fg_file_name)
                    fg_cmp = np.array(
                        [fg_cmp, np.zeros_like(fg_cmp), np.zeros_like(fg_cmp)]
                    )
                fg_tot += fg_cmp

            if cmb_dir.exists():
                nmc_cmb = self.params.nmc_cmb
                for nmc in range(nmc_cmb):
                    cmb_map_path = [
                        map_path
                        for map_path in saved_maps
                        if (map_path.mc_num == nmc) and (map_path.component == "cmb")
                    ]

                    assert len(cmb_map_path) == 1
                    cmb = hp.read_map(cmb_map_path[0], (0, 1, 2))
                    map_tot = fg_tot + cmb

                    nmc_str = f"{nmc:04d}"
                    nmc_dir = coadd_dir / nmc_str

                    nmc_dir.mkdir(parents=True, exist_ok=True)
                    tot_file_name = f"{chnl}_coadd_signal_map_{nmc_str}_{file_str}.fits"
                    tot_file_path = nmc_dir / tot_file_name
                    lbs.write_healpix_map_to_file(
                        tot_file_path, map_tot, column_units=col_units
                    )
                    saved_maps.append(
                        MbsSavedMapInfo(
                            path=tot_file_path,
                            component="coadded",
                            channel=chnl,
                            mc_num=nmc,
                        )
                    )
            else:
                tot_file_name = f"{chnl}_coadd_signal_map_{file_str}.fits"
                tot_file_path = coadd_dir / tot_file_name
                lbs.write_healpix_map_to_file(
                    tot_file_path, map_tot, column_units=col_units
                )
                saved_maps.append(
                    MbsSavedMapInfo(
                        path=tot_file_path, component="coadded", channel=chnl
                    )
                )

    def run_all(self):
        """Call PySM and generate the sky maps

        It returns a pair ``(maps, saved_maps)``, where ``maps`` is a
        dictionary associating channels to NumPy arrays containing the
        maps, and ``saved_maps`` is a list of
        :class:`.MbsSavedMapInfo` objects pointing to the FITS files
        containing the maps.
        """
        self._parse_instrument()
        rank = 0
        instr = self.instrument
        nside = self.params.nside
        npix = hp.nside2npix(nside)
        channels = instr.keys()
        n_channels = len(channels)
        saved_maps = []

        if not self.params.save:
            tot = np.zeros((n_channels, 3, npix))

        if self.params.parallel_mc:
            comm = lbs.MPI_COMM_WORLD
            rank = comm.Get_rank()

        if self.params.make_noise:
            log.info("generating and saving noise simulations")
            noise, noise_maps = self.generate_noise()
            saved_maps += noise_maps

            if not self.params.save:
                tot += noise

        if self.params.make_cmb:
            log.info("generating and saving cmb simulations")
            cmb, cmb_maps = self.generate_cmb()
            saved_maps += cmb_maps

            if not self.params.save:
                tot += cmb

        if self.params.make_fg:
            log.info("generating and saving fg simulations")
            fg, fg_maps = self.generate_fg()
            saved_maps += fg_maps

            if not self.params.save:
                for cmp in fg.keys():
                    tot += fg[cmp]

        if self.params.maps_in_ecliptic:
            r = hp.Rotator(coord=["G", "E"])

        if rank == 0:
            if self.params.save and self.params.coadd:
                log.info("saving coadded signal maps")
                self.write_coadded_maps(saved_maps)
            if not self.params.save:
                tot_dict = {}
                for nch, chnl in enumerate(channels):
                    if self.params.maps_in_ecliptic:
                        tot[nch] = r.rotate_map_alms(
                            tot[nch], lmax=4 * self.params.nside
                        )
                    tot_dict[chnl] = tot[nch]
                return (tot_dict, saved_maps)

        return (None, saved_maps)
