mlflow.pytorch

The mlflow.pytorch module provides an API for logging and loading PyTorch models. This module exports PyTorch models with the following flavors:

PyTorch (native) format
This is the main flavor that can be loaded back into PyTorch.
mlflow.pyfunc
Produced for use by generic pyfunc-based deployment tools and batch inference.
mlflow.pytorch.get_default_conda_env()
Returns:The default Conda environment for MLflow Models produced by calls to save_model() and log_model().
mlflow.pytorch.load_model(model_uri, **kwargs)

Load a PyTorch model from a local file or a run.

Parameters:
  • model_uri

    The location, in URI format, of the MLflow model, for example:

    • /Users/me/path/to/local/model
    • relative/path/to/local/model
    • s3://my_bucket/path/to/model
    • runs:/<mlflow_run_id>/run-relative/path/to/model

    For more information about supported URI schemes, see Referencing Artifacts.

  • kwargs – kwargs to pass to torch.load method.
Returns:

A PyTorch model.

>>> import torch
>>> import mlflow
>>> import mlflow.pytorch
>>> # set values
>>> model_path_dir = ...
>>> run_id="96771d893a5e46159d9f3b49bf9013e2"
>>> pytorch_model = mlflow.pytorch.load_model("runs:/" + run_id + "/" + model_path_dir)
>>> y_pred = pytorch_model(x_new_data)
mlflow.pytorch.log_model(pytorch_model, artifact_path, conda_env=None, code_paths=None, pickle_module=None, **kwargs)

Log a PyTorch model as an MLflow artifact for the current run.

Parameters:
  • pytorch_model

    PyTorch model to be saved. Must accept a single torch.FloatTensor as input and produce a single output tensor. Any code dependencies of the model’s class, including the class definition itself, should be included in one of the following locations:

    • The package(s) listed in the model’s Conda environment, specified by the conda_env parameter.
    • One or more of the files specified by the code_paths parameter.
  • artifact_path – Run-relative artifact path.
  • conda_env

    Path to a Conda environment file. If provided, this decribes the environment this model should be run in. At minimum, it should specify the dependencies contained in get_default_conda_env(). If None, the default get_default_conda_env() environment is added to the model. The following is an example dictionary representation of a Conda environment:

    {
        'name': 'mlflow-env',
        'channels': ['defaults'],
        'dependencies': [
            'python=3.7.0',
            'pytorch=0.4.1',
            'torchvision=0.2.1'
        ]
    }
    
  • code_paths – A list of local filesystem paths to Python file dependencies (or directories containing file dependencies). These files are prepended to the system path when the model is loaded.
  • pickle_module – The module that PyTorch should use to serialize (“pickle”) the specified pytorch_model. This is passed as the pickle_module parameter to torch.save(). By default, this module is also used to deserialize (“unpickle”) the PyTorch model at load time.
  • kwargs – kwargs to pass to torch.save method.
>>> import torch
>>> import mlflow
>>> import mlflow.pytorch
>>> # X data
>>> x_data = torch.Tensor([[1.0], [2.0], [3.0]])
>>> # Y data with its expected value: labels
>>> y_data = torch.Tensor([[2.0], [4.0], [6.0]])
>>> # Partial Model example modified from Sung Kim
>>> # https://github.com/hunkim/PyTorchZeroToAll
>>> class Model(torch.nn.Module):
>>>    def __init__(self):
>>>       super(Model, self).__init__()
>>>       self.linear = torch.nn.Linear(1, 1)  # One in and one out
>>>    def forward(self, x):
>>>        y_pred = self.linear(x)
>>>        return y_pred
>>> # our model
>>> model = Model()
>>> criterion = torch.nn.MSELoss(size_average=False)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
>>> # Training loop
>>> for epoch in range(500):
>>>    # Forward pass: Compute predicted y by passing x to the model
>>>    y_pred = model(x_data)
>>>    # Compute and print loss
>>>    loss = criterion(y_pred, y_data)
>>>    print(epoch, loss.data.item())
>>>    #Zero gradients, perform a backward pass, and update the weights.
>>>    optimizer.zero_grad()
>>>    loss.backward()
>>>    optimizer.step()
>>>
>>> # After training
>>> for hv in [4.0, 5.0, 6.0]:
>>>     hour_var = torch.Tensor([[hv]])
>>>     y_pred = model(hour_var)
>>>     print("predict (after training)",  hv, model(hour_var).data[0][0])
>>> # log the model
>>> with mlflow.start_run() as run:
>>>   mlflow.log_param("epochs", 500)
>>>   mlflow.pytorch.log_model(model, "models")
mlflow.pytorch.save_model(pytorch_model, path, conda_env=None, mlflow_model=<mlflow.models.Model object>, code_paths=None, pickle_module=None, **kwargs)

Save a PyTorch model to a path on the local file system.

Parameters:
  • pytorch_model

    PyTorch model to be saved. Must accept a single torch.FloatTensor as input and produce a single output tensor. Any code dependencies of the model’s class, including the class definition itself, should be included in one of the following locations:

    • The package(s) listed in the model’s Conda environment, specified by the conda_env parameter.
    • One or more of the files specified by the code_paths parameter.
  • path – Local path where the model is to be saved.
  • conda_env

    Either a dictionary representation of a Conda environment or the path to a Conda environment yaml file. If provided, this decribes the environment this model should be run in. At minimum, it should specify the dependencies contained in get_default_conda_env(). If None, the default get_default_conda_env() environment is added to the model. The following is an example dictionary representation of a Conda environment:

    {
        'name': 'mlflow-env',
        'channels': ['defaults'],
        'dependencies': [
            'python=3.7.0',
            'pytorch=0.4.1',
            'torchvision=0.2.1'
        ]
    }
    
  • mlflow_modelmlflow.models.Model this flavor is being added to.
  • code_paths – A list of local filesystem paths to Python file dependencies (or directories containing file dependencies). These files are prepended to the system path when the model is loaded.
  • pickle_module – The module that PyTorch should use to serialize (“pickle”) the specified pytorch_model. This is passed as the pickle_module parameter to torch.save(). By default, this module is also used to deserialize (“unpickle”) the PyTorch model at load time.
  • kwargs – kwargs to pass to torch.save method.
>>> import torch
>>> import mlflow
>>> import mlflow.pytorch
>>> # create model and set values
>>> pytorch_model = Model()
>>> pytorch_model_path = ...
>>> #train our model
>>> for epoch in range(500):
>>>     y_pred = model(x_data)
>>>     ...
>>> #save the model
>>> with mlflow.start_run() as run:
>>>   mlflow.log_param("epochs", 500)
>>>   mlflow.pytorch.save_model(pytorch_model, pytorch_model_path)