mlflow.pytorch
MLflow integration for PyTorch.
Manages logging and loading PyTorch models; logged models can be loaded back as PyTorch models or as Python Function models.
-
mlflow.pytorch.
load_model
(path, run_id=None, **kwargs) Load a PyTorch model from a local file (if run_id is None) or a run. :param path: Local filesystem path or Run-relative artifact path to the model saved by
mlflow.pytorch.log_model()
.Parameters: - run_id – Run ID. If provided it is combined with path to identify the model.
- kwargs – kwargs to pass to torch.load method
-
mlflow.pytorch.
load_pyfunc
(path, **kwargs) Load the model as PyFunc. The loaded PyFunc exposes a predict(pd.DataFrame) -> pd.DataFrame method that, given an input DataFrame of n rows and k float-valued columns, feeds a corresponding (n x k) torch.FloatTensor (or torch.cuda.FloatTensor) as input to the PyTorch model.
predict
returns the model’s predictions (output tensor) in a single-column DataFrame.Parameters: - path – Local filesystem path to the model saved by
mlflow.pytorch.log_model()
. - kwargs – kwargs to pass to torch.load method.
- path – Local filesystem path to the model saved by
-
mlflow.pytorch.
log_model
(pytorch_model, artifact_path, conda_env=None, **kwargs) Log a PyTorch model as an MLflow artifact for the current run.
Parameters: - pytorch_model – PyTorch model to be saved. Must accept a single torch.FloatTensor as input and produce a single output tensor.
- artifact_path – Run-relative artifact path.
- conda_env – Path to a Conda environment file. If provided, this defines the environment for the model. At minimum, it should specify Python, PyTorch and MLflow with appropriate versions.
- kwargs – kwargs to pass to
torch.save
method
-
mlflow.pytorch.
save_model
(pytorch_model, path, conda_env=None, mlflow_model=<mlflow.models.Model object>, **kwargs) Save a PyTorch model to a path on the local file system.
Parameters: - pytorch_model – PyTorch model to be saved. Must accept a single torch.FloatTensor as input and produce a single output tensor.
- path – Local path where the model is to be saved.
- conda_env – Path to a Conda environment file. If provided, this decribes the environment this model should be run in. At minimum, it should specify Python, PyTorch and MLflow with appropriate versions.
- mlflow_model – MLflow model config this flavor is being added to.
- kwargs – kwargs to pass to
torch.save
method