"""
This module provides a set of functions to manage the global tracer provider for MLflow tracing.

Every tracing operation in MLflow *MUST* be managed through this module, instead of directly
using the OpenTelemetry APIs. This is because MLflow needs to control the initialization of the
tracer provider and ensure that it won't interfere with the other external libraries that might
use OpenTelemetry e.g. PromptFlow, Snowpark.
"""

import contextvars
import functools
import json
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING

from opentelemetry import context as context_api
from opentelemetry import trace
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import SpanProcessor, TracerProvider
from opentelemetry.sdk.trace.sampling import TraceIdRatioBased

import mlflow
from mlflow.entities.trace_location import (
    MlflowExperimentLocation,
    TraceLocationBase,
    UCSchemaLocation,
)
from mlflow.environment_variables import (
    MLFLOW_TRACE_ENABLE_OTLP_DUAL_EXPORT,
    MLFLOW_TRACE_SAMPLING_RATIO,
    MLFLOW_USE_DEFAULT_TRACER_PROVIDER,
)
from mlflow.exceptions import MlflowException, MlflowTracingException
from mlflow.tracing.config import reset_config
from mlflow.tracing.constant import SpanAttributeKey
from mlflow.tracing.destination import TraceDestination, UserTraceDestinationRegistry
from mlflow.tracing.utils.exception import raise_as_trace_exception
from mlflow.tracing.utils.once import Once
from mlflow.tracing.utils.otlp import (
    get_otlp_exporter,
    should_export_otlp_metrics,
    should_use_otlp_exporter,
)
from mlflow.tracing.utils.warning import suppress_warning
from mlflow.utils.databricks_utils import (
    is_in_databricks_model_serving_environment,
    is_mlflow_tracing_enabled_in_model_serving,
)

if TYPE_CHECKING:
    from mlflow.entities import Span


# A trace destination specified by the user via the `set_destination` function.
_MLFLOW_TRACE_USER_DESTINATION = UserTraceDestinationRegistry()


_logger = logging.getLogger(__name__)


class _TracerProviderWrapper:
    """
    A facade for the tracer provider.
    MLflow uses different tracer providers depending on the MLFLOW_USE_DEFAULT_TRACER_PROVIDER
    environment variable setting.
    1. Use an isolated tracer provider instance managed by MLflow. This is the default behavior such
       that MLflow does not break an environment where MLflow and OpenTelemetry SDK are used in
       different purposes.
    2. Use the global OpenTelemetry tracer provider singleton and traces created by MLflow and
        OpenTelemetry SDK will be exported to the same destination.
    """

    def __init__(self):
        self._isolated_tracer_provider = None
        self._isolated_tracer_provider_once = Once()

    @property
    def once(self) -> Once:
        if MLFLOW_USE_DEFAULT_TRACER_PROVIDER.get():
            return self._isolated_tracer_provider_once
        return trace._TRACER_PROVIDER_SET_ONCE

    def get(self) -> TracerProvider:
        if MLFLOW_USE_DEFAULT_TRACER_PROVIDER.get():
            return self._isolated_tracer_provider
        return trace.get_tracer_provider()

    def set(self, tracer_provider: TracerProvider):
        if MLFLOW_USE_DEFAULT_TRACER_PROVIDER.get():
            self._isolated_tracer_provider = tracer_provider
        else:
            # Bypass the once flag otherwise the update will be ignored.
            # We check the once flag inside `get_or_init_tracer`. For other cases, trace provider
            # should be forcibly updated.
            trace._TRACER_PROVIDER = tracer_provider

    def get_or_init_tracer(self, module_name: str) -> trace.Tracer:
        self.once.do_once(_initialize_tracer_provider)
        return self.get().get_tracer(module_name)

    def reset(self):
        if MLFLOW_USE_DEFAULT_TRACER_PROVIDER.get():
            self._isolated_tracer_provider = None
            self._isolated_tracer_provider_once._done = False
        else:
            trace._TRACER_PROVIDER = None
            trace._TRACER_PROVIDER_SET_ONCE._done = False


provider = _TracerProviderWrapper()


def start_span_in_context(name: str, experiment_id: str | None = None) -> trace.Span:
    """
    Start a new OpenTelemetry span in the current context.

    Note that this function doesn't set the started span as the active span in the context. To do
    that, the upstream also need to call `use_span()` function in the OpenTelemetry trace APIs.

    Args:
        name: The name of the span.
        experiment_id: The ID of the experiment to log the span to. If not specified, the span will
            be logged to the active experiment or explicitly set trace destination.

    Returns:
        The newly created OpenTelemetry span.
    """
    attributes = {}
    if experiment_id:
        attributes[SpanAttributeKey.EXPERIMENT_ID] = json.dumps(experiment_id)
    span = _get_tracer(__name__).start_span(name, attributes=attributes)

    if experiment_id and getattr(span, "_parent", None):
        _logger.warning(
            "The `experiment_id` parameter can only be used for root spans, but the span "
            f"`{name}` is not a root span. The specified value `{experiment_id}` will be ignored."
        )
        span._span.attributes.pop(SpanAttributeKey.EXPERIMENT_ID, None)
    return span


def start_detached_span(
    name: str,
    parent: trace.Span | None = None,
    experiment_id: str | None = None,
    start_time_ns: int | None = None,
) -> tuple[str, trace.Span] | None:
    """
    Start a new OpenTelemetry span that is not part of the current trace context, but with the
    explicit parent span ID if provided.

    Args:
        name: The name of the span.
        parent: The parent OpenTelemetry span. If not provided, the span will be created as a root
                span.
        experiment_id: The ID of the experiment. This is used to associate the span with a specific
            experiment in MLflow.
        start_time_ns: The start time of the span in nanoseconds.
            If not provided, the current timestamp is used.

    Returns:
        The newly created OpenTelemetry span.
    """
    tracer = _get_tracer(__name__)
    context = trace.set_span_in_context(parent) if parent else None
    attributes = {}

    # Set start time and experiment to attribute so we can pass it to the span processor
    if start_time_ns:
        attributes[SpanAttributeKey.START_TIME_NS] = json.dumps(start_time_ns)
    if experiment_id:
        attributes[SpanAttributeKey.EXPERIMENT_ID] = json.dumps(experiment_id)
    span = tracer.start_span(name, context=context, attributes=attributes, start_time=start_time_ns)

    if experiment_id and getattr(span, "_parent", None):
        _logger.warning(
            "The `experiment_id` parameter can only be used for root spans, but the span "
            f"`{name}` is not a root span. The specified value `{experiment_id}` will be ignored."
        )
        span._span.attributes.pop(SpanAttributeKey.EXPERIMENT_ID, None)
    return span


@contextmanager
def safe_set_span_in_context(span: "Span"):
    """
    A context manager that sets the given OpenTelemetry span as the active span in the current
    context.

    Args:
        span: An MLflow span object to set as the active span.

    Example:

    .. code-block:: python

        import mlflow


        with mlflow.start_span("my_span") as span:
            span.set_attribute("my_key", "my_value")

        # The span is automatically detached from the context when the context manager exits.
    """
    token = set_span_in_context(span)
    try:
        yield
    finally:
        detach_span_from_context(token)


def set_span_in_context(span: "Span") -> contextvars.Token:
    """
    Set the given OpenTelemetry span as the active span in the current context.

    Args:
        span: An MLflow span object to set as the active span.

    Returns:
        A token object that will be required when detaching the span from the context.
    """
    context = trace.set_span_in_context(span._span)
    token = context_api.attach(context)
    return token  # noqa: RET504


def detach_span_from_context(token: contextvars.Token):
    """
    Remove the active span from the current context.

    Args:
        token: The token returned by `_set_span_to_active` function.
    """
    context_api.detach(token)


def set_destination(destination: TraceLocationBase, *, context_local: bool = False):
    """
    Set a custom span location to which MLflow will export the traces.

    A destination specified by this function will take precedence over
    other configurations, such as tracking URI, OTLP environment variables.

    Args:
        destination: A trace location object that specifies the location of the trace data.
            Currently, the following locations are supported:

            - :py:class:`~mlflow.entities.trace_location.MlflowExperimentLocation`: Logs traces to
                an MLflow experiment.
            - :py:class:`~mlflow.entities.trace_location.UCSchemaLocation`: Logs traces to a
                Databricks Unity Catalog schema. Only available in Databricks.

        context_local: If False (default), the destination is set globally. If True, the destination
            is isolated per async task or thread, providing isolation in concurrent applications.

    Example:

        **Logging traces to MLflow Experiment:**

        .. code-block:: python

            from mlflow.entities.trace_location import MlflowExperimentLocation

            mlflow.tracing.set_destination(MlflowExperimentLocation(experiment_id="123"))

        Note: This has the same effect as setting the active MLflow experiment via the
        ``MLFLOW_EXPERIMENT_ID`` environment variable or the ``mlflow.set_experiment`` API,
        but with narrower scope.

        **Logging traces to Databricks Unity Catalog:**

        .. code-block:: python

            from mlflow.entities.trace_location import UCSchemaLocation

            mlflow.tracing.set_destination(
                UCSchemaLocation(catalog_name="catalog", schema_name="schema")
            )

        **Isolate the destination between async tasks or threads:**

        .. code-block:: python

            from mlflow.tracing.destination import Databricks

            mlflow.tracing.set_destination(
                MlflowExperimentLocation(experiment_id="123"),
                context_local=True,
            )

        The destination set with the ``context_local`` flag will only be effective within the
        current async task or thread. This is particularly useful when you want to send traces
        to different destinations from a multi-tenant application.

        ** Reset the destination:**

        .. code-block:: python

            mlflow.tracing.reset()

    """
    if isinstance(destination, TraceDestination):
        # NB: Deprecation warnings are issued in the constructor of the destination classes
        # so we don't need to issue a warning here.
        destination = destination.to_location()

    if not isinstance(destination, TraceLocationBase):
        raise MlflowException.invalid_parameter_value(
            f"Invalid destination type: {type(destination)}. "
            "The destination must be an instance of TraceLocation."
        )

    if isinstance(destination, UCSchemaLocation) and (
        mlflow.get_tracking_uri() is None or not mlflow.get_tracking_uri().startswith("databricks")
    ):
        mlflow.set_tracking_uri("databricks")
        _logger.info(
            "Automatically setting the tracking URI to `databricks` "
            "because the tracing destination is set to Databricks."
        )

    _MLFLOW_TRACE_USER_DESTINATION.set(destination, context_local=context_local)
    _initialize_tracer_provider()


def _get_tracer(module_name: str) -> trace.Tracer:
    """
    Get a tracer instance for the given module name.

    If the tracer provider is not initialized, this function will initialize the tracer provider.
    Other simultaneous calls to this function will block until the initialization is done.
    """
    return provider.get_or_init_tracer(module_name)


def _get_trace_exporter():
    """
    Get the exporter instance that is used by the current tracer provider.
    """
    if tracer_provider := provider.get():
        processors = tracer_provider._active_span_processor._span_processors
        # There should be only one processor used for MLflow tracing
        processor = processors[0]
        return processor.span_exporter


def _initialize_tracer_provider(disabled=False):
    """
    Instantiate a tracer provider and set it as the global tracer provider.

    Note that this function ALWAYS updates the global tracer provider, regardless of the current
    state. It is the caller's responsibility to ensure that the tracer provider is initialized
    only once, and update the _INITIALIZED flag accordingly.
    """
    processors = _get_span_processors(disabled=disabled)
    if not processors:
        provider.set(trace.NoOpTracerProvider())
        return

    # Demote the "Failed to detach context" log raised by the OpenTelemetry logger to DEBUG
    # level so that it does not show up in the user's console. This warning may indicate
    # some incorrect context handling, but in many cases just false positive that does not
    # cause any issue in the generated trace.
    # Note that we need to apply it permanently rather than just the scope of prediction call,
    # because the exception can happen for streaming case, where the error log might be
    # generated when the iterator is consumed and we don't know when it will happen.
    suppress_warning("opentelemetry.context", "Failed to detach context")

    # These Otel warnings are occasionally raised in a valid case, e.g. we timeout a trace
    # but some spans are still active. We suppress them because they are not actionable.
    suppress_warning("opentelemetry.sdk.trace", "Setting attribute on ended span")
    suppress_warning("opentelemetry.sdk.trace", "Calling end() on an ended span")

    # NB: If otel resource env vars are set explicitly, don't create an empty resource
    # so that they are propagated to otel spans.
    otel_service_name = os.getenv("OTEL_SERVICE_NAME")
    otel_resource_attributes = os.getenv("OTEL_RESOURCE_ATTRIBUTES")
    resource = None
    if not otel_service_name and not otel_resource_attributes:
        # Setting an empty resource to avoid triggering resource aggregation, which causes
        # an issue in LiteLLM tracing: https://github.com/mlflow/mlflow/issues/16296
        # Add telemetry resource: https://opentelemetry.io/docs/specs/semconv/resource/#telemetry-sdk
        resource = Resource(
            {
                "telemetry.sdk.language": "python",
                "telemetry.sdk.name": "mlflow",
                "telemetry.sdk.version": mlflow.__version__,
            }
        )
    tracer_provider = TracerProvider(resource=resource, sampler=_get_trace_sampler())
    for processor in processors:
        tracer_provider.add_span_processor(processor)

    provider.set(tracer_provider)


def _get_trace_sampler() -> TraceIdRatioBased | None:
    """
    Get the sampler configuration based on environment variable.

    Returns:
        TraceIdRatioBased sampler or None for default sampling.
    """
    sampling_ratio = MLFLOW_TRACE_SAMPLING_RATIO.get()
    if sampling_ratio is not None:
        if not (0.0 <= sampling_ratio <= 1.0):
            _logger.warning(
                f"{MLFLOW_TRACE_SAMPLING_RATIO} must be between 0.0 and 1.0, got {sampling_ratio}. "
                "Ignoring the invalid value and using default sampling (1.0)."
            )
            return None
        return TraceIdRatioBased(sampling_ratio)
    return None


def _get_span_processors(disabled: bool = False) -> list[SpanProcessor]:
    """
    Get the list of span processors based on configuration.

    Args:
        disabled: If True, returns an empty list of processors because tracing is disabled.

    Returns:
        List of span processors to be added to the TracerProvider.
    """
    if disabled:
        return []

    processors = []

    if should_use_otlp_exporter():
        from mlflow.tracing.processor.otel import OtelSpanProcessor

        exporter = get_otlp_exporter()
        otel_processor = OtelSpanProcessor(
            span_exporter=exporter,
            # Only export metrics from the Otel processor if dual export is not enabled. Otherwise,
            # both Otel and MLflow processors will export metrics, causing duplication
            export_metrics=should_export_otlp_metrics()
            and not MLFLOW_TRACE_ENABLE_OTLP_DUAL_EXPORT.get(),
        )
        processors.append(otel_processor)

        if not MLFLOW_TRACE_ENABLE_OTLP_DUAL_EXPORT.get():
            return processors

    # TODO: Update this logic to pluggable registry where
    #  1. Partners can implement span processor/exporter and destination class.
    #  2. They can register their implementation to the registry via entry points.
    #  3. MLflow will pick the implementation based on given destination id.
    trace_destination = _MLFLOW_TRACE_USER_DESTINATION.get()
    if trace_destination:
        # in PrPr, users must set the destination to UCSchemaLocation to export traces to UC
        if isinstance(trace_destination, UCSchemaLocation):
            from mlflow.tracing.export.uc_table import DatabricksUCTableSpanExporter
            from mlflow.tracing.processor.uc_table import DatabricksUCTableSpanProcessor

            exporter = DatabricksUCTableSpanExporter(tracking_uri=mlflow.get_tracking_uri())
            processor = DatabricksUCTableSpanProcessor(span_exporter=exporter)
            processors.append(processor)
        elif isinstance(trace_destination, (MlflowExperimentLocation)):
            if is_in_databricks_model_serving_environment():
                _logger.info(
                    "Traces will be sent to the destination set by `mlflow.tracing.set_destination`"
                    " API. To enable saving traces to both MLflow experiment and inference table, "
                    "remove this API call from your model and set `MLFLOW_EXPERIMENT_ID` env var "
                    "instead."
                )
            processor = _get_mlflow_span_processor(tracking_uri=mlflow.get_tracking_uri())
            processors.append(processor)
    elif is_in_databricks_model_serving_environment():
        if not is_mlflow_tracing_enabled_in_model_serving():
            return processors

        from mlflow.tracing.export.inference_table import InferenceTableSpanExporter
        from mlflow.tracing.processor.inference_table import InferenceTableSpanProcessor

        exporter = InferenceTableSpanExporter()
        processor = InferenceTableSpanProcessor(exporter)
        processors.append(processor)
    else:
        processor = _get_mlflow_span_processor(tracking_uri=mlflow.get_tracking_uri())
        processors.append(processor)

    return processors


def _get_mlflow_span_processor(tracking_uri: str):
    """
    Get the MLflow span processor instance that is used by the current tracer provider.
    """
    # Databricks and SQL backends support V3 traces
    from mlflow.tracing.export.mlflow_v3 import MlflowV3SpanExporter
    from mlflow.tracing.processor.mlflow_v3 import MlflowV3SpanProcessor

    exporter = MlflowV3SpanExporter(tracking_uri=tracking_uri)
    return MlflowV3SpanProcessor(
        span_exporter=exporter,
        export_metrics=should_export_otlp_metrics(),
    )


@raise_as_trace_exception
def disable():
    """
    Disable tracing.

    .. note::

        This function sets up `OpenTelemetry` to use
        `NoOpTracerProvider <https://github.com/open-telemetry/opentelemetry-python/blob/4febd337b019ea013ccaab74893bd9883eb59000/opentelemetry-api/src/opentelemetry/trace/__init__.py#L222>`_
        and effectively disables all tracing operations.

    Example:

    .. code-block:: python
        :test:

        import mlflow


        @mlflow.trace
        def f():
            return 0


        # Tracing is enabled by default
        f()
        assert len(mlflow.search_traces()) == 1

        # Disable tracing
        mlflow.tracing.disable()
        f()
        assert len(mlflow.search_traces()) == 1

    """
    if not is_tracing_enabled():
        return

    _initialize_tracer_provider(disabled=True)
    provider.once._done = True


@raise_as_trace_exception
def enable():
    """
    Enable tracing.

    Example:

    .. code-block:: python
        :test:

        import mlflow


        @mlflow.trace
        def f():
            return 0


        # Tracing is enabled by default
        f()
        assert len(mlflow.search_traces()) == 1

        # Disable tracing
        mlflow.tracing.disable()
        f()
        assert len(mlflow.search_traces()) == 1

        # Re-enable tracing
        mlflow.tracing.enable()
        f()
        assert len(mlflow.search_traces()) == 2

    """
    if is_tracing_enabled() and provider.once._done:
        _logger.info("Tracing is already enabled")
        return

    _initialize_tracer_provider()
    provider.once._done = True


def trace_disabled(f):
    """
    A decorator that temporarily disables tracing for the duration of the decorated function.

    .. code-block:: python

        @trace_disabled
        def f():
            with mlflow.start_span("my_span") as span:
                span.set_attribute("my_key", "my_value")

            return


        # This function will not generate any trace
        f()

    :meta private:
    """

    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        is_func_called = False
        result = None
        try:
            if is_tracing_enabled():
                disable()
                try:
                    is_func_called = True
                    result = f(*args, **kwargs)
                finally:
                    enable()
            else:
                is_func_called = True
                result = f(*args, **kwargs)
        # We should only catch the exception from disable() and enable()
        # and let other exceptions propagate.
        except MlflowTracingException as e:
            _logger.warning(
                f"An error occurred while disabling or re-enabling tracing: {e} "
                "The original function will still be executed, but the tracing "
                "state may not be as expected. For full traceback, set "
                "logging level to debug.",
                exc_info=_logger.isEnabledFor(logging.DEBUG),
            )
            # If the exception is raised before the original function
            # is called, we should call the original function
            if not is_func_called:
                result = f(*args, **kwargs)

        return result

    return wrapper


def reset():
    """
    Reset the flags that indicates whether the MLflow tracer provider has been initialized.
    This ensures that the tracer provider is re-initialized when next tracing
    operation is performed.
    """
    # Set NoOp tracer provider to reset the global tracer to the initial state.
    _initialize_tracer_provider(disabled=True)
    # Flip the "once" flag to False so that
    # the next tracing operation will re-initialize the provider.
    provider.reset()

    # Reset the custom destination set by the user
    _MLFLOW_TRACE_USER_DESTINATION.reset()

    # Reset the tracing configuration to defaults
    reset_config()


@raise_as_trace_exception
def is_tracing_enabled() -> bool:
    """
    Check if tracing is enabled based on whether the global tracer
    is instantiated or not.

    Trace is considered as "enabled" if the followings
    1. The default state (before any tracing operation)
    2. The tracer is not either ProxyTracer or NoOpTracer
    """
    if not provider.once._done:
        return True

    tracer = _get_tracer(__name__)
    # Occasionally ProxyTracer instance wraps the actual tracer
    if isinstance(tracer, trace.ProxyTracer):
        tracer = tracer._tracer
    return not isinstance(tracer, trace.NoOpTracer)
