from tensorflow import keras
from tensorflow.keras.callbacks import Callback
from mlflow import log_metrics, log_params, log_text
from mlflow.utils.autologging_utils import ExceptionSafeClass
from mlflow.utils.checkpoint_utils import MlflowModelCheckpointCallbackBase
[docs]class MlflowCallback(keras.callbacks.Callback, metaclass=ExceptionSafeClass):
    """Callback for logging Tensorflow training metrics to MLflow.
    This callback logs model information at training start, and logs training metrics every epoch or
    every n steps (defined by the user) to MLflow.
    Args:
        log_every_epoch: bool, If True, log metrics every epoch. If False, log metrics every n
            steps.
        log_every_n_steps: int, log metrics every n steps. If None, log metrics every epoch.
            Must be `None` if `log_every_epoch=True`.
    .. code-block:: python
        :caption: Example
        from tensorflow import keras
        import mlflow
        import numpy as np
        # Prepare data for a 2-class classification.
        data = tf.random.uniform([8, 28, 28, 3])
        label = tf.convert_to_tensor(np.random.randint(2, size=8))
        model = keras.Sequential(
            [
                keras.Input([28, 28, 3]),
                keras.layers.Flatten(),
                keras.layers.Dense(2),
            ]
        )
        model.compile(
            loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            optimizer=keras.optimizers.Adam(0.001),
            metrics=[keras.metrics.SparseCategoricalAccuracy()],
        )
        with mlflow.start_run() as run:
            model.fit(
                data,
                label,
                batch_size=4,
                epochs=2,
                callbacks=[mlflow.keras.MlflowCallback(run)],
            )
    """
    def __init__(self, log_every_epoch=True, log_every_n_steps=None):
        self.log_every_epoch = log_every_epoch
        self.log_every_n_steps = log_every_n_steps
        if log_every_epoch and log_every_n_steps is not None:
            raise ValueError(
                "`log_every_n_steps` must be None if `log_every_epoch=True`, received "
                f"`log_every_epoch={log_every_epoch}` and `log_every_n_steps={log_every_n_steps}`."
            )
        if not log_every_epoch and log_every_n_steps is None:
            raise ValueError(
                "`log_every_n_steps` must be specified if `log_every_epoch=False`, received"
                "`log_every_n_steps=False` and `log_every_n_steps=None`."
            )
[docs]    def on_train_begin(self, logs=None):
        """Log model architecture and optimizer configuration when training begins."""
        config = self.model.optimizer.get_config()
        log_params({f"opt_{k}": v for k, v in config.items()})
        model_summary = []
        def print_fn(line, *args, **kwargs):
            model_summary.append(line)
        self.model.summary(print_fn=print_fn)
        summary = "\n".join(model_summary)
        log_text(summary, artifact_file="model_summary.txt") 
[docs]    def on_epoch_end(self, epoch, logs=None):
        """Log metrics at the end of each epoch."""
        if not self.log_every_epoch or logs is None:
            return
        log_metrics(logs, step=epoch, synchronous=False) 
[docs]    def on_batch_end(self, batch, logs=None):
        """Log metrics at the end of each batch with user specified frequency."""
        if self.log_every_n_steps is None or logs is None:
            return
        current_iteration = int(self.model.optimizer.iterations.numpy())
        if current_iteration % self.log_every_n_steps == 0:
            log_metrics(logs, step=current_iteration, synchronous=False) 
[docs]    def on_test_end(self, logs=None):
        """Log validation metrics at validation end."""
        if logs is None:
            return
        metrics = {"validation_" + k: v for k, v in logs.items()}
        log_metrics(metrics, synchronous=False)  
class MlflowModelCheckpointCallback(Callback, MlflowModelCheckpointCallbackBase):
    """Callback for automatic Keras model checkpointing to MLflow.
    Args:
        monitor: In automatic model checkpointing, the metric name to monitor if
            you set `model_checkpoint_save_best_only` to True.
        save_best_only: If True, automatic model checkpointing only saves when
            the model is considered the "best" model according to the quantity
            monitored and previous checkpoint model is overwritten.
        mode: one of {"min", "max"}. In automatic model checkpointing,
            if save_best_only=True, the decision to overwrite the current save file is made
            based on either the maximization or the minimization of the monitored quantity.
        save_weights_only: In automatic model checkpointing, if True, then
            only the model's weights will be saved. Otherwise, the optimizer states,
            lr-scheduler states, etc are added in the checkpoint too.
        save_freq: `"epoch"` or integer. When using `"epoch"`, the callback
            saves the model after each epoch. When using integer, the callback
            saves the model at end of this many batches. Note that if the saving isn't
            aligned to epochs, the monitored metric may potentially be less reliable (it
            could reflect as little as 1 batch, since the metrics get reset
            every epoch). Defaults to `"epoch"`.
    .. code-block:: python
        :caption: Example
        from tensorflow import keras
        import tensorflow as tf
        import mlflow
        import numpy as np
        from mlflow.tensorflow import MlflowModelCheckpointCallback
        # Prepare data for a 2-class classification.
        data = tf.random.uniform([8, 28, 28, 3])
        label = tf.convert_to_tensor(np.random.randint(2, size=8))
        model = keras.Sequential(
            [
                keras.Input([28, 28, 3]),
                keras.layers.Flatten(),
                keras.layers.Dense(2),
            ]
        )
        model.compile(
            loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            optimizer=keras.optimizers.Adam(0.001),
            metrics=[keras.metrics.SparseCategoricalAccuracy()],
        )
        mlflow_checkpoint_callback = MlflowModelCheckpointCallback(
            monitor="sparse_categorical_accuracy",
            mode="max",
            save_best_only=True,
            save_weights_only=False,
            save_freq="epoch",
        )
        with mlflow.start_run() as run:
            model.fit(
                data,
                label,
                batch_size=4,
                epochs=2,
                callbacks=[mlflow_checkpoint_callback],
            )
    """
    def __init__(
        self,
        monitor="val_loss",
        mode="min",
        save_best_only=True,
        save_weights_only=False,
        save_freq="epoch",
    ):
        Callback.__init__(self)
        MlflowModelCheckpointCallbackBase.__init__(
            self,
            checkpoint_file_suffix=".h5",
            monitor=monitor,
            mode=mode,
            save_best_only=save_best_only,
            save_weights_only=save_weights_only,
            save_freq=save_freq,
        )
        self.trainer = None
        self.current_epoch = None
        self._last_batch_seen = 0
        self.global_step = 0
        self.global_step_last_saving = 0
    def save_checkpoint(self, filepath: str):
        if self.save_weights_only:
            self.model.save_weights(filepath, overwrite=True)
        else:
            self.model.save(filepath, overwrite=True)
    def on_epoch_begin(self, epoch, logs=None):
        self.current_epoch = epoch
    def on_train_batch_end(self, batch, logs=None):
        # Note that `on_train_batch_end` might be invoked by every N train steps,
        # (controlled by `steps_per_execution` argument in `model.compile` method).
        # the following logic is similar to
        # https://github.com/keras-team/keras/blob/e6e62405fa1b4444102601636d871610d91e5783/keras/callbacks/model_checkpoint.py#L212
        add_batches = batch + 1 if batch <= self._last_batch_seen else batch - self._last_batch_seen
        self._last_batch_seen = batch
        self.global_step += add_batches
        if isinstance(self.save_freq, int):
            if self.global_step - self.global_step_last_saving >= self.save_freq:
                self.check_and_save_checkpoint_if_needed(
                    current_epoch=self.current_epoch,
                    global_step=self.global_step,
                    metric_dict={k: float(v) for k, v in logs.items()},
                )
                self.global_step_last_saving = self.global_step
    def on_epoch_end(self, epoch, logs=None):
        if self.save_freq == "epoch":
            self.check_and_save_checkpoint_if_needed(
                current_epoch=self.current_epoch,
                global_step=self.global_step,
                metric_dict={k: float(v) for k, v in logs.items()},
            )