"""A custom hash function implementation that properly supports pytorch."""
import io
import sys

from joblib.hashing import Hasher, NumpyHasher


class NoMemoizeHasher(Hasher):
    """A joblib hasher with all memoize features disabled."""

    def memoize(self, obj):
        # We skip memoization here, as it can cause some issues with nested objects.
        # https://github.com/joblib/joblib/issues/1283
        # In particular with the way cloning is implemented in tpcp, such bugs might occur more often than in other
        # applications.
        # My understanding is that memoization is used to reduce the size of the output.
        # Not using memoization is actually faster, but will fail with self referential objects.
        # (https://docs.python.org/3/library/pickle.html#pickle.Pickler.fast).
        # In these cases hashing will throw an `RecursionError`.
        # So it seems like it is a tradeoff between the two issues.
        # For now, we accept the recursion issue, as I think this might happen less often by accident.
        return

    def hash(self, obj, return_digest=True):
        """Get hash while handling some edgecases.

        Namely, this implementation fixes the following issues:

        - Because we skip memoization, we need to handle the case where the object is self-referential.
          We just catch the error and raise a more descriptive error message.
        - We need to handle the case where the object is defined in the `__main__` module.
          For some reason, this can lead to pickle issues.
          Based on some information I found, this should not happen, but it still does...
          To fix it, we detect, when an object is defined in `__main__` and temporarily add it to the "real" module
          representing the main function.
          Afterwards we do some cleanup.
          Not sure if really required, but it seems to work.
          Overall very hacky, but I don't see a better way to fix this.

        """
        modules_modified = []
        if getattr(obj, "__module__", None) == "__main__":
            try:
                name = obj.__name__
            except AttributeError:
                name = obj.__class__.__name__
            mod = sys.modules["__main__"]
            if not hasattr(mod, name):
                modules_modified.append((mod, name))
                setattr(mod, name, obj)
        try:
            return super().hash(obj, return_digest)
        except RecursionError as e:
            raise ValueError(
                "The custom hasher used in tpcp does not support hashing of self-referential objects."
            ) from e
        finally:
            # Remove all new entries made to the main module.
            for mod, name in modules_modified:
                delattr(mod, name)


class NoMemoizeNumpyHasher(NumpyHasher, NoMemoizeHasher):
    """A joblib numpy hasher with all memoize features disabled."""


class TorchHasher(NoMemoizeNumpyHasher):
    """A hasher that can handle torch models.

    Under the hood this uses the new implementation of `torch.save` to serialize the model.
    This produces consistent output.

    Note: I never did any performance checks with large models.
    """

    def __init__(self, hash_name="md5", coerce_mmap=False):
        super().__init__(hash_name, coerce_mmap)
        import torch  # noqa: import-outside-toplevel

        self.torch = torch

    def save(self, obj):
        if isinstance(obj, (self.torch.nn.Module, self.torch.Tensor)):
            b = bytes()
            buffer = io.BytesIO(b)
            self.torch.save(obj, buffer)
            self._hash.update(b)
            return
        NumpyHasher.save(self, obj)


# This function is modified based on
# https://github.com/joblib/joblib/blob/4dafaff788a3b5402acfed091558b4c511982959/joblib/hashing.py#L244
def custom_hash(obj, hash_name="md5", coerce_mmap=False):
    """Quick calculation of a hash to identify uniquely Python objects containing numpy arrays and torch models.

    This function is modified based on `joblib.hash` so that it can properly handle torch models.

    Parameters
    ----------
    obj
        The object to be hashed
    hash_name: 'md5' or 'sha1'
        Hashing algorithm used. sha1 is supposedly safer, but md5 is faster.
    coerce_mmap: boolean
        Make no difference between np.memmap and np.ndarray

    """
    valid_hash_names = ("md5", "sha1")
    if hash_name not in valid_hash_names:
        raise ValueError(f"Valid options for 'hash_name' are {valid_hash_names}. Got hash_name={hash_name!r} instead.")
    if "torch" in sys.modules:
        hasher = TorchHasher(hash_name=hash_name, coerce_mmap=coerce_mmap)
    elif "numpy" in sys.modules:
        hasher = NoMemoizeNumpyHasher(hash_name=hash_name, coerce_mmap=coerce_mmap)
    else:
        hasher = NoMemoizeHasher(hash_name=hash_name)
    return hasher.hash(obj)
