from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, TypeVar, Union
from mlflow.genai.utils.enum_utils import StrEnum
if TYPE_CHECKING:
    from databricks.agents.review_app import label_schemas as _label_schemas
    _InputCategorical = _label_schemas.InputCategorical
    _InputCategoricalList = _label_schemas.InputCategoricalList
    _InputNumeric = _label_schemas.InputNumeric
    _InputText = _label_schemas.InputText
    _InputTextList = _label_schemas.InputTextList
    _LabelSchema = _label_schemas.LabelSchema
DatabricksInputType = TypeVar("DatabricksInputType")
_InputType = TypeVar("_InputType", bound="InputType")
class InputType(ABC):
    """Base class for all input types."""
    @abstractmethod
    def _to_databricks_input(self) -> DatabricksInputType:
        """Convert to the internal Databricks input type."""
    @classmethod
    @abstractmethod
    def _from_databricks_input(cls, input_obj: DatabricksInputType) -> _InputType:
        """Create from the internal Databricks input type."""
[docs]@dataclass
class InputTextList(InputType):
    """Like `Text`, but allows multiple entries.
    .. note::
        This functionality is only available in Databricks. Please run
        `pip install mlflow[databricks]` to use it.
    """
    max_length_each: Optional[int] = None
    """Maximum character length for each individual text entry. None means no limit."""
    max_count: Optional[int] = None
    """Maximum number of text entries allowed. None means no limit."""
    def _to_databricks_input(self) -> "_InputTextList":
        """Convert to the internal Databricks input type."""
        from databricks.agents.review_app import label_schemas as _label_schemas
        return _label_schemas.InputTextList(
            max_length_each=self.max_length_each, max_count=self.max_count
        )
    @classmethod
    def _from_databricks_input(cls, input_obj: "_InputTextList") -> "InputTextList":
        """Create from the internal Databricks input type."""
        return cls(max_length_each=input_obj.max_length_each, max_count=input_obj.max_count) 
[docs]@dataclass
class InputText(InputType):
    """A free-form text box for collecting assessments from stakeholders.
    .. note::
        This functionality is only available in Databricks. Please run
        `pip install mlflow[databricks]` to use it.
    """
    max_length: Optional[int] = None
    """Maximum character length for the text input. None means no limit."""
    def _to_databricks_input(self) -> "_InputText":
        """Convert to the internal Databricks input type."""
        from databricks.agents.review_app import label_schemas as _label_schemas
        return _label_schemas.InputText(max_length=self.max_length)
    @classmethod
    def _from_databricks_input(cls, input_obj: "_InputText") -> "InputText":
        """Create from the internal Databricks input type."""
        return cls(max_length=input_obj.max_length) 
[docs]class LabelSchemaType(StrEnum):
    """Type of label schema."""
    FEEDBACK = "feedback"
    EXPECTATION = "expectation" 
[docs]@dataclass(frozen=True)
class LabelSchema:
    """A label schema for collecting input from stakeholders.
    .. note::
        This functionality is only available in Databricks. Please run
        `pip install mlflow[databricks]` to use it.
    """
    name: str
    """Unique name identifier for the label schema."""
    type: LabelSchemaType
    """Type of the label schema, either 'feedback' or 'expectation'."""
    title: str
    """Display title shown to stakeholders in the labeling review UI."""
    input: Union[InputCategorical, InputCategoricalList, InputText, InputTextList, InputNumeric]
    """
    Input type specification that defines how stakeholders will provide their assessment
    (e.g., dropdown, text box, numeric input)
    """
    instruction: Optional[str] = None
    """Optional detailed instructions shown to stakeholders for guidance."""
    enable_comment: bool = False
    """Whether to enable additional comment functionality for reviewers."""
    @classmethod
    def _from_databricks_label_schema(cls, schema: "_LabelSchema") -> "LabelSchema":
        """Convert from the internal Databricks label schema type."""
        return cls(
            name=schema.name,
            type=schema.type,
            title=schema.title,
            input=schema.input._from_databricks_input(),
            instruction=schema.instruction,
            enable_comment=schema.enable_comment,
        )