from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import ALREADY_EXISTS
from mlflow.utils.annotations import experimental
class DependenciesSchemasType(Enum):
"""
Enum to define the different types of dependencies schemas for the model.
"""
RETRIEVERS = "retrievers"
[docs]@experimental
def set_retriever_schema(
*,
primary_key: str,
text_column: str,
doc_uri: Optional[str] = None,
other_columns: Optional[List[str]] = None,
name: Optional[str] = "retriever",
):
"""
After defining your vector store in a Python file or notebook, call
set_retriever_schema() so that, when MLflow retrieves documents during
model inference, MLflow can interpret the fields in each retrieved document and
determine which fields correspond to the document text, document URI, etc.
Args:
primary_key: The primary key of the retriever or vector index.
text_column: The name of the text column to use for the embeddings.
doc_uri: The name of the column that contains the document URI.
other_columns: A list of other columns that are part of the vector index
that need to be retrieved during trace logging.
name: The name of the retriever tool or vector store index.
.. code-block:: Python
:caption: Example
from mlflow.models import set_retriever_schema
set_retriever_schema(
primary_key="chunk_id",
text_column="chunk_text",
doc_uri="doc_uri",
other_columns=["title"],
)
"""
retriever_schema = globals().get(DependenciesSchemasType.RETRIEVERS.value, [])
# Check if a retriever schema with the same name already exists
if any(schema["name"] == name for schema in retriever_schema):
# reset if there is an error to clear the global state for next run
_clear_retriever_schema()
raise MlflowException(
f"A retriever schema with the name '{name}' already exists.",
error_code=ALREADY_EXISTS,
)
retriever_schema.append(
{
"primary_key": primary_key,
"text_column": text_column,
"doc_uri": doc_uri,
"other_columns": other_columns or [],
"name": name,
}
)
globals()[DependenciesSchemasType.RETRIEVERS.value] = retriever_schema
def _get_retriever_schema():
"""
Get the vector search schema defined by the user.
Returns:
VectorSearchIndex: The vector search index schema.
"""
retriever_schema_list = globals().get(DependenciesSchemasType.RETRIEVERS.value, [])
if not retriever_schema_list:
return []
return [
RetrieverSchema(
name=retriever.get("name"),
primary_key=retriever.get("primary_key"),
text_column=retriever.get("text_column"),
doc_uri=retriever.get("doc_uri"),
other_columns=retriever.get("other_columns"),
)
for retriever in retriever_schema_list
]
def _clear_retriever_schema():
"""
Clear the vector search schema defined by the user.
"""
globals().pop(DependenciesSchemasType.RETRIEVERS.value, None)
def _clear_dependencies_schemas():
"""
Clear all the dependencies schema defined by the user.
"""
# Clear the vector search schema
_clear_retriever_schema()
@contextmanager
def _get_dependencies_schemas():
dependencies_schemas = DependenciesSchemas(retriever_schemas=_get_retriever_schema())
try:
yield dependencies_schemas
finally:
_clear_dependencies_schemas()
@dataclass
class Schema(ABC):
"""
Base class for defining the resources needed to serve a model.
Args:
type (ResourceType): The type of the schema.
"""
type: DependenciesSchemasType
@abstractmethod
def to_dict(self):
"""
Convert the resource to a dictionary.
Subclasses must implement this method.
"""
@classmethod
@abstractmethod
def from_dict(cls, data: Dict[str, str]):
"""
Convert the dictionary to a Resource.
Subclasses must implement this method.
"""
@dataclass
class RetrieverSchema(Schema):
"""
Define vector search index resource to serve a model.
Args:
name (str): The name of the vector search index schema.
primary_key (str): The primary key for the index.
text_column (str): The main text column for the index.
doc_uri (Optional[str]): The document URI for the index.
other_columns (Optional[List[str]]): Additional columns in the index.
"""
def __init__(
self,
name: str,
primary_key: str,
text_column: str,
doc_uri: Optional[str] = None,
other_columns: Optional[List[str]] = None,
):
super().__init__(type=DependenciesSchemasType.RETRIEVERS)
self.name = name
self.primary_key = primary_key
self.text_column = text_column
self.doc_uri = doc_uri
self.other_columns = other_columns or []
def to_dict(self):
return {
self.type.value: [
{
"name": self.name,
"primary_key": self.primary_key,
"text_column": self.text_column,
"doc_uri": self.doc_uri,
"other_columns": self.other_columns,
}
]
}
@classmethod
def from_dict(cls, data: Dict[str, str]):
return cls(
name=data["name"],
primary_key=data["primary_key"],
text_column=data["text_column"],
doc_uri=data.get("doc_uri"),
other_columns=data.get("other_columns", []),
)
@dataclass
class DependenciesSchemas:
retriever_schemas: List[RetrieverSchema] = field(default_factory=list)
def to_dict(self) -> Dict[str, Dict[DependenciesSchemasType, List[Dict]]]:
if not self.retriever_schemas:
return None
return {
"dependencies_schemas": {
DependenciesSchemasType.RETRIEVERS.value: [
index.to_dict()[DependenciesSchemasType.RETRIEVERS.value][0]
for index in self.retriever_schemas
],
}
}