Source code for mlflow.genai.git_versioning

import logging

from typing_extensions import Self

import mlflow
from mlflow.genai.git_versioning.git_info import GitInfo, GitOperationError
from mlflow.telemetry.events import GitModelVersioningEvent
from mlflow.telemetry.track import record_usage_event
from mlflow.tracking.fluent import _set_active_model
from mlflow.utils.annotations import experimental

_logger = logging.getLogger(__name__)


class GitContext:
    def __init__(self, remote_name: str = "origin") -> None:
        try:
            self.info = GitInfo.from_env(remote_name=remote_name)
        except GitOperationError as e:
            _logger.warning(
                f"Encountered an error while retrieving git information: {e}. "
                f"Git model versioning is disabled."
            )
            self.info = None
            self.active_model = None
            return

        git_tags = self.info.to_mlflow_tags()
        filter_string = self.info.to_search_filter_string()
        models = mlflow.search_logged_models(
            filter_string=filter_string,
            max_results=1,
            output_format="list",
        )
        match models:
            case [m]:
                _logger.info(
                    f"Using existing model with branch '{self.info.branch}', "
                    f"commit '{self.info.commit}', dirty state '{self.info.dirty}'."
                )
                model = m
                # Update tags to ensure they're current (especially git diff)
                mlflow.set_logged_model_tags(model_id=model.model_id, tags=git_tags)
            case _:
                _logger.info(
                    "No existing model found with the current git information. "
                    "Creating a new model."
                )
                model = mlflow.initialize_logged_model(tags=git_tags)

        self.active_model = _set_active_model(model_id=model.model_id)

    def __enter__(self) -> Self:
        return self

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        disable_git_model_versioning()


# Global variable to track the active git context
_active_context: GitContext | None = None


@record_usage_event(GitModelVersioningEvent)
def _enable_git_model_versioning(remote_name: str) -> None:
    global _active_context
    context = GitContext(remote_name=remote_name)
    _active_context = context
    return context


[docs]@experimental(version="3.4.0") def enable_git_model_versioning(remote_name: str = "origin") -> GitContext: """ Enable automatic Git-based model versioning for MLflow traces. This function enables automatic version tracking based on your Git repository state. When enabled, MLflow will: - Detect the current Git branch, commit hash, and dirty state - Create or reuse a LoggedModel matching this exact Git state - Link all subsequent traces to this LoggedModel version - Capture uncommitted changes as diffs when the repository is dirty Args: remote_name: The name of the git remote to use for repository URL detection. Defaults to "origin". Returns: A GitContext instance containing: - info: GitInfo object with branch, commit, dirty state, and diff information - active_model: The active LoggedModel linked to current Git state Example: .. code-block:: python import mlflow.genai # Enable Git-based versioning context = mlflow.genai.enable_git_model_versioning() print(f"Branch: {context.info.branch}, Commit: {context.info.commit[:8]}") # Output: Branch: main, Commit: abc12345 # All traces are now automatically linked to this Git version @mlflow.trace def my_app(): return "result" # Can also use as a context manager with mlflow.genai.enable_git_model_versioning() as context: # Traces within this block are linked to the Git version result = my_app() Note: If Git is not available or the current directory is not a Git repository, a warning is issued and versioning is disabled (context.info will be None). """ return _enable_git_model_versioning(remote_name)
[docs]@experimental(version="3.4.0") def disable_git_model_versioning() -> None: """ Disable Git-based model versioning and clear the active model context. This function stops automatic Git-based version tracking and clears any active LoggedModel context. After calling this, traces will no longer be automatically linked to Git-based versions. This is automatically called when exiting a context manager created with enable_git_model_versioning(). Example: .. code-block:: python import mlflow.genai # Enable versioning context = mlflow.genai.enable_git_model_versioning() # ... do work with versioning enabled ... # Disable versioning mlflow.genai.disable_git_model_versioning() # Traces are no longer linked to Git versions """ global _active_context _active_context = None mlflow.clear_active_model()
def _get_active_git_context() -> GitContext | None: """ Get the currently active git context, if any. Returns: The active GitContext instance or None if no context is active. """ return _active_context