Source code for mlflow.entities.gateway_guardrail

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum

from mlflow.entities._mlflow_object import _MlflowObject
from mlflow.entities.scorer import ScorerVersion
from mlflow.protos.service_pb2 import GatewayGuardrail as ProtoGatewayGuardrail
from mlflow.protos.service_pb2 import GatewayGuardrailConfig as ProtoGatewayGuardrailConfig
from mlflow.protos.service_pb2 import GuardrailAction as ProtoGuardrailAction
from mlflow.protos.service_pb2 import GuardrailStage as ProtoGuardrailStage
from mlflow.utils.workspace_utils import resolve_entity_workspace_name


[docs]class GuardrailStage(str, Enum): BEFORE = "BEFORE" AFTER = "AFTER" def __str__(self) -> str: return self.value
[docs] @classmethod def from_proto(cls, proto: ProtoGuardrailStage) -> GuardrailStage: return cls(ProtoGuardrailStage.Name(proto))
[docs] def to_proto(self) -> ProtoGuardrailStage: return ProtoGuardrailStage.Value(self.value)
[docs]class GuardrailAction(str, Enum): VALIDATION = "VALIDATION" SANITIZATION = "SANITIZATION" def __str__(self) -> str: return self.value
[docs] @classmethod def from_proto(cls, proto: ProtoGuardrailAction) -> GuardrailAction: return cls(ProtoGuardrailAction.Name(proto))
[docs] def to_proto(self) -> ProtoGuardrailAction: return ProtoGuardrailAction.Value(self.value)
[docs]@dataclass class GatewayGuardrail(_MlflowObject): guardrail_id: str name: str scorer: ScorerVersion stage: GuardrailStage action: GuardrailAction created_at: int last_updated_at: int action_endpoint_name: str | None = None created_by: str | None = None last_updated_by: str | None = None workspace: str | None = None def __post_init__(self): self.workspace = resolve_entity_workspace_name(self.workspace) if isinstance(self.stage, str): self.stage = GuardrailStage(self.stage) if isinstance(self.action, str): self.action = GuardrailAction(self.action)
[docs] def to_proto(self): proto = ProtoGatewayGuardrail() proto.guardrail_id = self.guardrail_id proto.name = self.name proto.scorer.CopyFrom(self.scorer.to_proto()) proto.stage = self.stage.to_proto() proto.action = self.action.to_proto() if self.action_endpoint_name: proto.action_endpoint_id = self.action_endpoint_name proto.created_by = self.created_by or "" proto.created_at = self.created_at proto.last_updated_by = self.last_updated_by or "" proto.last_updated_at = self.last_updated_at return proto
[docs] @classmethod def from_proto(cls, proto): return cls( guardrail_id=proto.guardrail_id, name=proto.name, scorer=ScorerVersion.from_proto(proto.scorer), stage=GuardrailStage.from_proto(proto.stage), action=GuardrailAction.from_proto(proto.action), action_endpoint_name=proto.action_endpoint_id or None, created_by=proto.created_by or None, created_at=proto.created_at, last_updated_by=proto.last_updated_by or None, last_updated_at=proto.last_updated_at, )
[docs]@dataclass class GatewayGuardrailConfig(_MlflowObject): """Junction between a guardrail and a gateway endpoint, with ordering.""" endpoint_id: str guardrail_id: str execution_order: int | None created_at: int guardrail: GatewayGuardrail | None = None created_by: str | None = None workspace: str | None = None
[docs] def to_proto(self): proto = ProtoGatewayGuardrailConfig() proto.endpoint_id = self.endpoint_id proto.guardrail_id = self.guardrail_id if self.execution_order is not None: proto.execution_order = self.execution_order if self.guardrail is not None: proto.guardrail.CopyFrom(self.guardrail.to_proto()) proto.created_by = self.created_by or "" proto.created_at = self.created_at return proto
[docs] @classmethod def from_proto(cls, proto): guardrail = None if proto.HasField("guardrail"): guardrail = GatewayGuardrail.from_proto(proto.guardrail) return cls( endpoint_id=proto.endpoint_id, guardrail_id=proto.guardrail_id, execution_order=proto.execution_order if proto.HasField("execution_order") else None, guardrail=guardrail, created_at=proto.created_at, created_by=proto.created_by or None, )