import re
from unittest.mock import Mock, patch

import numpy as np
import pandas as pd
import pytest

from rdt.transformers import CategoricalTransformer, OneHotEncodingTransformer

RE_SSN = re.compile(r'\d\d\d-\d\d-\d\d\d\d')


class TestCategoricalTransformer:

    def test___init__(self):
        """Passed arguments must be stored as attributes."""
        # Run
        transformer = CategoricalTransformer(
            anonymize='anonymize_value',
            fuzzy='fuzzy_value',
            clip='clip_value',
        )

        # Asserts
        assert transformer.anonymize == 'anonymize_value'
        assert transformer.fuzzy == 'fuzzy_value'
        assert transformer.clip == 'clip_value'

    def test__get_faker_anonymize_tuple(self):
        """If anonymize is a tuple, first value is the category, rest are arguments."""
        # Setup
        transformer = CategoricalTransformer(anonymize=('credit_card_number', 'visa'))

        # Run
        faker = transformer._get_faker()

        # Asserts
        assert callable(faker)

        fake = faker()
        assert isinstance(fake, str)
        assert len(fake) == 16

    def test__get_faker_anonymize_list(self):
        """If anonymize is a list, first value is the category, rest are arguments."""
        # Setup
        transformer = CategoricalTransformer(anonymize=['credit_card_number', 'visa'])

        # Run
        faker = transformer._get_faker()

        # Asserts
        assert callable(faker)

        fake = faker()
        assert isinstance(fake, str)
        assert len(fake) == 16

    def test__get_faker_anonymize_str(self):
        """If anonymize is a list, first value is the category, rest are arguments."""
        # Setup
        transformer = CategoricalTransformer(anonymize='ssn')

        # Run
        faker = transformer._get_faker()

        # Asserts
        assert callable(faker)

        fake = faker()
        assert isinstance(fake, str)
        assert RE_SSN.match(fake)

    def test__get_faker_anonymize_category_not_exist(self):
        # Setup
        transformer = CategoricalTransformer(anonymize='whatever')

        # Run
        with pytest.raises(ValueError):
            transformer._get_faker()

    def test__anonymize(self):
        # Setup
        transformer = CategoricalTransformer(anonymize='ssn')

        # Run
        data = pd.Series(['foo', 'bar', 'foo', 'tar'])
        result = transformer._anonymize(data)

        # Asserts
        assert len(result) == 4
        assert result.map(RE_SSN.match).astype(bool).all()

    def test__get_intervals(self):
        # Setup
        data = pd.Series(['bar', 'foo', 'foo', 'tar'])

        # Run
        result = CategoricalTransformer._get_intervals(data)

        # Asserts
        expected_intervals = {
            'foo': (0, 0.5, 0.25, 0.5 / 6),
            'tar': (0.5, 0.75, 0.625, 0.25 / 6),
            'bar': (0.75, 1, 0.875, 0.25 / 6)
        }
        assert result == expected_intervals

    def test_fit_no_anonymize(self):
        # Setup
        transformer = CategoricalTransformer()

        # Run
        data = np.array(['bar', 'foo', 'foo', 'tar'])
        transformer.fit(data)

        # Asserts
        expected_intervals = {
            'foo': (0, 0.5, 0.25, 0.5 / 6),
            'tar': (0.5, 0.75, 0.625, 0.25 / 6),
            'bar': (0.75, 1, 0.875, 0.25 / 6)
        }
        assert transformer.intervals == expected_intervals

    def test_fit_anonymize(self):
        # Setup
        transformer = CategoricalTransformer(anonymize='ssn')

        # Run
        data = np.array(['bar', 'foo', 'foo', 'tar'])
        transformer.fit(data)

        # Asserts
        expected_intervals = {
            (0, 0.5, 0.25, 0.5 / 6),
            (0.5, 0.75, 0.625, 0.25 / 6),
            (0.75, 1, 0.875, 0.25 / 6)
        }
        unexpected_keys = {'bar', 'foo', 'tar'}

        assert set(transformer.intervals.values()) == expected_intervals
        keys = transformer.intervals.keys()
        assert all((key not in unexpected_keys for key in keys))

    def test__get_value_no_fuzzy(self):
        # Run
        transformer = Mock()
        transformer.fuzzy = False
        transformer.intervals = {
            'foo': (0, 0.5, 0.25, 0.5 / 6),
        }

        result = CategoricalTransformer._get_value(transformer, 'foo')

        # Asserts
        assert result == 0.25

    @patch('scipy.stats.norm.rvs')
    def test__get_value_fuzzy(self, rvs_mock):
        # setup
        rvs_mock.return_value = 0.2745

        # Run
        transformer = Mock()
        transformer.fuzzy = True
        transformer.intervals = {
            'foo': (0, 0.5, 0.25, 0.5 / 6),
        }

        result = CategoricalTransformer._get_value(transformer, 'foo')

        # Asserts
        assert result == 0.2745

    @patch('rdt.transformers.categorical.MAPS', new_callable=dict)
    def test_transform_array_anonymize(self, mock_maps):
        # Setup
        transformer = CategoricalTransformer(anonymize='ssn')
        transformer.intervals = {
            'foo_x': (0, 0.5, 0.25, 0.5 / 6),
            'bar_x': (0.5, 0.75, 0.625, 0.25 / 6),
            'tar_x': (0.75, 1, 0.875, 0.25 / 6)
        }
        mock_maps[id(transformer)] = {
            'foo': 'foo_x',
            'bar': 'bar_x',
            'tar': 'tar_x'
        }

        # Run
        data = np.array(['foo', 'bar', 'tar'])
        result = transformer.transform(data)

        # Asserts
        assert list(result) == [0.25, 0.625, 0.875]

    def test__normalize_no_clip(self):
        """Test normalize data"""
        # Setup
        data = pd.Series([-0.43, 0.1234, 1.5, -1.31])

        transformer = Mock()
        transformer.clip = False

        # Run
        result = CategoricalTransformer._normalize(transformer, data)

        # Asserts
        expect = pd.Series([0.57, 0.1234, 0.5, 0.69], dtype=float)

        pd.testing.assert_series_equal(result, expect)

    def test__normalize_clip(self):
        """Test normalize data with clip=True"""
        # Setup
        data = pd.Series([-0.43, 0.1234, 1.5, -1.31])

        transformer = Mock()
        transformer.clip = True

        # Run
        result = CategoricalTransformer._normalize(transformer, data)

        # Asserts
        expect = pd.Series([0.0, 0.1234, 1.0, 0.0], dtype=float)

        pd.testing.assert_series_equal(result, expect)

    def test_reverse_transform_array(self):
        """Test reverse_transform a numpy.array"""
        # Setup
        data = np.array([-0.6, 0.2, 0.6, -0.2])
        normalized_data = pd.Series([0.4, 0.2, 0.6, 0.8])

        intervals = {
            'foo': (0, 0.5),
            'bar': (0.5, 0.75),
            'tar': (0.75, 1),
        }

        # Run
        transformer = Mock()
        transformer._normalize.return_value = normalized_data
        transformer.intervals = intervals

        result = CategoricalTransformer.reverse_transform(transformer, data)

        # Asserts
        expect = pd.Series(['foo', 'foo', 'bar', 'tar'])

        pd.testing.assert_series_equal(result, expect)


class TestOneHotEncodingTransformer:

    def test_fit_no_nans(self):
        # Setup
        ohet = OneHotEncodingTransformer()

        # Run
        data = pd.Series(['a', 'b', 'c'])
        ohet.fit(data)

        # Assert
        np.testing.assert_array_equal(ohet.dummies, ['a', 'b', 'c'])

    def test_fit_nans(self):
        # Setup
        ohet = OneHotEncodingTransformer()

        # Run
        data = pd.Series(['a', 'b', None])
        ohet.fit(data)

        # Assert
        np.testing.assert_array_equal(ohet.dummies, ['a', 'b', np.nan])

    def test_fit_single(self):
        # Setup
        ohet = OneHotEncodingTransformer()

        # Run
        data = pd.Series(['a', 'a', 'a'])
        ohet.fit(data)

        # Assert
        np.testing.assert_array_equal(ohet.dummies, ['a'])

    def test_transform_no_nans(self):
        # Setup
        ohet = OneHotEncodingTransformer()
        data = pd.Series(['a', 'b', 'c'])
        ohet.fit(data)

        # Run
        out = ohet.transform(data)

        # Assert
        expected = np.array([
            [1, 0, 0],
            [0, 1, 0],
            [0, 0, 1]
        ])
        np.testing.assert_array_equal(out, expected)

    def test_transform_nans(self):
        # Setup
        ohet = OneHotEncodingTransformer()
        data = pd.Series(['a', 'b', None])
        ohet.fit(data)

        # Run
        out = ohet.transform(data)

        # Assert
        expected = np.array([
            [1, 0, 0],
            [0, 1, 0],
            [0, 0, 1]
        ])
        np.testing.assert_array_equal(out, expected)

    def test_transform_single(self):
        # Setup
        ohet = OneHotEncodingTransformer()
        data = pd.Series(['a', 'a', 'a'])
        ohet.fit(data)

        # Run
        out = ohet.transform(data)

        # Assert
        expected = np.array([
            [1],
            [1],
            [1]
        ])
        np.testing.assert_array_equal(out, expected)

    def test_reverse_transform_no_nans(self):
        # Setup
        ohet = OneHotEncodingTransformer()
        data = pd.Series(['a', 'b', 'c'])
        ohet.fit(data)

        # Run
        transformed = np.array([
            [1, 0, 0],
            [0, 1, 0],
            [0, 0, 1]
        ])
        out = ohet.reverse_transform(transformed)

        # Assert
        expected = pd.Series(['a', 'b', 'c'])
        pd.testing.assert_series_equal(out, expected)

    def test_reverse_transform_nans(self):
        # Setup
        ohet = OneHotEncodingTransformer()
        data = pd.Series(['a', 'b', None])
        ohet.fit(data)

        # Run
        transformed = np.array([
            [1, 0, 0],
            [0, 1, 0],
            [0, 0, 1]
        ])
        out = ohet.reverse_transform(transformed)

        # Assert
        expected = pd.Series(['a', 'b', None])
        pd.testing.assert_series_equal(out, expected)

    def test_reverse_transform_single(self):
        # Setup
        ohet = OneHotEncodingTransformer()
        data = pd.Series(['a', 'a', 'a'])
        ohet.fit(data)

        # Run
        transformed = np.array([
            [1],
            [1],
            [1]
        ])
        out = ohet.reverse_transform(transformed)

        # Assert
        expected = pd.Series(['a', 'a', 'a'])
        pd.testing.assert_series_equal(out, expected)

    def test_reverse_transform_1d(self):
        # Setup
        ohet = OneHotEncodingTransformer()
        data = pd.Series(['a', 'a', 'a'])
        ohet.fit(data)

        # Run
        transformed = np.array([1, 1, 1])
        out = ohet.reverse_transform(transformed)

        # Assert
        expected = pd.Series(['a', 'a', 'a'])
        pd.testing.assert_series_equal(out, expected)
