#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Dict, Optional, Sequence

from refinery.units.blockwise import Arg, BlockTransformation
from refinery.lib.tools import isbuffer


class map(BlockTransformation):
    """
    Each block of the input data which occurs as a block of the index argument
    is replaced by the corresponding block of the image argument.
    """
    _map: Optional[Dict[int, int]]

    def __init__(
        self,
        index: Arg.NumSeq(help='index characters'),
        image: Arg.NumSeq(help='image characters'),
        blocksize=1
    ):
        super().__init__(blocksize=blocksize, index=index, image=image)
        self._map = None

    def process(self, data):
        index: Sequence[int] = self.args.index
        image: Sequence[int] = self.args.image
        if not self.bytestream:
            if isbuffer(index):
                self.log_info(F'chunking index sequence into blocks of size {self.args.blocksize}')
                index = list(self.chunk(index))
                self.log_debug(F'index sequence: {index}')
            if isbuffer(image):
                self.log_info(F'chunking image sequence into blocks of size {self.args.blocksize}')
                image = list(self.chunk(image))
                self.log_debug(F'image sequence: {image}')
        if len(set(index)) != len(index):
            raise ValueError('The index sequence contains duplicates.')
        if len(index) > len(image):
            raise ValueError('The index sequence is longer than the image sequence.')
        if self.bytestream:
            mapping = dict(zip(index, image))
            mapping = bytes(mapping.get(c, c) for c in range(0x100))
            if not isinstance(data, bytearray):
                data = bytearray(data)
            data[:] = (mapping[b] for b in data)
            return data
        try:
            self._map = dict(zip(index, image))
            return super().process(data)
        finally:
            self._map = None

    def process_block(self, token):
        return self._map.get(token, token)
