# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from unittest.mock import Mock

import opentelemetry.propagators.b3 as b3_format  # pylint: disable=no-name-in-module,import-error
import opentelemetry.sdk.trace as trace
import opentelemetry.sdk.trace.id_generator as id_generator
import opentelemetry.trace as trace_api
from opentelemetry.context import get_current
from opentelemetry.propagators.textmap import DefaultGetter

FORMAT = b3_format.B3Format()


def get_child_parent_new_carrier(old_carrier):

    ctx = FORMAT.extract(old_carrier)
    parent_span_context = trace_api.get_current_span(ctx).get_span_context()

    parent = trace._Span("parent", parent_span_context)
    child = trace._Span(
        "child",
        trace_api.SpanContext(
            parent_span_context.trace_id,
            id_generator.RandomIdGenerator().generate_span_id(),
            is_remote=False,
            trace_flags=parent_span_context.trace_flags,
            trace_state=parent_span_context.trace_state,
        ),
        parent=parent.get_span_context(),
    )

    new_carrier = {}
    ctx = trace_api.set_span_in_context(child)
    FORMAT.inject(new_carrier, context=ctx)

    return child, parent, new_carrier


class TestB3Format(unittest.TestCase):
    # pylint: disable=too-many-public-methods

    @classmethod
    def setUpClass(cls):
        generator = id_generator.RandomIdGenerator()
        cls.serialized_trace_id = b3_format.format_trace_id(
            generator.generate_trace_id()
        )
        cls.serialized_span_id = b3_format.format_span_id(
            generator.generate_span_id()
        )
        cls.serialized_parent_id = b3_format.format_span_id(
            generator.generate_span_id()
        )

    def setUp(self) -> None:
        tracer_provider = trace.TracerProvider()
        patcher = unittest.mock.patch.object(
            trace_api, "get_tracer_provider", return_value=tracer_provider
        )
        patcher.start()
        self.addCleanup(patcher.stop)

    def test_extract_multi_header(self):
        """Test the extraction of B3 headers."""
        child, parent, new_carrier = get_child_parent_new_carrier(
            {
                FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
                FORMAT.SPAN_ID_KEY: self.serialized_span_id,
                FORMAT.PARENT_SPAN_ID_KEY: self.serialized_parent_id,
                FORMAT.SAMPLED_KEY: "1",
            }
        )

        self.assertEqual(
            new_carrier[FORMAT.TRACE_ID_KEY],
            b3_format.format_trace_id(child.context.trace_id),
        )
        self.assertEqual(
            new_carrier[FORMAT.SPAN_ID_KEY],
            b3_format.format_span_id(child.context.span_id),
        )
        self.assertEqual(
            new_carrier[FORMAT.PARENT_SPAN_ID_KEY],
            b3_format.format_span_id(parent.context.span_id),
        )
        self.assertTrue(parent.context.is_remote)
        self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1")

    def test_extract_single_header(self):
        """Test the extraction from a single b3 header."""
        child, parent, new_carrier = get_child_parent_new_carrier(
            {
                FORMAT.SINGLE_HEADER_KEY: "{}-{}".format(
                    self.serialized_trace_id, self.serialized_span_id
                )
            }
        )

        self.assertEqual(
            new_carrier[FORMAT.TRACE_ID_KEY],
            b3_format.format_trace_id(child.context.trace_id),
        )
        self.assertEqual(
            new_carrier[FORMAT.SPAN_ID_KEY],
            b3_format.format_span_id(child.context.span_id),
        )
        self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1")
        self.assertTrue(parent.context.is_remote)

        child, parent, new_carrier = get_child_parent_new_carrier(
            {
                FORMAT.SINGLE_HEADER_KEY: "{}-{}-1-{}".format(
                    self.serialized_trace_id,
                    self.serialized_span_id,
                    self.serialized_parent_id,
                )
            }
        )

        self.assertEqual(
            new_carrier[FORMAT.TRACE_ID_KEY],
            b3_format.format_trace_id(child.context.trace_id),
        )
        self.assertEqual(
            new_carrier[FORMAT.SPAN_ID_KEY],
            b3_format.format_span_id(child.context.span_id),
        )
        self.assertEqual(
            new_carrier[FORMAT.PARENT_SPAN_ID_KEY],
            b3_format.format_span_id(parent.context.span_id),
        )
        self.assertTrue(parent.context.is_remote)
        self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1")

    def test_extract_header_precedence(self):
        """A single b3 header should take precedence over multiple
        headers.
        """
        single_header_trace_id = self.serialized_trace_id[:-3] + "123"

        _, _, new_carrier = get_child_parent_new_carrier(
            {
                FORMAT.SINGLE_HEADER_KEY: "{}-{}".format(
                    single_header_trace_id, self.serialized_span_id
                ),
                FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
                FORMAT.SPAN_ID_KEY: self.serialized_span_id,
                FORMAT.SAMPLED_KEY: "1",
            }
        )

        self.assertEqual(
            new_carrier[FORMAT.TRACE_ID_KEY], single_header_trace_id
        )

    def test_enabled_sampling(self):
        """Test b3 sample key variants that turn on sampling."""
        for variant in ["1", "True", "true", "d"]:
            _, _, new_carrier = get_child_parent_new_carrier(
                {
                    FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
                    FORMAT.SPAN_ID_KEY: self.serialized_span_id,
                    FORMAT.SAMPLED_KEY: variant,
                }
            )

            self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1")

    def test_disabled_sampling(self):
        """Test b3 sample key variants that turn off sampling."""
        for variant in ["0", "False", "false", None]:
            _, _, new_carrier = get_child_parent_new_carrier(
                {
                    FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
                    FORMAT.SPAN_ID_KEY: self.serialized_span_id,
                    FORMAT.SAMPLED_KEY: variant,
                }
            )

            self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "0")

    def test_flags(self):
        """x-b3-flags set to "1" should result in propagation."""
        _, _, new_carrier = get_child_parent_new_carrier(
            {
                FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
                FORMAT.SPAN_ID_KEY: self.serialized_span_id,
                FORMAT.FLAGS_KEY: "1",
            }
        )

        self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1")

    def test_flags_and_sampling(self):
        """Propagate if b3 flags and sampling are set."""
        _, _, new_carrier = get_child_parent_new_carrier(
            {
                FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
                FORMAT.SPAN_ID_KEY: self.serialized_span_id,
                FORMAT.FLAGS_KEY: "1",
            }
        )

        self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1")

    def test_derived_ctx_is_returned_for_success(self):
        """Ensure returned context is derived from the given context."""
        old_ctx = {"k1": "v1"}
        new_ctx = FORMAT.extract(
            {
                FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
                FORMAT.SPAN_ID_KEY: self.serialized_span_id,
                FORMAT.FLAGS_KEY: "1",
            },
            old_ctx,
        )
        self.assertIn("current-span", new_ctx)
        for key, value in old_ctx.items():
            self.assertIn(key, new_ctx)
            self.assertEqual(new_ctx[key], value)

    def test_derived_ctx_is_returned_for_failure(self):
        """Ensure returned context is derived from the given context."""
        old_ctx = {"k2": "v2"}
        new_ctx = FORMAT.extract({}, old_ctx)
        self.assertNotIn("current-span", new_ctx)
        for key, value in old_ctx.items():
            self.assertIn(key, new_ctx)
            self.assertEqual(new_ctx[key], value)

    def test_64bit_trace_id(self):
        """64 bit trace ids should be padded to 128 bit trace ids."""
        trace_id_64_bit = self.serialized_trace_id[:16]

        _, _, new_carrier = get_child_parent_new_carrier(
            {
                FORMAT.TRACE_ID_KEY: trace_id_64_bit,
                FORMAT.SPAN_ID_KEY: self.serialized_span_id,
                FORMAT.FLAGS_KEY: "1",
            }
        )

        self.assertEqual(
            new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit
        )

    def test_extract_invalid_single_header(self):
        """Given unparsable header, do not modify context"""
        old_ctx = {}

        carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"}
        new_ctx = FORMAT.extract(carrier, old_ctx)

        self.assertDictEqual(new_ctx, old_ctx)

    def test_extract_missing_trace_id(self):
        """Given no trace ID, do not modify context"""
        old_ctx = {}

        carrier = {
            FORMAT.SPAN_ID_KEY: self.serialized_span_id,
            FORMAT.FLAGS_KEY: "1",
        }
        new_ctx = FORMAT.extract(carrier, old_ctx)

        self.assertDictEqual(new_ctx, old_ctx)

    def test_extract_invalid_trace_id(self):
        """Given invalid trace ID, do not modify context"""
        old_ctx = {}

        carrier = {
            FORMAT.TRACE_ID_KEY: "abc123",
            FORMAT.SPAN_ID_KEY: self.serialized_span_id,
            FORMAT.FLAGS_KEY: "1",
        }
        new_ctx = FORMAT.extract(carrier, old_ctx)

        self.assertDictEqual(new_ctx, old_ctx)

    def test_extract_invalid_span_id(self):
        """Given invalid span ID, do not modify context"""
        old_ctx = {}

        carrier = {
            FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
            FORMAT.SPAN_ID_KEY: "abc123",
            FORMAT.FLAGS_KEY: "1",
        }
        new_ctx = FORMAT.extract(carrier, old_ctx)

        self.assertDictEqual(new_ctx, old_ctx)

    def test_extract_missing_span_id(self):
        """Given no span ID, do not modify context"""
        old_ctx = {}

        carrier = {
            FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
            FORMAT.FLAGS_KEY: "1",
        }
        new_ctx = FORMAT.extract(carrier, old_ctx)

        self.assertDictEqual(new_ctx, old_ctx)

    def test_extract_empty_carrier(self):
        """Given no headers at all, do not modify context"""
        old_ctx = {}

        carrier = {}
        new_ctx = FORMAT.extract(carrier, old_ctx)

        self.assertDictEqual(new_ctx, old_ctx)

    @staticmethod
    def test_inject_empty_context():
        """If the current context has no span, don't add headers"""
        new_carrier = {}
        FORMAT.inject(new_carrier, get_current())
        assert len(new_carrier) == 0

    @staticmethod
    def test_default_span():
        """Make sure propagator does not crash when working with NonRecordingSpan"""

        class CarrierGetter(DefaultGetter):
            def get(self, carrier, key):
                return carrier.get(key, None)

        ctx = FORMAT.extract({}, getter=CarrierGetter())
        FORMAT.inject({}, context=ctx)

    def test_fields(self):
        """Make sure the fields attribute returns the fields used in inject"""

        tracer = trace.TracerProvider().get_tracer("sdk_tracer_provider")

        mock_setter = Mock()

        with tracer.start_as_current_span("parent"):
            with tracer.start_as_current_span("child"):
                FORMAT.inject({}, setter=mock_setter)

        inject_fields = set()

        for call in mock_setter.mock_calls:
            inject_fields.add(call[1][1])

        self.assertEqual(FORMAT.fields, inject_fields)

    def test_extract_none_context(self):
        """Given no trace ID, do not modify context"""
        old_ctx = None

        carrier = {}
        new_ctx = FORMAT.extract(carrier, old_ctx)
        self.assertIsNotNone(new_ctx)
        self.assertEqual(new_ctx["current-span"], trace_api.INVALID_SPAN)
