"""
The ``mlflow.openai`` module provides an API for logging and loading OpenAI models.
Credential management for OpenAI on Databricks
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. warning::
Specifying secrets for model serving with ``MLFLOW_OPENAI_SECRET_SCOPE`` is deprecated.
Use `secrets-based environment variables <https://docs.databricks.com/en/machine-learning/model-serving/store-env-variable-model-serving.html>`_
instead.
When this flavor logs a model on Databricks, it saves a YAML file with the following contents as
``openai.yaml`` if the ``MLFLOW_OPENAI_SECRET_SCOPE`` environment variable is set.
.. code-block:: yaml
OPENAI_API_BASE: {scope}:openai_api_base
OPENAI_API_KEY: {scope}:openai_api_key
OPENAI_API_KEY_PATH: {scope}:openai_api_key_path
OPENAI_API_TYPE: {scope}:openai_api_type
OPENAI_ORGANIZATION: {scope}:openai_organization
- ``{scope}`` is the value of the ``MLFLOW_OPENAI_SECRET_SCOPE`` environment variable.
- The keys are the environment variables that the ``openai-python`` package uses to
configure the API client.
- The values are the references to the secrets that store the values of the environment
variables.
When the logged model is served on Databricks, each secret will be resolved and set as the
corresponding environment variable. See https://docs.databricks.com/security/secrets/index.html
for how to set up secrets on Databricks.
"""
import itertools
import logging
import os
import warnings
from string import Formatter
from typing import Any, Dict, Optional, Set
import yaml
from packaging.version import Version
import mlflow
from mlflow import pyfunc
from mlflow.environment_variables import MLFLOW_OPENAI_SECRET_SCOPE
from mlflow.exceptions import MlflowException
from mlflow.models import Model, ModelInputExample, ModelSignature
from mlflow.models.model import MLMODEL_FILE_NAME
from mlflow.models.utils import _save_example
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.tracking._model_registry import DEFAULT_AWAIT_MAX_SLEEP_SECONDS
from mlflow.tracking.artifact_utils import _download_artifact_from_uri
from mlflow.types import ColSpec, Schema, TensorSpec
from mlflow.utils.annotations import experimental
from mlflow.utils.databricks_utils import (
check_databricks_secret_scope_access,
is_in_databricks_runtime,
)
from mlflow.utils.docstring_utils import LOG_MODEL_PARAM_DOCS, format_docstring
from mlflow.utils.environment import (
_CONDA_ENV_FILE_NAME,
_CONSTRAINTS_FILE_NAME,
_PYTHON_ENV_FILE_NAME,
_REQUIREMENTS_FILE_NAME,
_mlflow_conda_env,
_process_conda_env,
_process_pip_requirements,
_PythonEnv,
_validate_env_arguments,
)
from mlflow.utils.file_utils import write_to
from mlflow.utils.model_utils import (
_add_code_from_conf_to_system_path,
_get_flavor_configuration,
_validate_and_copy_code_paths,
_validate_and_prepare_target_save_path,
)
from mlflow.utils.openai_utils import (
REQUEST_URL_CHAT,
REQUEST_URL_COMPLETIONS,
REQUEST_URL_EMBEDDINGS,
_OAITokenHolder,
_OpenAIApiConfig,
_OpenAIEnvVar,
_validate_model_params,
)
from mlflow.utils.requirements_utils import _get_pinned_requirement
FLAVOR_NAME = "openai"
MODEL_FILENAME = "model.yaml"
_PYFUNC_SUPPORTED_TASKS = ("chat.completions", "embeddings", "completions")
_logger = logging.getLogger(__name__)
[docs]@experimental
def get_default_pip_requirements():
"""
Returns:
A list of default pip requirements for MLflow Models produced by this flavor.
Calls to :func:`save_model()` and :func:`log_model()` produce a pip environment
that, at minimum, contains these requirements.
"""
return list(map(_get_pinned_requirement, ["openai", "tiktoken", "tenacity"]))
[docs]@experimental
def get_default_conda_env():
"""
Returns:
The default Conda environment for MLflow Models produced by calls to
:func:`save_model()` and :func:`log_model()`.
"""
return _mlflow_conda_env(additional_pip_deps=get_default_pip_requirements())
def _get_obj_to_task_mapping():
import openai
if Version(_get_openai_package_version()).major < 1:
from openai import api_resources as ar
return {
ar.Audio: ar.Audio.OBJECT_NAME,
ar.ChatCompletion: ar.ChatCompletion.OBJECT_NAME,
ar.Completion: ar.Completion.OBJECT_NAME,
ar.Edit: ar.Edit.OBJECT_NAME,
ar.Deployment: ar.Deployment.OBJECT_NAME,
ar.Embedding: ar.Embedding.OBJECT_NAME,
ar.Engine: ar.Engine.OBJECT_NAME,
ar.File: ar.File.OBJECT_NAME,
ar.Image: ar.Image.OBJECT_NAME,
ar.FineTune: ar.FineTune.OBJECT_NAME,
ar.Model: ar.Model.OBJECT_NAME,
ar.Moderation: "moderations",
}
else:
return {
openai.audio: "audio",
openai.chat.completions: "chat.completions",
openai.completions: "completions",
openai.images.edit: "images.edit",
openai.embeddings: "embeddings",
openai.files: "files",
openai.images: "images",
openai.fine_tuning: "fine_tuning",
openai.moderations: "moderations",
openai.models: "models",
}
def _get_model_name(model):
import openai
if isinstance(model, str):
return model
if Version(_get_openai_package_version()).major < 1 and isinstance(model, openai.Model):
return model.id
raise mlflow.MlflowException(
f"Unsupported model type: {type(model)}", error_code=INVALID_PARAMETER_VALUE
)
def _get_task_name(task):
mapping = _get_obj_to_task_mapping()
if isinstance(task, str):
if task not in mapping.values():
raise mlflow.MlflowException(
f"Unsupported task: {task}", error_code=INVALID_PARAMETER_VALUE
)
return task
else:
task_name = mapping.get(task)
if task_name is None:
raise mlflow.MlflowException(
f"Unsupported task object: {task}", error_code=INVALID_PARAMETER_VALUE
)
return task_name
def _get_api_config() -> _OpenAIApiConfig:
"""Gets the parameters and configuration of the OpenAI API connected to."""
import openai
api_type = os.getenv(_OpenAIEnvVar.OPENAI_API_TYPE.value, openai.api_type)
api_version = os.getenv(_OpenAIEnvVar.OPENAI_API_VERSION.value, openai.api_version)
api_base = os.getenv(_OpenAIEnvVar.OPENAI_API_BASE.value, None)
engine = os.getenv(_OpenAIEnvVar.OPENAI_ENGINE.value, None)
deployment_id = os.getenv(_OpenAIEnvVar.OPENAI_DEPLOYMENT_NAME.value, None)
if api_type in ("azure", "azure_ad", "azuread"):
batch_size = 16
max_tokens_per_minute = 60_000
else:
# The maximum batch size is 2048:
# https://github.com/openai/openai-python/blob/b82a3f7e4c462a8a10fa445193301a3cefef9a4a/openai/embeddings_utils.py#L43
# We use a smaller batch size to be safe.
batch_size = 1024
max_tokens_per_minute = 90_000
return _OpenAIApiConfig(
api_type=api_type,
batch_size=batch_size,
max_requests_per_minute=3_500,
max_tokens_per_minute=max_tokens_per_minute,
api_base=api_base,
api_version=api_version,
engine=engine,
deployment_id=deployment_id,
)
def _get_openai_package_version():
import openai
try:
return openai.__version__
except AttributeError:
# openai < 0.27.5 doesn't have a __version__ attribute
return openai.version.VERSION
def _log_secrets_yaml(local_model_dir, scope):
with open(os.path.join(local_model_dir, "openai.yaml"), "w") as f:
yaml.safe_dump({e.value: f"{scope}:{e.secret_key}" for e in _OpenAIEnvVar}, f)
def _parse_format_fields(s) -> Set[str]:
"""Parses format fields from a given string, e.g. "Hello {name}" -> ["name"]."""
return {fn for _, fn, _, _ in Formatter().parse(s) if fn is not None}
def _get_input_schema(task, content):
if content:
formatter = _ContentFormatter(task, content)
variables = formatter.variables
if len(variables) == 1:
return Schema([ColSpec(type="string")])
elif len(variables) > 1:
return Schema([ColSpec(name=v, type="string") for v in variables])
else:
return Schema([ColSpec(type="string")])
else:
return Schema([ColSpec(type="string")])
[docs]@experimental
@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
def save_model(
model,
task,
path,
conda_env=None,
code_paths=None,
mlflow_model=None,
signature: ModelSignature = None,
input_example: ModelInputExample = None,
pip_requirements=None,
extra_pip_requirements=None,
metadata=None,
example_no_conversion=False,
**kwargs,
):
"""
Save an OpenAI model to a path on the local file system.
Args:
model: The OpenAI model name.
task: The task the model is performing, e.g., ``openai.chat.completions`` or
``'chat.completions'``.
path: Local path where the model is to be saved.
conda_env: {{ conda_env }}
code_paths: {{ code_paths }}
mlflow_model: :py:mod:`mlflow.models.Model` this flavor is being added to.
signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
from datasets with valid model input (e.g. the training dataset with target
column omitted) and valid model output (e.g. model predictions generated on
the training dataset), for example:
.. code-block:: python
from mlflow.models import infer_signature
train = df.drop_column("target_label")
predictions = ... # compute model predictions
signature = infer_signature(train, predictions)
input_example: {{ input_example }}
pip_requirements: {{ pip_requirements }}
extra_pip_requirements: {{ extra_pip_requirements }}
metadata: Custom metadata dictionary passed to the model and stored in the MLmodel file.
.. Note:: Experimental: This parameter may change or be removed in a future
release without warning.
example_no_conversion: {{ example_no_conversion }}
kwargs: Keyword arguments specific to the OpenAI task, such as the ``messages`` (see
:ref:`mlflow.openai.messages` for more details on this parameter)
or ``top_p`` value to use for chat completion.
.. code-block:: python
import mlflow
import openai
# Chat
mlflow.openai.save_model(
model="gpt-3.5-turbo",
task=openai.chat.completions,
messages=[{"role": "user", "content": "Tell me a joke."}],
path="model",
)
# Completions
mlflow.openai.save_model(
model="text-davinci-002",
task=openai.completions,
prompt="{text}. The general sentiment of the text is",
path="model",
)
# Embeddings
mlflow.openai.save_model(
model="text-embedding-ada-002",
task=openai.embeddings,
path="model",
)
"""
import numpy as np
_validate_env_arguments(conda_env, pip_requirements, extra_pip_requirements)
path = os.path.abspath(path)
_validate_and_prepare_target_save_path(path)
code_dir_subpath = _validate_and_copy_code_paths(code_paths, path)
task = _get_task_name(task)
if mlflow_model is None:
mlflow_model = Model()
if signature is not None:
if signature.params:
_validate_model_params(
task, kwargs, {p.name: p.default for p in signature.params.params}
)
mlflow_model.signature = signature
elif task == "chat.completions":
messages = kwargs.get("messages", [])
if messages and not (
all(isinstance(m, dict) for m in messages) and all(map(_is_valid_message, messages))
):
raise mlflow.MlflowException.invalid_parameter_value(
"If `messages` is provided, it must be a list of dictionaries with keys "
"'role' and 'content'."
)
mlflow_model.signature = ModelSignature(
inputs=_get_input_schema(task, messages),
outputs=Schema([ColSpec(type="string", name=None)]),
)
elif task == "completions":
prompt = kwargs.get("prompt")
mlflow_model.signature = ModelSignature(
inputs=_get_input_schema(task, prompt),
outputs=Schema([ColSpec(type="string", name=None)]),
)
elif task == "embeddings":
mlflow_model.signature = ModelSignature(
inputs=Schema([ColSpec(type="string", name=None)]),
outputs=Schema([TensorSpec(type=np.dtype("float64"), shape=(-1,))]),
)
if input_example is not None:
_save_example(mlflow_model, input_example, path, example_no_conversion)
if metadata is not None:
mlflow_model.metadata = metadata
model_data_path = os.path.join(path, MODEL_FILENAME)
model_dict = {
"model": _get_model_name(model),
"task": task,
**kwargs,
}
with open(model_data_path, "w") as f:
yaml.safe_dump(model_dict, f)
if task in _PYFUNC_SUPPORTED_TASKS:
pyfunc.add_to_model(
mlflow_model,
loader_module="mlflow.openai",
data=MODEL_FILENAME,
conda_env=_CONDA_ENV_FILE_NAME,
python_env=_PYTHON_ENV_FILE_NAME,
code=code_dir_subpath,
)
mlflow_model.add_flavor(
FLAVOR_NAME,
openai_version=_get_openai_package_version(),
data=MODEL_FILENAME,
code=code_dir_subpath,
)
mlflow_model.save(os.path.join(path, MLMODEL_FILE_NAME))
if is_in_databricks_runtime():
if scope := MLFLOW_OPENAI_SECRET_SCOPE.get():
url = "https://docs.databricks.com/en/machine-learning/model-serving/store-env-variable-model-serving.html"
warnings.warn(
"Specifying secrets for model serving with `MLFLOW_OPENAI_SECRET_SCOPE` is "
f"deprecated. Use secrets-based environment variables ({url}) instead.",
FutureWarning,
)
check_databricks_secret_scope_access(scope)
_log_secrets_yaml(path, scope)
if conda_env is None:
if pip_requirements is None:
default_reqs = get_default_pip_requirements()
inferred_reqs = mlflow.models.infer_pip_requirements(
path, FLAVOR_NAME, fallback=default_reqs
)
default_reqs = sorted(set(inferred_reqs).union(default_reqs))
else:
default_reqs = None
conda_env, pip_requirements, pip_constraints = _process_pip_requirements(
default_reqs,
pip_requirements,
extra_pip_requirements,
)
else:
conda_env, pip_requirements, pip_constraints = _process_conda_env(conda_env)
with open(os.path.join(path, _CONDA_ENV_FILE_NAME), "w") as f:
yaml.safe_dump(conda_env, stream=f, default_flow_style=False)
# Save `constraints.txt` if necessary
if pip_constraints:
write_to(os.path.join(path, _CONSTRAINTS_FILE_NAME), "\n".join(pip_constraints))
# Save `requirements.txt`
write_to(os.path.join(path, _REQUIREMENTS_FILE_NAME), "\n".join(pip_requirements))
_PythonEnv.current().to_yaml(os.path.join(path, _PYTHON_ENV_FILE_NAME))
[docs]@experimental
@format_docstring(LOG_MODEL_PARAM_DOCS.format(package_name=FLAVOR_NAME))
def log_model(
model,
task,
artifact_path,
conda_env=None,
code_paths=None,
registered_model_name=None,
signature: ModelSignature = None,
input_example: ModelInputExample = None,
await_registration_for=DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
pip_requirements=None,
extra_pip_requirements=None,
metadata=None,
example_no_conversion=False,
**kwargs,
):
"""
Log an OpenAI model as an MLflow artifact for the current run.
Args:
model: The OpenAI model name or reference instance, e.g.,
``openai.Model.retrieve("gpt-3.5-turbo")``.
task: The task the model is performing, e.g., ``openai.chat.completions`` or
``'chat.completions'``.
artifact_path: Run-relative artifact path.
conda_env: {{ conda_env }}
code_paths: {{ code_paths }}
registered_model_name: If given, create a model version under
``registered_model_name``, also creating a registered model if one
with the given name does not exist.
signature: :py:class:`ModelSignature <mlflow.models.ModelSignature>`
describes model input and output :py:class:`Schema <mlflow.types.Schema>`.
The model signature can be :py:func:`inferred <mlflow.models.infer_signature>`
from datasets with valid model input (e.g. the training dataset with target
column omitted) and valid model output (e.g. model predictions generated on
the training dataset), for example:
.. code-block:: python
from mlflow.models import infer_signature
train = df.drop_column("target_label")
predictions = ... # compute model predictions
signature = infer_signature(train, predictions)
input_example: {{ input_example }}
await_registration_for: Number of seconds to wait for the model version to finish
being created and is in ``READY`` status. By default, the function
waits for five minutes. Specify 0 or None to skip waiting.
pip_requirements: {{ pip_requirements }}
extra_pip_requirements: {{ extra_pip_requirements }}
metadata: Custom metadata dictionary passed to the model and stored in the MLmodel file.
.. Note:: Experimental: This parameter may change or be removed in a future
release without warning.
example_no_conversion: {{ example_no_conversion }}
kwargs: Keyword arguments specific to the OpenAI task, such as the ``messages`` (see
:ref:`mlflow.openai.messages` for more details on this parameter)
or ``top_p`` value to use for chat completion.
Returns:
A :py:class:`ModelInfo <mlflow.models.model.ModelInfo>` instance that contains the
metadata of the logged model.
.. code-block:: python
import mlflow
import openai
# Chat
with mlflow.start_run():
info = mlflow.openai.log_model(
model="gpt-3.5-turbo",
task=openai.chat.completions,
messages=[{"role": "user", "content": "Tell me a joke about {animal}."}],
artifact_path="model",
)
model = mlflow.pyfunc.load_model(info.model_uri)
df = pd.DataFrame({"animal": ["cats", "dogs"]})
print(model.predict(df))
# Embeddings
with mlflow.start_run():
info = mlflow.openai.log_model(
model="text-embedding-ada-002",
task=openai.embeddings,
artifact_path="embeddings",
)
model = mlflow.pyfunc.load_model(info.model_uri)
print(model.predict(["hello", "world"]))
"""
return Model.log(
artifact_path=artifact_path,
flavor=mlflow.openai,
registered_model_name=registered_model_name,
model=model,
task=task,
conda_env=conda_env,
code_paths=code_paths,
signature=signature,
input_example=input_example,
await_registration_for=await_registration_for,
pip_requirements=pip_requirements,
extra_pip_requirements=extra_pip_requirements,
metadata=metadata,
example_no_conversion=example_no_conversion,
**kwargs,
)
def _load_model(path):
with open(path) as f:
return yaml.safe_load(f)
def _is_valid_message(d):
return isinstance(d, dict) and "content" in d and "role" in d
class _ContentFormatter:
def __init__(self, task, template=None):
if task == "completions":
template = template or "{prompt}"
if not isinstance(template, str):
raise mlflow.MlflowException.invalid_parameter_value(
f"Template for task {task} expects type `str`, but got {type(template)}."
)
self.template = template
self.format_fn = self.format_prompt
self.variables = sorted(_parse_format_fields(self.template))
elif task == "chat.completions":
if not template:
template = [{"role": "user", "content": "{content}"}]
if not all(map(_is_valid_message, template)):
raise mlflow.MlflowException.invalid_parameter_value(
f"Template for task {task} expects type `dict` with keys 'content' "
f"and 'role', but got {type(template)}."
)
self.template = template.copy()
self.format_fn = self.format_chat
self.variables = sorted(
set(
itertools.chain.from_iterable(
_parse_format_fields(message.get("content"))
| _parse_format_fields(message.get("role"))
for message in self.template
)
)
)
if not self.variables:
self.template.append({"role": "user", "content": "{content}"})
self.variables.append("content")
else:
raise mlflow.MlflowException.invalid_parameter_value(
f"Task type ``{task}`` is not supported for formatting."
)
def format(self, **params):
if missing_params := set(self.variables) - set(params):
raise mlflow.MlflowException.invalid_parameter_value(
f"Expected parameters {self.variables} to be provided, "
f"only got {list(params)}, {list(missing_params)} are missing."
)
return self.format_fn(**params)
def format_prompt(self, **params):
return self.template.format(**{v: params[v] for v in self.variables})
def format_chat(self, **params):
format_args = {v: params[v] for v in self.variables}
return [
{
"role": message.get("role").format(**format_args),
"content": message.get("content").format(**format_args),
}
for message in self.template
]
def _first_string_column(pdf):
iter_str_cols = (c for c, v in pdf.iloc[0].items() if isinstance(v, str))
col = next(iter_str_cols, None)
if col is None:
raise mlflow.MlflowException.invalid_parameter_value(
f"Could not find a string column in the input data: {pdf.dtypes.to_dict()}"
)
return col
class _OpenAIWrapper:
def __init__(self, model):
task = model.pop("task")
if task not in _PYFUNC_SUPPORTED_TASKS:
raise mlflow.MlflowException.invalid_parameter_value(
f"Unsupported task: {task}. Supported tasks: {_PYFUNC_SUPPORTED_TASKS}."
)
self.model = model
self.task = task
self.api_config = _get_api_config()
self.api_token = _OAITokenHolder(self.api_config.api_type)
# If the same parameter exists in self.model & self.api_config,
# we use the parameter from self.model
self.request_configs = {}
for x in ["api_base", "api_version", "api_type", "engine", "deployment_id"]:
if x in self.model:
self.request_configs[x] = self.model.pop(x)
elif value := getattr(self.api_config, x):
self.request_configs[x] = value
if self.request_configs.get("api_type") in ("azure", "azure_ad", "azuread"):
deployment_id = self.request_configs.get("deployment_id")
if self.request_configs.get("engine"):
# Avoid using both parameters as they serve the same purpose
# Invalid inputs:
# - Wrong engine + correct/wrong deployment_id
# - No engine + wrong deployment_id
# Valid inputs:
# - Correct engine + correct/wrong deployment_id
# - No engine + correct deployment_id
if deployment_id is not None:
_logger.warning(
"Both engine and deployment_id are set. "
"Using engine as it takes precedence."
)
elif deployment_id is None:
raise MlflowException(
"Either engine or deployment_id must be set for Azure OpenAI API",
)
if self.task != "embeddings":
self._setup_completions()
def _setup_completions(self):
if self.task == "chat.completions":
self.template = self.model.get("messages", [])
else:
self.template = self.model.get("prompt")
self.formater = _ContentFormatter(self.task, self.template)
def format_completions(self, params_list):
return [self.formater.format(**params) for params in params_list]
def get_params_list(self, data):
if len(self.formater.variables) == 1:
variable = self.formater.variables[0]
if variable in data.columns:
return data[[variable]].to_dict(orient="records")
else:
first_string_column = _first_string_column(data)
return [{variable: s} for s in data[first_string_column]]
else:
return data[self.formater.variables].to_dict(orient="records")
def _construct_request_url(self, task_url, default_url):
api_type = self.request_configs.get("api_type")
api_base = self.request_configs.get("api_base")
if api_type in ("azure", "azure_ad", "azuread"):
api_version = self.request_configs.get("api_version")
deployment_id = self.request_configs.get("deployment_id")
return (
f"{api_base}/openai/deployments/{deployment_id}/"
f"{task_url}?api-version={api_version}"
)
return f"{api_base}/{task_url}" if api_base else default_url
def _predict_chat(self, data, params):
from mlflow.openai.api_request_parallel_processor import process_api_requests
_validate_model_params(self.task, self.model, params)
messages_list = self.format_completions(self.get_params_list(data))
requests = [{**self.model, **params, "messages": messages} for messages in messages_list]
request_url = self._construct_request_url("chat/completions", REQUEST_URL_CHAT)
results = process_api_requests(
requests,
request_url,
api_token=self.api_token,
max_requests_per_minute=self.api_config.max_requests_per_minute,
max_tokens_per_minute=self.api_config.max_tokens_per_minute,
)
return [r["choices"][0]["message"]["content"] for r in results]
def _predict_completions(self, data, params):
from mlflow.openai.api_request_parallel_processor import process_api_requests
_validate_model_params(self.task, self.model, params)
prompts_list = self.format_completions(self.get_params_list(data))
batch_size = params.pop("batch_size", self.api_config.batch_size)
_logger.debug(f"Requests are being batched by {batch_size} samples.")
requests = [
{
**self.model,
**params,
"prompt": prompts_list[i : i + batch_size],
}
for i in range(0, len(prompts_list), batch_size)
]
request_url = self._construct_request_url("completions", REQUEST_URL_COMPLETIONS)
results = process_api_requests(
requests,
request_url,
api_token=self.api_token,
max_requests_per_minute=self.api_config.max_requests_per_minute,
max_tokens_per_minute=self.api_config.max_tokens_per_minute,
)
return [row["text"] for batch in results for row in batch["choices"]]
def _predict_embeddings(self, data, params):
from mlflow.openai.api_request_parallel_processor import process_api_requests
_validate_model_params(self.task, self.model, params)
batch_size = params.pop("batch_size", self.api_config.batch_size)
_logger.debug(f"Requests are being batched by {batch_size} samples.")
first_string_column = _first_string_column(data)
texts = data[first_string_column].tolist()
requests = [
{
**self.model,
**params,
"input": texts[i : i + batch_size],
}
for i in range(0, len(texts), batch_size)
]
request_url = self._construct_request_url("embeddings", REQUEST_URL_EMBEDDINGS)
results = process_api_requests(
requests,
request_url,
api_token=self.api_token,
max_requests_per_minute=self.api_config.max_requests_per_minute,
max_tokens_per_minute=self.api_config.max_tokens_per_minute,
)
return [row["embedding"] for batch in results for row in batch["data"]]
def predict(self, data, params: Optional[Dict[str, Any]] = None):
"""
Args:
data: Model input data.
params: Additional parameters to pass to the model for inference.
.. Note:: Experimental: This parameter may change or be removed in a future
release without warning.
Returns:
Model predictions.
"""
self.api_token.validate()
if self.task == "chat.completions":
return self._predict_chat(data, params or {})
elif self.task == "completions":
return self._predict_completions(data, params or {})
elif self.task == "embeddings":
return self._predict_embeddings(data, params or {})
def _load_pyfunc(path):
"""Loads PyFunc implementation. Called by ``pyfunc.load_model``.
Args:
path: Local filesystem path to the MLflow Model with the ``openai`` flavor.
"""
return _OpenAIWrapper(_load_model(path))
[docs]@experimental
def load_model(model_uri, dst_path=None):
"""
Load an OpenAI model from a local file or a run.
Args:
model_uri: The location, in URI format, of the MLflow model. For example:
- ``/Users/me/path/to/local/model``
- ``relative/path/to/local/model``
- ``s3://my_bucket/path/to/model``
- ``runs:/<mlflow_run_id>/run-relative/path/to/model``
For more information about supported URI schemes, see
`Referencing Artifacts <https://www.mlflow.org/docs/latest/tracking.html#
artifact-locations>`_.
dst_path: The local filesystem path to which to download the model artifact.
This directory must already exist. If unspecified, a local output
path will be created.
Returns:
A dictionary representing the OpenAI model.
"""
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
flavor_conf = _get_flavor_configuration(local_model_path, FLAVOR_NAME)
_add_code_from_conf_to_system_path(local_model_path, flavor_conf)
model_data_path = os.path.join(local_model_path, flavor_conf.get("data", MODEL_FILENAME))
return _load_model(model_data_path)