mlflow.pytorch

MLflow integration for PyTorch.

Manages logging and loading PyTorch models; logged models can be loaded back as PyTorch models or as Python Function models.

mlflow.pytorch.load_model(path, run_id=None, **kwargs)

Load a PyTorch model from a local file (if run_id is None) or a run. :param path: Local filesystem path or Run-relative artifact path to the model saved by mlflow.pytorch.log_model().

Parameters:
  • run_id – Run ID. If provided it is combined with path to identify the model.
  • kwargs – kwargs to pass to torch.load method
mlflow.pytorch.load_pyfunc(path, **kwargs)

Load the model as PyFunc. The loaded PyFunc exposes a predict(pd.DataFrame) -> pd.DataFrame method that, given an input DataFrame of n rows and k float-valued columns, feeds a corresponding (n x k) torch.FloatTensor (or torch.cuda.FloatTensor) as input to the PyTorch model. predict returns the model’s predictions (output tensor) in a single-column DataFrame.

Parameters:
  • path – Local filesystem path to the model saved by mlflow.pytorch.log_model().
  • kwargs – kwargs to pass to torch.load method.
mlflow.pytorch.log_model(pytorch_model, artifact_path, conda_env=None, **kwargs)

Log a PyTorch model as an MLflow artifact for the current run.

Parameters:
  • pytorch_model – PyTorch model to be saved. Must accept a single torch.FloatTensor as input and produce a single output tensor.
  • artifact_path – Run-relative artifact path.
  • conda_env – Path to a Conda environment file. If provided, this defines the environment for the model. At minimum, it should specify Python, PyTorch and MLflow with appropriate versions.
  • kwargs – kwargs to pass to torch.save method
mlflow.pytorch.save_model(pytorch_model, path, conda_env=None, mlflow_model=<mlflow.models.Model object>, **kwargs)

Save a PyTorch model to a path on the local file system.

Parameters:
  • pytorch_model – PyTorch model to be saved. Must accept a single torch.FloatTensor as input and produce a single output tensor.
  • path – Local path where the model is to be saved.
  • conda_env – Path to a Conda environment file. If provided, this decribes the environment this model should be run in. At minimum, it should specify Python, PyTorch and MLflow with appropriate versions.
  • mlflow_model – MLflow model config this flavor is being added to.
  • kwargs – kwargs to pass to torch.save method