Deep Learning with MLflow 3.0
In this example, we demonstrate how to use MLflow 3.0 to track and evaluate deep learning models with a PyTorch-based Iris classifier.
The example showcases how to log checkpoints, link metrics to datasets, and rank models based on performance metrics.
With mlflow.search_logged_models() you can easily find the best model based on the metric value.
import pandas as pd
import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import mlflow
import mlflow.pytorch
from mlflow.entities import Dataset
# Helper function to prepare data
def prepare_data(df):
    X = torch.tensor(df.iloc[:, :-1].values, dtype=torch.float32)
    y = torch.tensor(df.iloc[:, -1].values, dtype=torch.long)
    return X, y
# Helper function to compute accuracy
def compute_accuracy(model, X, y):
    with torch.no_grad():
        outputs = model(X)
        _, predicted = torch.max(outputs, 1)
        accuracy = (predicted == y).sum().item() / y.size(0)
    return accuracy
# Define a basic PyTorch classifier
class IrisClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
# Load Iris dataset and prepare the DataFrame
iris = load_iris()
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df["target"] = iris.target
# Split into training and testing datasets
train_df, test_df = train_test_split(iris_df, test_size=0.2, random_state=42)
# Prepare training data
train_dataset = mlflow.data.from_pandas(train_df, name="train")
X_train, y_train = prepare_data(train_dataset.df)
# Define the PyTorch model and move it to the device
input_size = X_train.shape[1]
hidden_size = 16
output_size = len(iris.target_names)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scripted_model = IrisClassifier(input_size, hidden_size, output_size).to(device)
scripted_model = torch.jit.script(scripted_model)
# Start a run to represent the training job
with mlflow.start_run() as run:
    # Load the training dataset with MLflow. We will link training metrics to this dataset.
    train_dataset: Dataset = mlflow.data.from_pandas(train_df, name="train")
    X_train, y_train = prepare_data(train_dataset.df)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(scripted_model.parameters(), lr=0.01)
    for epoch in range(101):
        X_train, y_train = X_train.to(device), y_train.to(device)
        out = scripted_model(X_train)
        loss = criterion(out, y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Log a checkpoint with metrics every 10 epochs
        if epoch % 10 == 0:
            # Each newly created LoggedModel checkpoint is linked with its name and step
            model_info = mlflow.pytorch.log_model(
                pytorch_model=scripted_model,
                name=f"torch-iris-{epoch}",
                step=epoch,
                input_example=X_train.numpy(),
            )
            # log params to the run, LoggedModel inherits those params
            mlflow.log_params(
                params={
                    "n_layers": 3,
                    "activation": "ReLU",
                    "criterion": "CrossEntropyLoss",
                    "optimizer": "Adam",
                }
            )
            # Log metric on training dataset at step and link to LoggedModel
            mlflow.log_metric(
                key="accuracy",
                value=compute_accuracy(scripted_model, X_train, y_train),
                step=epoch,
                model_id=model_info.model_id,
                dataset=train_dataset,
            )
In the run page, you can see the logged models generated, and the model names follow the pattern of torch-iris-<epoch>:

The logged models also show up in the Models tab of the experiment, including their dataset, parameters and metrics:

Use mlflow.search_logged_models() to search the logged models attached to the run, ordering them by the accuracy metric value to easily fetch the best and worst models:
ranked_checkpoints = mlflow.search_logged_models(
    filter_string=f"source_run_id='{run.info.run_id}'",
    order_by=[{"field_name": "metrics.accuracy", "ascending": False}],
    output_format="list",
)
best_checkpoint = ranked_checkpoints[0]
print(f"Best model: {best_checkpoint}")
print(best_checkpoint.metrics)
# Best model: <LoggedModel: artifact_location='file:///Users/serena.ruan/Documents/repos/mlflow-3-doc/mlruns/0/models/41bd5a16-25a6-447b-90e0-0f7b7e5cb6cf/artifacts', creation_timestamp=1743734069924, experiment_id='0', last_updated_timestamp=1743734075018, metrics=[<Metric: dataset_digest='1f1c13b5', dataset_name='train', key='accuracy', model_id='41bd5a16-25a6-447b-90e0-0f7b7e5cb6cf', run_id='12f143a7fda1461e9240d7ffad4ea5bd', step=100, timestamp=1743734075029, value=0.975>], model_id='41bd5a16-25a6-447b-90e0-0f7b7e5cb6cf', model_type='', model_uri='models:/41bd5a16-25a6-447b-90e0-0f7b7e5cb6cf', name='torch-iris-100', params={'activation': 'ReLU',
#  'criterion': 'CrossEntropyLoss',
#  'n_layers': '3',
#  'optimizer': 'Adam'}, source_run_id='12f143a7fda1461e9240d7ffad4ea5bd', status=<LoggedModelStatus.READY: 'READY'>, status_message='', tags={'mlflow.source.git.commit': '7324c807f07a1766d4b951733e3d723504b4576e',
#  'mlflow.source.name': 'a.py',
#  'mlflow.source.type': 'LOCAL',
#  'mlflow.user': 'serena.ruan'}>
# [<Metric: dataset_digest='1f1c13b5', dataset_name='train', key='accuracy', model_id='41bd5a16-25a6-447b-90e0-0f7b7e5cb6cf', run_id='12f143a7fda1461e9240d7ffad4ea5bd', step=100, timestamp=1743734075029, value=0.975>]
worst_checkpoint = ranked_checkpoints[-1]
print(f"Worst model: {worst_checkpoint}")
print(worst_checkpoint.metrics)
# Worst model: <LoggedModel: artifact_location='file:///Users/serena.ruan/Documents/repos/mlflow-3-doc/mlruns/0/models/0d789084-9a3b-4b85-9d43-6a148c014b7e/artifacts', creation_timestamp=1743734016290, experiment_id='0', last_updated_timestamp=1743734022728, metrics=[<Metric: dataset_digest='1f1c13b5', dataset_name='train', key='accuracy', model_id='0d789084-9a3b-4b85-9d43-6a148c014b7e', run_id='12f143a7fda1461e9240d7ffad4ea5bd', step=0, timestamp=1743734022737, value=0.3>], model_id='0d789084-9a3b-4b85-9d43-6a148c014b7e', model_type='', model_uri='models:/0d789084-9a3b-4b85-9d43-6a148c014b7e', name='torch-iris-0', params={}, source_run_id='12f143a7fda1461e9240d7ffad4ea5bd', status=<LoggedModelStatus.READY: 'READY'>, status_message='', tags={'mlflow.source.git.commit': '7324c807f07a1766d4b951733e3d723504b4576e',
#  'mlflow.source.name': 'a.py',
#  'mlflow.source.type': 'LOCAL',
#  'mlflow.user': 'serena.ruan'}>
# [<Metric: dataset_digest='1f1c13b5', dataset_name='train', key='accuracy', model_id='0d789084-9a3b-4b85-9d43-6a148c014b7e', run_id='12f143a7fda1461e9240d7ffad4ea5bd', step=0, timestamp=1743734022737, value=0.3>]
Artifacts of the model can be viewed on the Artifacts tab of the model page:
