from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from mlflow.entities._mlflow_object import _MlflowObject
from mlflow.protos.service_pb2 import FallbackConfig as ProtoFallbackConfig
from mlflow.protos.service_pb2 import FallbackStrategy as ProtoFallbackStrategy
from mlflow.protos.service_pb2 import (
GatewayEndpoint as ProtoGatewayEndpoint,
)
from mlflow.protos.service_pb2 import (
GatewayEndpointBinding as ProtoGatewayEndpointBinding,
)
from mlflow.protos.service_pb2 import (
GatewayEndpointModelConfig as ProtoGatewayEndpointModelConfig,
)
from mlflow.protos.service_pb2 import (
GatewayEndpointModelMapping as ProtoGatewayEndpointModelMapping,
)
from mlflow.protos.service_pb2 import (
GatewayModelDefinition as ProtoGatewayModelDefinition,
)
from mlflow.protos.service_pb2 import GatewayModelLinkageType as ProtoGatewayModelLinkageType
from mlflow.protos.service_pb2 import RoutingStrategy as ProtoRoutingStrategy
[docs]class GatewayResourceType(str, Enum):
"""Valid MLflow resource types that can use gateway endpoints."""
SCORER = "scorer"
[docs]class RoutingStrategy(str, Enum):
"""Routing strategy for gateway endpoints."""
REQUEST_BASED_TRAFFIC_SPLIT = "REQUEST_BASED_TRAFFIC_SPLIT"
[docs] @classmethod
def from_proto(cls, proto: ProtoRoutingStrategy) -> "RoutingStrategy":
try:
return cls(ProtoRoutingStrategy.Name(proto))
except ValueError:
# unspecified in proto is treated as None
return None
[docs] def to_proto(self) -> ProtoRoutingStrategy:
return ProtoRoutingStrategy.Value(self.value)
[docs]class FallbackStrategy(str, Enum):
"""Fallback strategy for routing."""
SEQUENTIAL = "SEQUENTIAL"
[docs] @classmethod
def from_proto(cls, proto: ProtoFallbackStrategy) -> "FallbackStrategy":
try:
return cls(ProtoFallbackStrategy.Name(proto))
except ValueError:
# unspecified in proto is treated as None
return None
[docs] def to_proto(self) -> ProtoFallbackStrategy:
return ProtoFallbackStrategy.Value(self.value)
[docs]class GatewayModelLinkageType(str, Enum):
"""Type of linkage between endpoint and model definition."""
PRIMARY = "PRIMARY"
FALLBACK = "FALLBACK"
[docs] @classmethod
def from_proto(cls, proto: ProtoGatewayModelLinkageType) -> "GatewayModelLinkageType":
try:
return cls(ProtoGatewayModelLinkageType.Name(proto))
except ValueError:
# unspecified in proto is treated as None
return None
[docs] def to_proto(self) -> ProtoGatewayModelLinkageType:
return ProtoGatewayModelLinkageType.Value(self.value)
[docs]@dataclass
class FallbackConfig(_MlflowObject):
"""
Configuration for fallback routing strategy.
Defines how requests should be routed across multiple models when using
fallback routing. Fallback models are defined via GatewayEndpointModelMapping
with linkage_type=FALLBACK and ordered by fallback_order.
Args:
strategy: The fallback strategy to use (e.g., FallbackStrategy.SEQUENTIAL).
max_attempts: Maximum number of fallback models to try (None = try all).
"""
strategy: FallbackStrategy | None = None
max_attempts: int | None = None
[docs] def to_proto(self) -> ProtoFallbackConfig:
proto = ProtoFallbackConfig()
if self.strategy is not None:
proto.strategy = self.strategy.to_proto()
if self.max_attempts is not None:
proto.max_attempts = self.max_attempts
return proto
[docs] @classmethod
def from_proto(cls, proto: ProtoFallbackConfig) -> "FallbackConfig":
strategy = (
FallbackStrategy.from_proto(proto.strategy) if proto.HasField("strategy") else None
)
return cls(
strategy=strategy,
max_attempts=proto.max_attempts,
)
[docs]@dataclass
class GatewayEndpointModelConfig(_MlflowObject):
"""
Configuration for a model attached to an endpoint.
This structured object combines all configuration needed to attach a model
to an endpoint, including the model definition ID, linkage type, weight,
and fallback order.
Args:
model_definition_id: ID of the model definition to attach.
linkage_type: Type of linkage (PRIMARY or FALLBACK).
weight: Routing weight for traffic distribution (default 1.0).
fallback_order: Order for fallback attempts (only for FALLBACK linkages, None for PRIMARY).
"""
model_definition_id: str
linkage_type: GatewayModelLinkageType
weight: float = 1.0
fallback_order: int | None = None
[docs] def to_proto(self) -> ProtoGatewayEndpointModelConfig:
proto = ProtoGatewayEndpointModelConfig()
proto.model_definition_id = self.model_definition_id
proto.linkage_type = self.linkage_type.to_proto()
proto.weight = self.weight
if self.fallback_order is not None:
proto.fallback_order = self.fallback_order
return proto
[docs] @classmethod
def from_proto(cls, proto: ProtoGatewayEndpointModelConfig) -> "GatewayEndpointModelConfig":
return cls(
model_definition_id=proto.model_definition_id,
linkage_type=GatewayModelLinkageType.from_proto(proto.linkage_type),
weight=proto.weight if proto.HasField("weight") else 1.0,
fallback_order=proto.fallback_order if proto.HasField("fallback_order") else None,
)
[docs]@dataclass
class GatewayModelDefinition(_MlflowObject):
"""
Represents a reusable LLM model configuration.
Model definitions can be shared across multiple endpoints, enabling
centralized management of model configurations and API credentials.
Args:
model_definition_id: Unique identifier for this model definition.
name: User-friendly name for identification and reuse.
secret_id: ID of the secret containing authentication credentials (None if orphaned).
secret_name: Name of the secret for display/reference purposes (None if orphaned).
provider: LLM provider (e.g., "openai", "anthropic", "cohere", "bedrock").
model_name: Provider-specific model identifier (e.g., "gpt-4o", "claude-3-5-sonnet").
created_at: Timestamp (milliseconds) when the model definition was created.
last_updated_at: Timestamp (milliseconds) when the model definition was last updated.
created_by: User ID who created the model definition.
last_updated_by: User ID who last updated the model definition.
"""
model_definition_id: str
name: str
secret_id: str | None
secret_name: str | None
provider: str
model_name: str
created_at: int
last_updated_at: int
created_by: str | None = None
last_updated_by: str | None = None
[docs] def to_proto(self):
proto = ProtoGatewayModelDefinition()
proto.model_definition_id = self.model_definition_id
proto.name = self.name
if self.secret_id is not None:
proto.secret_id = self.secret_id
if self.secret_name is not None:
proto.secret_name = self.secret_name
proto.provider = self.provider
proto.model_name = self.model_name
proto.created_at = self.created_at
proto.last_updated_at = self.last_updated_at
if self.created_by is not None:
proto.created_by = self.created_by
if self.last_updated_by is not None:
proto.last_updated_by = self.last_updated_by
return proto
[docs] @classmethod
def from_proto(cls, proto):
return cls(
model_definition_id=proto.model_definition_id,
name=proto.name,
secret_id=proto.secret_id or None,
secret_name=proto.secret_name or None,
provider=proto.provider,
model_name=proto.model_name,
created_at=proto.created_at,
last_updated_at=proto.last_updated_at,
created_by=proto.created_by or None,
last_updated_by=proto.last_updated_by or None,
)
[docs]@dataclass
class GatewayEndpointModelMapping(_MlflowObject):
"""
Represents a mapping between an endpoint and a model definition.
This is a junction entity that links endpoints to model definitions,
enabling many-to-many relationships and traffic routing configuration.
Args:
mapping_id: Unique identifier for this mapping.
endpoint_id: ID of the endpoint.
model_definition_id: ID of the model definition.
model_definition: The full model definition (populated via JOIN).
weight: Routing weight for traffic distribution (default 1).
linkage_type: Type of linkage (PRIMARY or FALLBACK).
fallback_order: Zero-indexed order for fallback attempts (only for FALLBACK linkages)
created_at: Timestamp (milliseconds) when the mapping was created.
created_by: User ID who created the mapping.
"""
mapping_id: str
endpoint_id: str
model_definition_id: str
model_definition: GatewayModelDefinition | None
weight: float
linkage_type: GatewayModelLinkageType
fallback_order: int | None
created_at: int
created_by: str | None = None
[docs] def to_proto(self):
proto = ProtoGatewayEndpointModelMapping()
proto.mapping_id = self.mapping_id
proto.endpoint_id = self.endpoint_id
proto.model_definition_id = self.model_definition_id
if self.model_definition is not None:
proto.model_definition.CopyFrom(self.model_definition.to_proto())
proto.weight = self.weight
proto.linkage_type = self.linkage_type.to_proto()
if self.fallback_order is not None:
proto.fallback_order = self.fallback_order
proto.created_at = self.created_at
if self.created_by is not None:
proto.created_by = self.created_by
return proto
[docs] @classmethod
def from_proto(cls, proto):
model_def = None
if proto.HasField("model_definition"):
model_def = GatewayModelDefinition.from_proto(proto.model_definition)
return cls(
mapping_id=proto.mapping_id,
endpoint_id=proto.endpoint_id,
model_definition_id=proto.model_definition_id,
model_definition=model_def,
weight=proto.weight,
linkage_type=GatewayModelLinkageType.from_proto(proto.linkage_type),
fallback_order=proto.fallback_order if proto.HasField("fallback_order") else None,
created_at=proto.created_at,
created_by=proto.created_by or None,
)
[docs]@dataclass
class GatewayEndpointTag(_MlflowObject):
"""
Represents a tag (key-value pair) associated with a gateway endpoint.
Tags are used for categorization, filtering, and metadata storage for endpoints.
Args:
key: Tag key (max 250 characters).
value: Tag value (max 5000 characters, can be None).
"""
key: str
value: str | None
[docs] def to_proto(self):
from mlflow.protos.service_pb2 import GatewayEndpointTag as ProtoGatewayEndpointTag
proto = ProtoGatewayEndpointTag()
proto.key = self.key
if self.value is not None:
proto.value = self.value
return proto
[docs] @classmethod
def from_proto(cls, proto):
return cls(
key=proto.key,
value=proto.value or None,
)
[docs]@dataclass
class GatewayEndpoint(_MlflowObject):
"""
Represents an LLM gateway endpoint with its associated model configurations.
Args:
endpoint_id: Unique identifier for this endpoint.
name: User-friendly name for the endpoint (optional).
created_at: Timestamp (milliseconds) when the endpoint was created.
last_updated_at: Timestamp (milliseconds) when the endpoint was last updated.
model_mappings: List of model mappings bound to this endpoint.
tags: List of tags associated with this endpoint.
created_by: User ID who created the endpoint.
last_updated_by: User ID who last updated the endpoint.
routing_strategy: Routing strategy for the endpoint (e.g., "FALLBACK").
fallback_config: Fallback configuration entity (if routing_strategy is FALLBACK).
"""
endpoint_id: str
name: str | None
created_at: int
last_updated_at: int
model_mappings: list[GatewayEndpointModelMapping] = field(default_factory=list)
tags: list["GatewayEndpointTag"] = field(default_factory=list)
created_by: str | None = None
last_updated_by: str | None = None
routing_strategy: RoutingStrategy | None = None
fallback_config: FallbackConfig | None = None
[docs] def to_proto(self):
proto = ProtoGatewayEndpoint()
proto.endpoint_id = self.endpoint_id
proto.name = self.name or ""
proto.created_at = self.created_at
proto.last_updated_at = self.last_updated_at
proto.model_mappings.extend([m.to_proto() for m in self.model_mappings])
proto.tags.extend([t.to_proto() for t in self.tags])
proto.created_by = self.created_by or ""
proto.last_updated_by = self.last_updated_by or ""
if self.routing_strategy:
proto.routing_strategy = ProtoRoutingStrategy.Value(self.routing_strategy.value)
if self.fallback_config:
proto.fallback_config.CopyFrom(self.fallback_config.to_proto())
return proto
[docs] @classmethod
def from_proto(cls, proto):
routing_strategy = None
if proto.HasField("routing_strategy"):
strategy_name = ProtoRoutingStrategy.Name(proto.routing_strategy)
routing_strategy = RoutingStrategy(strategy_name)
fallback_config = None
if proto.HasField("fallback_config"):
fallback_config = FallbackConfig.from_proto(proto.fallback_config)
return cls(
endpoint_id=proto.endpoint_id,
name=proto.name or None,
created_at=proto.created_at,
last_updated_at=proto.last_updated_at,
model_mappings=[
GatewayEndpointModelMapping.from_proto(m) for m in proto.model_mappings
],
tags=[GatewayEndpointTag.from_proto(t) for t in proto.tags],
created_by=proto.created_by or None,
last_updated_by=proto.last_updated_by or None,
routing_strategy=routing_strategy,
fallback_config=fallback_config,
)
[docs]@dataclass
class GatewayEndpointBinding(_MlflowObject):
"""
Represents a binding between an endpoint and an MLflow resource.
Bindings track which MLflow resources (e.g., scorer jobs) are configured to use
which endpoints. The composite key (endpoint_id, resource_type, resource_id) uniquely
identifies each binding.
Args:
endpoint_id: ID of the endpoint this binding references.
resource_type: Type of MLflow resource (e.g., "scorer").
resource_id: ID of the specific resource instance.
created_at: Timestamp (milliseconds) when the binding was created.
last_updated_at: Timestamp (milliseconds) when the binding was last updated.
created_by: User ID who created the binding.
last_updated_by: User ID who last updated the binding.
display_name: Human-readable display name for the resource (e.g., scorer name).
"""
endpoint_id: str
resource_type: GatewayResourceType
resource_id: str
created_at: int
last_updated_at: int
created_by: str | None = None
last_updated_by: str | None = None
display_name: str | None = None
[docs] def to_proto(self):
proto = ProtoGatewayEndpointBinding()
proto.endpoint_id = self.endpoint_id
proto.resource_type = self.resource_type.value
proto.resource_id = self.resource_id
proto.created_at = self.created_at
proto.last_updated_at = self.last_updated_at
if self.created_by is not None:
proto.created_by = self.created_by
if self.last_updated_by is not None:
proto.last_updated_by = self.last_updated_by
if self.display_name is not None:
proto.display_name = self.display_name
return proto
[docs] @classmethod
def from_proto(cls, proto):
return cls(
endpoint_id=proto.endpoint_id,
resource_type=GatewayResourceType(proto.resource_type),
resource_id=proto.resource_id,
created_at=proto.created_at,
last_updated_at=proto.last_updated_at,
created_by=proto.created_by or None,
last_updated_by=proto.last_updated_by or None,
display_name=proto.display_name or None,
)