import importlib
from packaging.version import Version
import mlflow
from mlflow.dspy.save import FLAVOR_NAME
from mlflow.models.model import _MODEL_TRACKER
from mlflow.tracing.provider import trace_disabled
from mlflow.tracing.utils import construct_full_inputs
from mlflow.utils.annotations import experimental
from mlflow.utils.autologging_utils import (
    autologging_integration,
    get_autologging_config,
    safe_patch,
)
[docs]@experimental
def autolog(
    log_traces: bool = True,
    log_traces_from_compile: bool = False,
    log_traces_from_eval: bool = True,
    log_compiles: bool = False,
    log_evals: bool = False,
    disable: bool = False,
    silent: bool = False,
    log_models: bool = True,
):
    """
    Enables (or disables) and configures autologging from DSPy to MLflow. Currently, the
    MLflow DSPy flavor only supports autologging for tracing.
    Args:
        log_traces: If ``True``, traces are logged for DSPy models by using. If ``False``,
            no traces are collected during inference. Default to ``True``.
        log_traces_from_compile: If ``True``, traces are logged when compiling (optimizing)
            DSPy programs. If ``False``, traces are only logged from normal model inference and
            disabled when compiling. Default to ``False``.
        log_traces_from_eval: If ``True``, traces are logged for DSPy models when running DSPy's
            `built-in evaluator <https://dspy.ai/learn/evaluation/metrics/#evaluation>`_.
            If ``False``, traces are only logged from normal model inference and disabled when
            running the evaluator. Default to ``True``.
        log_compiles: If ``True``, information about the optimization process is logged when
            `Teleprompter.compile()` is called.
        log_evals: If ``True``, information about the evaluation call is logged when
            `Evaluate.__call__()` is called.
        disable: If ``True``, disables the DSPy autologging integration. If ``False``,
            enables the DSPy autologging integration.
        silent: If ``True``, suppress all event logs and warnings from MLflow during DSPy
            autologging. If ``False``, show all events and warnings.
        log_models: If ``True``, automatically create a LoggedModel when the model
            used for inference is not already logged. The created LoggedModel contains no model
            artifacts, but it will be used to associate all traces generated by the model. If
            ``False``, no LoggedModel is created and the traces will not be associated with any
            model. Default to ``True``.
            .. Note:: Experimental: This argument may change or be removed in a future release
            without warning.
    """
    # NB: The @autologging_integration annotation is used for adding shared logic. However, one
    # caveat is that the wrapped function is NOT executed when disable=True is passed. This prevents
    # us from running cleaning up logging when autologging is turned off. To workaround this, we
    # annotate _autolog() instead of this entrypoint, and define the cleanup logic outside it.
    # This needs to be called before doing any safe-patching (otherwise safe-patch will be no-op).
    # TODO: since this implementation is inconsistent, explore a universal way to solve the issue.
    _autolog(
        log_traces=log_traces,
        log_traces_from_compile=log_traces_from_compile,
        log_traces_from_eval=log_traces_from_eval,
        log_compiles=log_compiles,
        log_evals=log_evals,
        disable=disable,
        silent=silent,
        log_models=log_models,
    )
    import dspy
    from mlflow.dspy.callback import MlflowCallback
    from mlflow.dspy.util import log_dspy_dataset, save_dspy_module_state
    # Enable tracing by setting the MlflowCallback
    if not disable:
        if not any(isinstance(c, MlflowCallback) for c in dspy.settings.callbacks):
            dspy.settings.configure(callbacks=[*dspy.settings.callbacks, MlflowCallback()])
    else:
        dspy.settings.configure(
            callbacks=[c for c in dspy.settings.callbacks if not isinstance(c, MlflowCallback)]
        )
    def patch_module(original, self, *args, **kwargs):
        if model_id := _MODEL_TRACKER.get(id(self)):
            _MODEL_TRACKER.set_active_model_id(model_id)
        elif not _MODEL_TRACKER._is_active_model_id_set and log_models:
            logged_model = mlflow.create_external_model(name=self.__class__.__name__)
            _MODEL_TRACKER.set(id(self), logged_model.model_id)
            _MODEL_TRACKER.set_active_model_id(logged_model.model_id)
        else:
            _MODEL_TRACKER.set_active_model_id(None)
            return original(self, *args, **kwargs)
        # This is needed because we should not create LoggedModel for internal objects
        _MODEL_TRACKER._is_active_model_id_set = True
        try:
            return original(self, *args, **kwargs)
        finally:
            _MODEL_TRACKER._is_active_model_id_set = False
    # Patch teleprompter/evaluator not to generate traces by default
    def patch_fn(original, self, *args, **kwargs):
        # NB: Since calling mlflow.dspy.autolog() again does not unpatch a function, we need to
        # check this flag at runtime to determine if we should generate traces.
        # method to disable tracing for compile and evaluate by default
        @trace_disabled
        def _trace_disabled_fn(self, *args, **kwargs):
            return original(self, *args, **kwargs)
        def _compile_fn(self, *args, **kwargs):
            if callback := _active_callback():
                callback.optimizer_stack_level += 1
            try:
                if get_autologging_config(FLAVOR_NAME, "log_traces_from_compile"):
                    result = original(self, *args, **kwargs)
                else:
                    result = _trace_disabled_fn(self, *args, **kwargs)
                return result
            finally:
                if callback:
                    callback.optimizer_stack_level -= 1
                    if callback.optimizer_stack_level == 0:
                        # Reset the callback state after the completion of root compile
                        callback.reset()
        if isinstance(self, Teleprompter):
            if not get_autologging_config(FLAVOR_NAME, "log_compiles"):
                return _compile_fn(self, *args, **kwargs)
            program = _compile_fn(self, *args, **kwargs)
            # Save the state of the best model in json format
            # so that users can see the demonstrations and instructions.
            save_dspy_module_state(program, "best_model.json")
            # Teleprompter.get_params is introduced in dspy 2.6.15
            params = (
                self.get_params()
                if Version(importlib.metadata.version("dspy")) >= Version("2.6.15")
                else {}
            )
            # Construct the dict of arguments passed to the compile call
            inputs = construct_full_inputs(original, self, *args, **kwargs)
            # Update params with the arguments passed to the compile call
            params.update(inputs)
            mlflow.log_params(
                {k: v for k, v in inputs.items() if isinstance(v, (int, float, str, bool))}
            )
            if trainset := inputs.get("trainset"):
                log_dspy_dataset(trainset, "trainset.json")
            if valset := inputs.get("valset"):
                log_dspy_dataset(valset, "valset.json")
            return program
        if isinstance(self, Teleprompter) and get_autologging_config(
            FLAVOR_NAME, "log_traces_from_compile"
        ):
            return original(self, *args, **kwargs)
        if isinstance(self, Evaluate) and get_autologging_config(
            FLAVOR_NAME, "log_traces_from_eval"
        ):
            return original(self, *args, **kwargs)
        return _trace_disabled_fn(self, *args, **kwargs)
    from dspy import Module
    from dspy.evaluate import Evaluate
    from dspy.teleprompt import Teleprompter
    safe_patch(
        FLAVOR_NAME,
        Module,
        "__call__",
        patch_module,
    )
    compile_patch = "compile"
    for cls in Teleprompter.__subclasses__():
        # NB: This is to avoid the abstraction inheritance of superclasses that are defined
        # only for the purposes of abstraction. The recursion behavior of the
        # __subclasses__ dunder method will target the appropriate subclasses we need to patch.
        if hasattr(cls, compile_patch):
            safe_patch(
                FLAVOR_NAME,
                cls,
                compile_patch,
                patch_fn,
                manage_run=get_autologging_config(FLAVOR_NAME, "log_compiles"),
            )
    call_patch = "__call__"
    if hasattr(Evaluate, call_patch):
        safe_patch(
            FLAVOR_NAME,
            Evaluate,
            call_patch,
            patch_fn,
        ) 
# This is required by mlflow.autolog()
autolog.integration_name = FLAVOR_NAME
@autologging_integration(FLAVOR_NAME)
def _autolog(
    log_traces: bool,
    log_traces_from_compile: bool,
    log_traces_from_eval: bool,
    log_compiles: bool,
    log_evals: bool,
    disable: bool = False,
    silent: bool = False,
    log_models: bool = True,
):
    pass
def _active_callback():
    import dspy
    from mlflow.dspy.callback import MlflowCallback
    for callback in dspy.settings.callbacks:
        if isinstance(callback, MlflowCallback):
            return callback