mlflow.pytorch
The mlflow.pytorch
module provides an API for logging and loading PyTorch models. This module
exports PyTorch models with the following flavors:
- PyTorch (native) format
- This is the main flavor that can be loaded back into PyTorch.
mlflow.pyfunc
- Produced for use by generic pyfunc-based deployment tools and batch inference.
-
mlflow.pytorch.
get_default_conda_env
() Returns: The default Conda environment for MLflow Models produced by calls to save_model()
andlog_model()
.
-
mlflow.pytorch.
load_model
(model_uri, **kwargs) Load a PyTorch model from a local file or a run.
Parameters: - model_uri –
The location, in URI format, of the MLflow model, for example:
/Users/me/path/to/local/model
relative/path/to/local/model
s3://my_bucket/path/to/model
runs:/<mlflow_run_id>/run-relative/path/to/model
For more information about supported URI schemes, see Referencing Artifacts.
- kwargs – kwargs to pass to
torch.load
method.
Returns: A PyTorch model.
>>> import torch >>> import mlflow >>> import mlflow.pytorch >>> # set values >>> model_path_dir = ... >>> run_id="96771d893a5e46159d9f3b49bf9013e2" >>> pytorch_model = mlflow.pytorch.load_model("runs:/" + run_id + "/" + model_path_dir) >>> y_pred = pytorch_model(x_new_data)
- model_uri –
-
mlflow.pytorch.
log_model
(pytorch_model, artifact_path, conda_env=None, code_paths=None, pickle_module=None, **kwargs) Log a PyTorch model as an MLflow artifact for the current run.
Parameters: - pytorch_model –
PyTorch model to be saved. Must accept a single
torch.FloatTensor
as input and produce a single output tensor. Any code dependencies of the model’s class, including the class definition itself, should be included in one of the following locations:- The package(s) listed in the model’s Conda environment, specified
by the
conda_env
parameter. - One or more of the files specified by the
code_paths
parameter.
- The package(s) listed in the model’s Conda environment, specified
by the
- artifact_path – Run-relative artifact path.
- conda_env –
Path to a Conda environment file. If provided, this decribes the environment this model should be run in. At minimum, it should specify the dependencies contained in
get_default_conda_env()
. IfNone
, the defaultget_default_conda_env()
environment is added to the model. The following is an example dictionary representation of a Conda environment:{ 'name': 'mlflow-env', 'channels': ['defaults'], 'dependencies': [ 'python=3.7.0', 'pytorch=0.4.1', 'torchvision=0.2.1' ] }
- code_paths – A list of local filesystem paths to Python file dependencies (or directories containing file dependencies). These files are prepended to the system path when the model is loaded.
- pickle_module – The module that PyTorch should use to serialize (“pickle”) the specified
pytorch_model
. This is passed as thepickle_module
parameter totorch.save()
. By default, this module is also used to deserialize (“unpickle”) the PyTorch model at load time. - kwargs – kwargs to pass to
torch.save
method.
>>> import torch >>> import mlflow >>> import mlflow.pytorch >>> # X data >>> x_data = torch.Tensor([[1.0], [2.0], [3.0]]) >>> # Y data with its expected value: labels >>> y_data = torch.Tensor([[2.0], [4.0], [6.0]]) >>> # Partial Model example modified from Sung Kim >>> # https://github.com/hunkim/PyTorchZeroToAll >>> class Model(torch.nn.Module): >>> def __init__(self): >>> super(Model, self).__init__() >>> self.linear = torch.nn.Linear(1, 1) # One in and one out >>> def forward(self, x): >>> y_pred = self.linear(x) >>> return y_pred >>> # our model >>> model = Model() >>> criterion = torch.nn.MSELoss(size_average=False) >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.01) >>> # Training loop >>> for epoch in range(500): >>> # Forward pass: Compute predicted y by passing x to the model >>> y_pred = model(x_data) >>> # Compute and print loss >>> loss = criterion(y_pred, y_data) >>> print(epoch, loss.data.item()) >>> #Zero gradients, perform a backward pass, and update the weights. >>> optimizer.zero_grad() >>> loss.backward() >>> optimizer.step() >>> >>> # After training >>> for hv in [4.0, 5.0, 6.0]: >>> hour_var = torch.Tensor([[hv]]) >>> y_pred = model(hour_var) >>> print("predict (after training)", hv, model(hour_var).data[0][0]) >>> # log the model >>> with mlflow.start_run() as run: >>> mlflow.log_param("epochs", 500) >>> mlflow.pytorch.log_model(model, "models")
- pytorch_model –
-
mlflow.pytorch.
save_model
(pytorch_model, path, conda_env=None, mlflow_model=<mlflow.models.Model object>, code_paths=None, pickle_module=None, **kwargs) Save a PyTorch model to a path on the local file system.
Parameters: - pytorch_model –
PyTorch model to be saved. Must accept a single
torch.FloatTensor
as input and produce a single output tensor. Any code dependencies of the model’s class, including the class definition itself, should be included in one of the following locations:- The package(s) listed in the model’s Conda environment, specified
by the
conda_env
parameter. - One or more of the files specified by the
code_paths
parameter.
- The package(s) listed in the model’s Conda environment, specified
by the
- path – Local path where the model is to be saved.
- conda_env –
Either a dictionary representation of a Conda environment or the path to a Conda environment yaml file. If provided, this decribes the environment this model should be run in. At minimum, it should specify the dependencies contained in
get_default_conda_env()
. IfNone
, the defaultget_default_conda_env()
environment is added to the model. The following is an example dictionary representation of a Conda environment:{ 'name': 'mlflow-env', 'channels': ['defaults'], 'dependencies': [ 'python=3.7.0', 'pytorch=0.4.1', 'torchvision=0.2.1' ] }
- mlflow_model –
mlflow.models.Model
this flavor is being added to. - code_paths – A list of local filesystem paths to Python file dependencies (or directories containing file dependencies). These files are prepended to the system path when the model is loaded.
- pickle_module – The module that PyTorch should use to serialize (“pickle”) the specified
pytorch_model
. This is passed as thepickle_module
parameter totorch.save()
. By default, this module is also used to deserialize (“unpickle”) the PyTorch model at load time. - kwargs – kwargs to pass to
torch.save
method.
>>> import torch >>> import mlflow >>> import mlflow.pytorch >>> # create model and set values >>> pytorch_model = Model() >>> pytorch_model_path = ... >>> #train our model >>> for epoch in range(500): >>> y_pred = model(x_data) >>> ... >>> #save the model >>> with mlflow.start_run() as run: >>> mlflow.log_param("epochs", 500) >>> mlflow.pytorch.save_model(pytorch_model, pytorch_model_path)
- pytorch_model –