Source code for mlflow.tracking._workspace.fluent

from __future__ import annotations

import threading
from typing import Callable, TypeVar

from mlflow.entities.workspace import Workspace
from mlflow.exceptions import MlflowException, RestException
from mlflow.protos import databricks_pb2
from mlflow.protos.databricks_pb2 import FEATURE_DISABLED
from mlflow.store.workspace.abstract_store import WorkspaceNameValidator
from mlflow.tracking.client import MlflowClient
from mlflow.utils.annotations import experimental
from mlflow.utils.workspace_context import set_workspace as set_context_workspace
from mlflow.utils.workspace_utils import DEFAULT_WORKSPACE_NAME

T = TypeVar("T")

_workspace_lock = threading.Lock()


def _workspace_client_call(func: Callable[[MlflowClient], T]) -> T:
    client = MlflowClient()
    try:
        return func(client)
    except RestException as exc:
        if exc.error_code == databricks_pb2.ErrorCode.Name(databricks_pb2.ENDPOINT_NOT_FOUND):
            raise MlflowException(
                "The configured tracking server does not expose workspace APIs. "
                "Ensure workspace is enabled.",
                error_code=FEATURE_DISABLED,
            ) from exc
        raise


[docs]@experimental(version="3.10.0") def set_workspace(workspace: str | None) -> None: """Set the active workspace for subsequent MLflow operations.""" with _workspace_lock: if workspace is None: set_context_workspace(None) return if workspace != DEFAULT_WORKSPACE_NAME: WorkspaceNameValidator.validate(workspace) set_context_workspace(workspace)
[docs]@experimental(version="3.10.0") def list_workspaces() -> list[Workspace]: """Return the list of workspaces available to the current user.""" return _workspace_client_call(lambda client: client.list_workspaces())
[docs]@experimental(version="3.10.0") def get_workspace(name: str) -> Workspace: """Return metadata for the specified workspace.""" return _workspace_client_call(lambda client: client.get_workspace(name))
[docs]@experimental(version="3.10.0") def create_workspace( name: str, description: str | None = None, default_artifact_root: str | None = None ) -> Workspace: """Create a new workspace. Args: name: The workspace name (lowercase alphanumeric with optional internal hyphens). description: Optional description of the workspace. default_artifact_root: Optional artifact root URI; falls back to server default. Returns: The newly created workspace. Raises: MlflowException: If the name is invalid, already exists, or no artifact root available. """ WorkspaceNameValidator.validate(name) return _workspace_client_call( lambda client: client.create_workspace( name=name, description=description, default_artifact_root=default_artifact_root, ) )
[docs]@experimental(version="3.10.0") def update_workspace( name: str, description: str | None = None, default_artifact_root: str | None = None ) -> Workspace: """Update metadata for an existing workspace. Args: name: The name of the workspace to update. description: New description, or ``None`` to leave unchanged. default_artifact_root: New artifact root URI, empty string to clear, or ``None``. Returns: The updated workspace. Raises: MlflowException: If the workspace does not exist or no artifact root available. """ if name != DEFAULT_WORKSPACE_NAME: WorkspaceNameValidator.validate(name) return _workspace_client_call( lambda client: client.update_workspace( name=name, description=description, default_artifact_root=default_artifact_root, ) )
[docs]@experimental(version="3.10.0") def delete_workspace(name: str) -> None: """Delete an existing workspace.""" if name != DEFAULT_WORKSPACE_NAME: WorkspaceNameValidator.validate(name) _workspace_client_call(lambda client: client.delete_workspace(name=name))
__all__ = [ "Workspace", "set_workspace", "list_workspaces", "get_workspace", "create_workspace", "update_workspace", "delete_workspace", ]