import pathlib
from enum import Enum
import json
import logging
import os
from pathlib import Path
from pydantic import validator, root_validator, parse_obj_as, ValidationError
from pydantic.json import pydantic_encoder
from typing import Optional, Union, List, Dict, Any
import yaml
from mlflow.exceptions import MlflowException
from mlflow.gateway.base_models import ConfigModel, ResponseModel
from mlflow.gateway.utils import is_valid_endpoint_name, check_configuration_route_name_collisions
from mlflow.gateway.constants import MLFLOW_GATEWAY_ROUTE_BASE, MLFLOW_QUERY_SUFFIX
_logger = logging.getLogger(__name__)
[docs]class Provider(str, Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
# Note: Databricks Model Serving is only supported on Databricks
DATABRICKS_MODEL_SERVING = "databricks-model-serving"
COHERE = "cohere"
[docs] @classmethod
def values(cls):
return {p.value for p in cls}
[docs]class RouteType(str, Enum):
LLM_V1_COMPLETIONS = "llm/v1/completions"
LLM_V1_CHAT = "llm/v1/chat"
LLM_V1_EMBEDDINGS = "llm/v1/embeddings"
[docs]class CohereConfig(ConfigModel):
cohere_api_key: str
# pylint: disable=no-self-argument
[docs] @validator("cohere_api_key", pre=True)
def validate_cohere_api_key(cls, value):
return _resolve_api_key_from_input(value)
[docs]class OpenAIAPIType(str, Enum):
OPENAI = "openai"
AZURE = "azure"
AZUREAD = "azuread"
@classmethod
def _missing_(cls, value):
"""
Implements case-insensitive matching of API type strings
"""
for api_type in cls:
if api_type.value == value.lower():
return api_type
raise MlflowException.invalid_parameter_value(f"Invalid OpenAI API type '{value}'")
[docs]class OpenAIConfig(ConfigModel):
openai_api_key: str
openai_api_type: OpenAIAPIType = OpenAIAPIType.OPENAI
openai_api_base: Optional[str] = None
openai_api_version: Optional[str] = None
openai_deployment_name: Optional[str] = None
openai_organization: Optional[str] = None
# pylint: disable=no-self-argument
[docs] @validator("openai_api_key", pre=True)
def validate_openai_api_key(cls, value):
return _resolve_api_key_from_input(value)
[docs] @root_validator(pre=False)
def validate_field_compatibility(cls, config: Dict[str, Any]):
api_type = config.get("openai_api_type")
if api_type == OpenAIAPIType.OPENAI:
if config.get("openai_deployment_name") is not None:
raise MlflowException.invalid_parameter_value(
f"OpenAI route configuration can only specify a value for "
f"'openai_deployment_name' if 'openai_api_type' is '{OpenAIAPIType.AZURE}' "
f"or '{OpenAIAPIType.AZUREAD}'. Found type: '{api_type}'"
)
if config.get("openai_api_base") is None:
config["openai_api_base"] = "https://api.openai.com/v1"
elif api_type in (OpenAIAPIType.AZURE, OpenAIAPIType.AZUREAD):
if config.get("openai_organization") is not None:
raise MlflowException.invalid_parameter_value(
f"OpenAI route configuration can only specify a value for "
f"'openai_organization' if 'openai_api_type' is '{OpenAIAPIType.OPENAI}'"
)
base_url = config.get("openai_api_base")
deployment_name = config.get("openai_deployment_name")
api_version = config.get("openai_api_version")
if (base_url, deployment_name, api_version).count(None) > 0:
raise MlflowException.invalid_parameter_value(
f"OpenAI route configuration must specify 'openai_api_base', "
f"'openai_deployment_name', and 'openai_api_version' if 'openai_api_type' is "
f"'{OpenAIAPIType.AZURE}' or '{OpenAIAPIType.AZUREAD}'."
)
else:
raise MlflowException.invalid_parameter_value(f"Invalid OpenAI API type '{api_type}'")
return config
[docs]class AnthropicConfig(ConfigModel):
anthropic_api_key: str
# pylint: disable=no-self-argument
[docs] @validator("anthropic_api_key", pre=True)
def validate_anthropic_api_key(cls, value):
return _resolve_api_key_from_input(value)
config_types = {
Provider.COHERE: CohereConfig,
Provider.OPENAI: OpenAIConfig,
Provider.ANTHROPIC: AnthropicConfig,
}
[docs]class ModelInfo(ResponseModel):
name: Optional[str] = None
provider: Provider
def _resolve_api_key_from_input(api_key_input):
"""
Resolves the provided API key.
Input formats accepted:
- Path to a file as a string which will have the key loaded from it
- environment variable name that stores the api key
- the api key itself
"""
if not isinstance(api_key_input, str):
raise MlflowException.invalid_parameter_value(
"The api key provided is not a string. Please provide either an environment "
"variable key, a path to a file containing the api key, or the api key itself"
)
# try reading as an environment variable
if api_key_input.startswith("$"):
env_var_name = api_key_input[1:]
if env_var := os.getenv(env_var_name):
return env_var
else:
raise MlflowException.invalid_parameter_value(
f"Environment variable {env_var_name!r} is not set"
)
# try reading from a local path
file = pathlib.Path(api_key_input)
if file.is_file():
return file.read_text()
# if the key itself is passed, return
return api_key_input
# pylint: disable=no-self-argument
[docs]class Model(ConfigModel):
name: Optional[str] = None
provider: Union[str, Provider]
config: Optional[
Union[
CohereConfig,
OpenAIConfig,
AnthropicConfig,
]
] = None
[docs] @validator("provider", pre=True)
def validate_provider(cls, value):
if isinstance(value, Provider):
return value
if value.upper() in Provider.__members__:
return Provider[value.upper()]
raise MlflowException.invalid_parameter_value(f"The provider '{value}' is not supported.")
[docs] @validator("config", pre=True)
def validate_config(cls, config, values):
if provider := values.get("provider"):
config_type = config_types[provider]
return config_type(**config)
raise MlflowException.invalid_parameter_value(
"A provider must be provided for each gateway route."
)
# pylint: disable=no-self-argument
[docs]class RouteConfig(ConfigModel):
name: str
route_type: RouteType
model: Model
[docs] @validator("name")
def validate_endpoint_name(cls, route_name):
if not is_valid_endpoint_name(route_name):
raise MlflowException.invalid_parameter_value(
"The route name provided contains disallowed characters for a url endpoint. "
f"'{route_name}' is invalid. Names cannot contain spaces or any non "
"alphanumeric characters other than hyphen and underscore."
)
return route_name
[docs] @validator("model", pre=True)
def validate_model(cls, model):
if model:
model_instance = Model(**model)
if model_instance.provider in Provider.values() and model_instance.config is None:
raise MlflowException.invalid_parameter_value(
"A config must be supplied when setting a provider. The provider entry for "
f"{model_instance.provider} is incorrect."
)
return model
[docs] @validator("route_type", pre=True)
def validate_route_type(cls, value):
if value in RouteType._value2member_map_:
return value
raise MlflowException.invalid_parameter_value(f"The route_type '{value}' is not supported.")
[docs] def to_route(self) -> "Route":
return Route(
name=self.name,
route_type=self.route_type,
model=ModelInfo(
name=self.model.name,
provider=self.model.provider,
),
route_url=f"{MLFLOW_GATEWAY_ROUTE_BASE}{self.name}{MLFLOW_QUERY_SUFFIX}",
)
[docs]class Route(ResponseModel):
name: str
route_type: RouteType
model: ModelInfo
route_url: str
[docs] class Config:
schema_extra = {
"example": {
"name": "openai-completions",
"route_type": "llm/v1/completions",
"model": {
"name": "gpt-3.5-turbo",
"provider": "openai",
},
"route_url": "/gateway/routes/completions/invocations",
}
}
[docs]class GatewayConfig(ConfigModel):
routes: List[RouteConfig]
def _load_route_config(path: Union[str, Path]) -> GatewayConfig:
"""
Reads the gateway configuration yaml file from the storage location and returns an instance
of the configuration RouteConfig class
"""
if isinstance(path, str):
path = Path(path)
try:
configuration = yaml.safe_load(path.read_text())
except Exception as e:
raise MlflowException.invalid_parameter_value(
f"The file at {path} is not a valid yaml file"
) from e
check_configuration_route_name_collisions(configuration)
try:
return parse_obj_as(GatewayConfig, configuration)
except ValidationError as e:
raise MlflowException.invalid_parameter_value(
f"The gateway configuration is invalid: {e}"
) from e
def _save_route_config(config: GatewayConfig, path: Union[str, Path]) -> None:
if isinstance(path, str):
path = Path(path)
path.write_text(yaml.safe_dump(json.loads(json.dumps(config.dict(), default=pydantic_encoder))))
def _validate_config(config_path: str) -> GatewayConfig:
if not os.path.exists(config_path):
raise MlflowException.invalid_parameter_value(f"{config_path} does not exist")
try:
return _load_route_config(config_path)
except ValidationError as e:
raise MlflowException.invalid_parameter_value(f"Invalid gateway configuration: {e}") from e