danswer/backend/shared_configs/model_server_models.py
josvdw 50c17438d5
Litellm bump (#2195)
* ran bump-pydantic

* replace root_validator with model_validator

* mostly working. some alternate assistant error. changed root_validator and typing_extensions

* working generation chat. changed type

* replacing .dict with .model_dump

* argument needed to bring model_dump up to parity with dict()

* fix a fewremaining issues -- working with llama and gpt

* updating requirements file

* more requirement updates

* more requirement updates

* fix to make search work

* return type fix:

* half way tpyes change

* fixes for mypy and pydantic:

* endpoint fix

* fix pydantic protected namespaces

* it works!

* removed unecessary None initializations

* better logging

* changed default values to empty lists

* mypy fixes

* fixed array defaulting

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-08-28 00:00:27 +00:00

56 lines
1.4 KiB
Python

from pydantic import BaseModel
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
Embedding = list[float]
class EmbedRequest(BaseModel):
texts: list[str]
# Can be none for cloud embedding model requests, error handling logic exists for other cases
model_name: str | None = None
max_context_length: int
normalize_embeddings: bool
api_key: str | None = None
provider_type: EmbeddingProvider | None = None
text_type: EmbedTextType
manual_query_prefix: str | None = None
manual_passage_prefix: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
class EmbedResponse(BaseModel):
embeddings: list[Embedding]
class RerankRequest(BaseModel):
query: str
documents: list[str]
model_name: str
provider_type: RerankerProvider | None = None
api_key: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
class RerankResponse(BaseModel):
scores: list[float]
class IntentRequest(BaseModel):
query: str
# Sequence classification threshold
semantic_percent_threshold: float
# Token classification threshold
keyword_percent_threshold: float
class IntentResponse(BaseModel):
is_keyword: bool
keywords: list[str]