The mlflow.tensorflow module provides an API for logging and loading TensorFlow models as mlflow.pyfunc models.

You must save your own saved_model and pass its path to log_saved_model(saved_model_dir). To load the model to predict on it, you call model = pyfunc.load_pyfunc(saved_model_dir) followed by prediction = model.predict(pandas DataFrame) to obtain a prediction in a pandas DataFrame.

The loaded mlflow.pyfunc model does not expose any APIs for model training.

mlflow.tensorflow.log_saved_model(saved_model_dir, signature_def_key, artifact_path)

Log a TensorFlow model as an MLflow artifact for the current run.

  • saved_model_dir – Directory where the TensorFlow model is saved.
  • signature_def_key – The signature definition to use when loading the model again. See SignatureDefs in SavedModel for TensorFlow Serving for details.
  • artifact_path – Path (within the artifact directory for the current run) to which artifacts of the model are saved.