import numpy as np
import pandas as pd

NER_TRAINING_DATA = [
    (
        "what is SEMRUSH PRO? Can you run complex queries ? Can you identify "
        "active usage ?",
        {
            "entities": [
                (21, 32, "Questions About the Product"),
                (51, 67, "Questions About the Product"),
            ]
        },
    ),
    ("Thank you for your subscription renewal", {"entities": [(19, 39, "Renew")]}),
    (
        "you can upgrade your account for an old price,while you can upgrade your "
        "account for $399.95/month",
        {"entities": [(8, 28, "Potential Upsell"), (60, 80, "Potential Upsell")]},
    ),
    (
        "I like EMSI ordered the pro package",
        {"entities": [(12, 23, "Product Usage")]},
    ),
    (
        "Here you go, your account is created",
        {
            "entities": [
                (0, 11, "Action item accomplished"),
                (29, 36, "Action item accomplished"),
            ]
        },
    ),
]


NER_TEST_DATA = [
    ("Thank you for your subscription renewal", {"entities": [(32, 39, "Renew")]}),
]

NER_CLASS_LABELS = [
    "B-Questions About the Product",
    "B-Potential Upsell",
    "B-Action item accomplished",
    "B-Renew",
    "B-Product Usage",
    "I-Questions About the Product",
    "I-Potential Upsell",
    "I-Action item accomplished",
    "I-Renew",
    "I-Product Usage",
    "L-Questions About the Product",
    "L-Potential Upsell",
    "L-Action item accomplished",
    "L-Renew",
    "L-Product Usage",
    "U-Questions About the Product",
    "U-Potential Upsell",
    "U-Action item accomplished",
    "U-Renew",
    "U-Product Usage",
    "O",
]


class TestSpacyNerConstants:
    num_epochs = 2
    _num_gold_spans = 8

    gt_data = pd.DataFrame(
        data={
            "id": range(5),
            "split": ["training"] * 5,
            "text": [data_sample[0] for data_sample in NER_TRAINING_DATA],
            "text_token_indices": [
                np.array(
                    [
                        0,
                        4,
                        5,
                        7,
                        8,
                        15,
                        16,
                        19,
                        19,
                        20,
                        21,
                        24,
                        25,
                        28,
                        29,
                        32,
                        33,
                        40,
                        41,
                        48,
                        49,
                        50,
                        51,
                        54,
                        55,
                        58,
                        59,
                        67,
                        68,
                        74,
                        75,
                        80,
                        81,
                        82,
                    ]
                ),
                np.array([0, 5, 6, 9, 10, 13, 14, 18, 19, 31, 32, 39]),
                np.array(
                    [
                        0,
                        3,
                        4,
                        7,
                        8,
                        15,
                        16,
                        20,
                        21,
                        28,
                        29,
                        32,
                        33,
                        35,
                        36,
                        39,
                        40,
                        45,
                        45,
                        46,
                        46,
                        51,
                        52,
                        55,
                        56,
                        59,
                        60,
                        67,
                        68,
                        72,
                        73,
                        80,
                        81,
                        84,
                        85,
                        86,
                        86,
                        92,
                        92,
                        93,
                        93,
                        98,
                    ]
                ),
                np.array([0, 1, 2, 6, 7, 11, 12, 19, 20, 23, 24, 27, 28, 35]),
                np.array([0, 4, 5, 8, 9, 11, 11, 12, 13, 17, 18, 25, 26, 28, 29, 36]),
            ],
            "data_schema_version": [1] * 5,
        }
    )
    gt_embs = np.array(
        [
            [
                1.7881e00,
                2.9082e00,
                1.1357e00,
                3.4453e00,
                2.5449e00,
                2.5605e00,
                1.0840e-01,
                1.0332e00,
                5.9180e-01,
                1.1562e00,
                1.2793e00,
                1.4565e-02,
                1.0625e00,
                1.9912e00,
                -2.6562e-01,
                1.0039e00,
                -2.5293e00,
                -1.6670e00,
                2.4238e00,
                1.4521e00,
                -7.5098e-01,
                1.4775e00,
                2.5488e00,
                2.2422e00,
                1.1201e00,
                1.6670e00,
                3.0215e00,
                9.3994e-01,
                -9.0234e-01,
                5.8008e00,
                2.1602e00,
                8.4814e-01,
                -7.0752e-01,
                3.1714e-01,
                2.0781e00,
                2.4023e00,
                1.0361e00,
                5.6396e-01,
                1.6182e00,
                1.2051e00,
                -8.0176e-01,
                5.1172e00,
                3.0371e00,
                2.0371e00,
                -4.1235e-01,
                1.6035e00,
                5.1445e00,
                2.5488e00,
                2.5098e00,
                1.5586e00,
                -1.6510e-02,
                9.1113e-01,
                2.6660e00,
                1.8965e00,
                1.1904e00,
                2.0039e00,
                1.1650e00,
                2.5684e00,
                -5.6061e-02,
                1.0674e00,
                -2.0898e00,
                4.0508e00,
                1.5547e00,
                3.1519e-01,
            ],
            [
                2.2637e00,
                2.7773e00,
                1.9443e00,
                2.0371e00,
                9.2957e-02,
                3.1875e00,
                -4.7241e-01,
                -7.7246e-01,
                5.1484e00,
                -1.3350e00,
                -3.4448e-01,
                -3.2861e-01,
                3.7910e00,
                1.5176e00,
                7.0264e-01,
                1.6284e-01,
                -1.0791e00,
                3.4106e-01,
                1.0195e00,
                1.4941e00,
                -5.3613e-01,
                -6.0254e-01,
                4.5117e00,
                1.2588e00,
                3.0371e00,
                -4.6802e-01,
                1.4961e00,
                2.1543e00,
                -1.2773e00,
                6.8164e00,
                1.4385e00,
                3.0020e00,
                5.9229e-01,
                1.3438e00,
                2.1699e00,
                1.9639e00,
                -1.0020e00,
                1.9375e00,
                5.6299e-01,
                1.9717e00,
                2.1426e00,
                3.4473e00,
                2.7852e00,
                4.6445e00,
                -3.2031e00,
                1.0437e-01,
                4.4297e00,
                6.3750e00,
                6.4062e-01,
                1.4648e00,
                -2.7695e00,
                9.1064e-01,
                3.1699e00,
                1.0605e00,
                2.2129e00,
                4.2969e00,
                4.7900e-01,
                4.7632e-01,
                -1.9604e-01,
                6.1426e-01,
                -1.3506e00,
                -2.6245e-02,
                2.5273e00,
                -1.1121e-01,
            ],
            [
                1.4404e00,
                1.6855e00,
                1.8721e00,
                2.4597e-01,
                1.4893e00,
                2.2148e00,
                7.9102e-01,
                3.5371e00,
                5.3906e-01,
                -1.0898e00,
                1.3193e00,
                2.0752e-01,
                5.4297e-01,
                1.5264e00,
                1.9346e00,
                -1.2683e-01,
                -2.6699e00,
                2.7773e00,
                2.3145e00,
                4.0503e-01,
                -4.4653e-01,
                2.3499e-01,
                4.8789e00,
                1.6777e00,
                5.6592e-01,
                1.6055e00,
                1.2041e00,
                3.2051e00,
                -1.2366e-01,
                4.2305e00,
                2.1465e00,
                2.8262e00,
                1.2231e-01,
                1.0908e00,
                1.8672e00,
                3.7549e-01,
                1.4248e00,
                4.2148e00,
                4.9829e-01,
                0.0000e00,
                -1.0684e00,
                2.1992e00,
                1.7656e00,
                1.7207e00,
                2.7441e00,
                -2.0469e00,
                -1.3945e00,
                7.1133e00,
                4.8398e00,
                1.0215e00,
                -6.3428e-01,
                1.8027e00,
                6.9678e-01,
                -1.0771e00,
                -9.9414e-01,
                3.6426e-01,
                -1.5850e00,
                1.2568e00,
                -3.1421e-01,
                1.2529e00,
                5.3940e-03,
                1.9365e00,
                3.3652e00,
                -4.8169e-01,
            ],
            [
                3.1680e00,
                4.2852e00,
                1.6895e00,
                1.7676e00,
                1.0830e00,
                2.3926e00,
                1.6885e00,
                4.1675e-01,
                4.0000e00,
                -4.8169e-01,
                6.9971e-01,
                8.1445e-01,
                2.8594e00,
                6.6211e-01,
                3.0684e00,
                2.2422e00,
                -1.5557e00,
                -1.0059e00,
                1.6533e00,
                -1.2469e-01,
                -2.3022e-01,
                1.2275e00,
                2.5762e00,
                -8.0383e-02,
                2.7832e00,
                1.5020e00,
                2.2305e00,
                2.7852e00,
                -2.6688e-02,
                4.3164e00,
                1.9316e00,
                2.6440e-01,
                -5.9131e-01,
                3.8892e-01,
                1.8955e00,
                1.9883e00,
                1.5508e00,
                1.7051e00,
                -2.7026e-01,
                2.3477e00,
                -2.1152e00,
                3.5898e00,
                1.5449e00,
                1.3525e00,
                5.7080e-01,
                8.0713e-01,
                1.7363e00,
                3.9902e00,
                1.7031e00,
                2.8760e-01,
                1.2080e00,
                2.1680e00,
                1.8594e00,
                -2.8906e-01,
                2.0469e00,
                2.6367e00,
                1.3350e00,
                1.9473e00,
                4.5215e-01,
                2.1816e00,
                -2.0059e00,
                1.2627e00,
                2.9102e00,
                3.2251e-01,
            ],
            [
                3.0332e00,
                2.0430e00,
                1.1201e00,
                1.2646e00,
                3.6836e00,
                -2.3340e-01,
                7.4170e-01,
                1.8887e00,
                1.2061e00,
                -1.2959e00,
                1.7139e00,
                4.1602e-01,
                1.3730e00,
                2.1992e00,
                1.0469e00,
                1.5957e00,
                -3.8672e00,
                -3.7354e-01,
                2.3340e00,
                1.0830e00,
                1.8525e00,
                1.4004e00,
                3.6211e00,
                1.7471e00,
                5.3271e-01,
                7.2119e-01,
                1.6709e00,
                1.1143e00,
                -5.7080e-01,
                5.0938e00,
                1.0254e00,
                2.1113e00,
                -7.4219e-01,
                1.0215e00,
                5.1953e00,
                1.3525e00,
                1.6240e00,
                1.3232e00,
                9.9902e-01,
                3.4082e00,
                -1.6953e00,
                4.2695e00,
                2.9883e00,
                2.2852e00,
                1.5322e00,
                6.4209e-01,
                3.3965e00,
                3.8574e00,
                4.1992e00,
                3.4155e-01,
                9.3359e-01,
                -2.3474e-01,
                1.8047e00,
                7.7783e-01,
                1.1006e00,
                5.4297e-01,
                1.3513e-01,
                1.0527e00,
                3.3325e-01,
                1.1934e00,
                -1.8145e00,
                -3.1372e-02,
                2.3945e00,
                -8.0518e-01,
            ],
            [
                5.1172e00,
                2.3711e00,
                1.3086e00,
                9.2969e-01,
                5.8691e-01,
                -4.8779e-01,
                1.2783e00,
                1.0127e00,
                3.4805e00,
                1.6016e00,
                2.2637e00,
                1.5723e00,
                2.4414e-01,
                -2.5859e00,
                2.1055e00,
                5.4023e00,
                -1.2422e00,
                -1.5547e00,
                3.2949e00,
                8.2324e-01,
                -3.1270e00,
                1.0459e00,
                3.3945e00,
                -9.0820e-02,
                2.2344e00,
                3.0977e00,
                3.6934e00,
                1.8291e00,
                -2.5977e-01,
                2.8672e00,
                5.7373e-01,
                -3.1519e-01,
                -5.5127e-01,
                5.8594e-01,
                1.4707e00,
                -2.6001e-01,
                -3.5596e-01,
                3.6172e00,
                9.2432e-01,
                1.3877e00,
                3.5156e00,
                1.8525e00,
                2.3867e00,
                2.6289e00,
                3.7617e00,
                2.4199e00,
                2.0645e00,
                3.5527e00,
                2.8027e00,
                3.6406e00,
                -4.0015e-01,
                6.2646e-01,
                8.3496e-01,
                1.7227e00,
                4.6055e00,
                1.0605e00,
                5.2031e00,
                1.7881e00,
                7.9590e-01,
                -2.1033e-01,
                -6.2598e-01,
                1.3135e00,
                2.0586e00,
                -5.6055e-01,
            ],
            [
                9.7705e-01,
                3.4941e00,
                9.8145e-01,
                7.1484e-01,
                -3.0029e-02,
                3.1230e00,
                1.0254e00,
                6.5723e-01,
                1.6758e00,
                -7.9150e-01,
                6.1572e-01,
                2.6641e00,
                3.4619e-01,
                -1.3193e00,
                5.7129e-02,
                5.0391e-01,
                -1.2734e00,
                -1.2581e-02,
                1.5703e00,
                9.5752e-01,
                -8.1885e-01,
                -2.9248e-01,
                2.5273e00,
                6.5332e-01,
                1.0049e00,
                1.0176e00,
                1.6182e00,
                2.2383e00,
                -5.9912e-01,
                4.4609e00,
                2.7383e00,
                9.4775e-01,
                -2.0645e00,
                2.7051e00,
                1.6885e00,
                1.7676e00,
                1.1699e00,
                1.1592e00,
                7.8223e-01,
                2.4492e00,
                -3.7524e-01,
                2.7324e00,
                2.0488e00,
                2.0469e00,
                1.8787e-01,
                8.8281e-01,
                -1.8823e-01,
                5.3672e00,
                2.4160e00,
                -1.7566e-01,
                2.3657e-01,
                3.9038e-01,
                2.8008e00,
                -1.3506e00,
                2.5391e00,
                6.9336e-01,
                2.3767e-01,
                4.0039e00,
                -3.0737e-01,
                -6.8066e-01,
                -1.1738e00,
                8.0371e-01,
                1.9463e00,
                3.9941e-01,
            ],
            [
                3.5117e00,
                2.9004e00,
                0.0000e00,
                8.6963e-01,
                1.4893e00,
                2.1738e00,
                -1.5469e00,
                9.6875e-01,
                0.0000e00,
                7.3340e-01,
                -1.5293e00,
                2.8301e00,
                -6.3770e-01,
                0.0000e00,
                3.2031e00,
                0.0000e00,
                -1.7520e00,
                0.0000e00,
                3.9805e00,
                3.5859e00,
                -2.0935e-01,
                2.0898e00,
                4.8853e-01,
                4.4570e00,
                1.2021e00,
                1.5215e00,
                3.9609e00,
                3.1323e-01,
                -3.5898e00,
                5.1445e00,
                3.1562e00,
                -4.5859e00,
                -2.1309e00,
                2.8301e00,
                -1.5439e00,
                0.0000e00,
                1.7354e00,
                3.3936e-01,
                -6.6113e-01,
                1.4150e00,
                -2.7949e00,
                4.7227e00,
                1.1826e00,
                -1.0518e00,
                -7.0850e-01,
                8.6816e-01,
                -1.3350e00,
                4.6797e00,
                3.2285e00,
                1.8457e00,
                1.5186e00,
                2.8691e00,
                2.8652e00,
                0.0000e00,
                3.5312e00,
                0.0000e00,
                -2.5293e-01,
                -1.4023e00,
                2.6777e00,
                2.4570e00,
                1.7314e00,
                0.0000e00,
                2.8301e00,
                -4.5142e-01,
            ],
        ],
        dtype=np.float16,
    )

    gt_probs = pd.DataFrame(
        data={
            "sample_id": [0, 0, 1, 2, 2, 3, 4, 4],
            "split": ["training"] * _num_gold_spans,
            "epoch": [num_epochs - 1] * _num_gold_spans,
            "is_gold": [True] * _num_gold_spans,
            "is_pred": [False] * _num_gold_spans,
            "span_start": [5, 11, 4, 2, 13, 3, 0, 7],
            "span_end": [8, 14, 6, 5, 16, 5, 3, 8],
            "gold": [
                "Questions About the Product",
                "Questions About the Product",
                "Renew",
                "Potential Upsell",
                "Potential Upsell",
                "Product Usage",
                "Action item accomplished",
                "Action item accomplished",
            ],
            "pred": [""] * _num_gold_spans,
            "data_error_potential": [
                0.509,
                0.5088,
                0.5083,
                0.508,
                0.5085,
                0.5125,
                0.5119,
                0.5092,
            ],
            "galileo_error_type": ["missed_label"] * _num_gold_spans,
        }
    )
