Source code for mlflow.tracking._model_registry.fluent

from mlflow.exceptions import MlflowException
from mlflow.entities.model_registry import ModelVersion
from mlflow.protos.databricks_pb2 import RESOURCE_ALREADY_EXISTS, ErrorCode
from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository
from mlflow.tracking import MlflowClient
from mlflow.utils.logging_utils import eprint
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS


[docs]def register_model( model_uri, name, await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS ) -> ModelVersion: """ Create a new model version in model registry for the model files specified by ``model_uri``. Note that this method assumes the model registry backend URI is the same as that of the tracking backend. :param model_uri: URI referring to the MLmodel directory. Use a ``runs:/`` URI if you want to record the run ID with the model in model registry. ``models:/`` URIs are currently not supported. :param name: Name of the registered model under which to create a new model version. If a registered model with the given name does not exist, it will be created automatically. :param await_registration_for: Number of seconds to wait for the model version to finish being created and is in ``READY`` status. By default, the function waits for five minutes. Specify 0 or None to skip waiting. :return: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object created by backend. .. code-block:: python :caption: Example import mlflow.sklearn from sklearn.ensemble import RandomForestRegressor mlflow.set_tracking_uri("sqlite:////tmp/mlruns.db") params = {"n_estimators": 3, "random_state": 42} # Log MLflow entities with mlflow.start_run() as run: rfr = RandomForestRegressor(**params).fit([[0, 1]], [1]) mlflow.log_params(params) mlflow.sklearn.log_model(rfr, artifact_path="sklearn-model") model_uri = "runs:/{}/sklearn-model".format(run.info.run_id) mv = mlflow.register_model(model_uri, "RandomForestRegressionModel") print("Name: {}".format(mv.name)) print("Version: {}".format(mv.version)) .. code-block:: text :caption: Output Name: RandomForestRegressionModel Version: 1 """ client = MlflowClient() try: create_model_response = client.create_registered_model(name) eprint("Successfully registered model '%s'." % create_model_response.name) except MlflowException as e: if e.error_code == ErrorCode.Name(RESOURCE_ALREADY_EXISTS): eprint( "Registered model '%s' already exists. Creating a new version of this model..." % name ) else: raise e if RunsArtifactRepository.is_runs_uri(model_uri): source = RunsArtifactRepository.get_underlying_uri(model_uri) (run_id, _) = RunsArtifactRepository.parse_runs_uri(model_uri) create_version_response = client.create_model_version( name, source, run_id, await_creation_for=await_registration_for ) else: create_version_response = client.create_model_version( name, source=model_uri, run_id=None, await_creation_for=await_registration_for ) eprint( "Created version '{version}' of model '{model_name}'.".format( version=create_version_response.version, model_name=create_version_response.name ) ) return create_version_response