Litellm bump (#2195)

* ran bump-pydantic

* replace root_validator with model_validator

* mostly working. some alternate assistant error. changed root_validator and typing_extensions

* working generation chat. changed type

* replacing .dict with .model_dump

* argument needed to bring model_dump up to parity with dict()

* fix a fewremaining issues -- working with llama and gpt

* updating requirements file

* more requirement updates

* more requirement updates

* fix to make search work

* return type fix:

* half way tpyes change

* fixes for mypy and pydantic:

* endpoint fix

* fix pydantic protected namespaces

* it works!

* removed unecessary None initializations

* better logging

* changed default values to empty lists

* mypy fixes

* fixed array defaulting

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
This commit is contained in:
josvdw 2024-08-27 17:00:27 -07:00 committed by GitHub
parent 657d2050a5
commit 50c17438d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
52 changed files with 230 additions and 223 deletions

View File

@ -13,7 +13,7 @@ from danswer.server.manage.models import UserPreferences
def set_no_auth_user_preferences(
store: DynamicConfigStore, preferences: UserPreferences
) -> None:
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:

View File

@ -35,11 +35,12 @@ class QADocsResponse(RetrievalDocs):
applied_time_cutoff: datetime | None
recency_bias_multiplier: float
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
initial_dict["applied_time_cutoff"] = (
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
)
return initial_dict

View File

@ -813,4 +813,4 @@ def stream_chat_message(
is_connected=is_connected,
)
for obj in objects:
yield get_json_line(obj.dict())
yield get_json_line(obj.model_dump())

View File

@ -1,4 +1,4 @@
from typing import TypedDict
from typing_extensions import TypedDict # noreorder
from pydantic import BaseModel

View File

@ -125,7 +125,7 @@ def update_gmail_credential_access_tokens(
) -> OAuthCredentials | None:
app_credentials = get_google_app_gmail_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.dict(),
app_credentials.model_dump(),
scopes=SCOPES,
redirect_uri=_build_frontend_gmail_redirect(),
)

View File

@ -106,7 +106,7 @@ def update_credential_access_tokens(
) -> OAuthCredentials | None:
app_credentials = get_google_app_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.dict(),
app_credentials.model_dump(),
scopes=SCOPES,
redirect_uri=_build_frontend_google_drive_redirect(),
)

View File

@ -3,6 +3,7 @@ from typing import List
from typing import Optional
from pydantic import BaseModel
from pydantic import Field
class Message(BaseModel):
@ -18,11 +19,11 @@ class Message(BaseModel):
sender_realm_str: str
subject: str
topic_links: Optional[List[Any]] = None
last_edit_timestamp: Optional[int] = None
edit_history: Any
last_edit_timestamp: Optional[int]
edit_history: Any = None
reactions: List[Any]
submessages: List[Any]
flags: List[str] = []
flags: List[str] = Field(default_factory=list)
display_recipient: Optional[str] = None
type: Optional[str] = None
stream_id: int
@ -39,4 +40,4 @@ class GetMessagesResponse(BaseModel):
found_newest: Optional[bool] = None
history_limited: Optional[bool] = None
anchor: Optional[str] = None
messages: List[Message] = []
messages: List[Message] = Field(default_factory=list)

View File

@ -26,8 +26,11 @@ from danswer.db.models import User__UserGroup
from danswer.db.models import UserRole
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from danswer.server.features.document_set.models import DocumentSetUpdateRequest
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
@ -233,9 +236,9 @@ def insert_document_set(
)
db_session.commit()
except:
except Exception as e:
db_session.rollback()
raise
logger.error(f"Error creating document set: {e}")
return new_document_set_row, ds_cc_pairs

View File

@ -46,10 +46,10 @@ def upsert_cloud_embedding_provider(
.first()
)
if existing_provider:
for key, value in provider.dict().items():
for key, value in provider.model_dump().items():
setattr(existing_provider, key, value)
else:
new_provider = CloudEmbeddingProviderModel(**provider.dict())
new_provider = CloudEmbeddingProviderModel(**provider.model_dump())
db_session.add(new_provider)
existing_provider = new_provider
db_session.commit()

View File

@ -5,7 +5,7 @@ from typing import Any
from typing import Literal
from typing import NotRequired
from typing import Optional
from typing import TypedDict
from typing_extensions import TypedDict # noreorder
from uuid import UUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID

View File

@ -1,7 +1,7 @@
import base64
from enum import Enum
from typing import NotRequired
from typing import TypedDict
from typing_extensions import TypedDict # noreorder
from pydantic import BaseModel

View File

@ -156,7 +156,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embed_dict[title] = title_embedding
new_embedded_chunk = IndexChunk(
**chunk.dict(),
**chunk.model_dump(),
embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:],

View File

@ -3,6 +3,7 @@ from functools import partial
from typing import Protocol
from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
@ -40,9 +41,7 @@ logger = setup_logger()
class DocumentBatchPrepareContext(BaseModel):
updatable_docs: list[Document]
id_to_db_doc_map: dict[str, DBDocument]
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
class IndexingPipelineProtocol(Protocol):

View File

@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import Field
from danswer.access.models import DocumentAccess
from danswer.connectors.models import Document
@ -24,9 +25,8 @@ class BaseChunk(BaseModel):
chunk_id: int
blurb: str # The first sentence(s) of the first Section of the chunk
content: str
source_links: dict[
int, str
] | None # Holds the link and the offsets into the raw Chunk text
# Holds the link and the offsets into the raw Chunk text
source_links: dict[int, str] | None
section_continuation: bool # True if this Chunk's start is not at the start of a Section
@ -47,7 +47,7 @@ class DocAwareChunk(BaseChunk):
mini_chunk_texts: list[str] | None
large_chunk_reference_ids: list[int] = []
large_chunk_reference_ids: list[int] = Field(default_factory=list)
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a chunk"""
@ -85,7 +85,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
document_sets: set[str],
boost: int,
) -> "DocMetadataAwareIndexChunk":
index_chunk_data = index_chunk.dict()
index_chunk_data = index_chunk.model_dump()
return cls(
**index_chunk_data,
access=access,
@ -102,6 +102,9 @@ class EmbeddingModelDetail(BaseModel):
provider_type: EmbeddingProvider | None = None
api_key: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
@classmethod
def from_db_model(
cls,
@ -123,6 +126,9 @@ class IndexingSetting(EmbeddingModelDetail):
index_name: str | None
multipass_indexing: bool
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
@classmethod
def from_db_model(cls, search_settings: "SearchSettings") -> "IndexingSetting":
return cls(

View File

@ -1,6 +1,5 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import TYPE_CHECKING
from langchain.schema.messages import AIMessage
@ -8,8 +7,9 @@ from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import root_validator
from pydantic import model_validator
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.configs.constants import MessageType
@ -117,22 +117,19 @@ class AnswerStyleConfig(BaseModel):
default_factory=DocumentPruningConfig
)
@root_validator
def check_quotes_and_citation(cls, values: dict[str, Any]) -> dict[str, Any]:
citation_config = values.get("citation_config")
quotes_config = values.get("quotes_config")
if citation_config is None and quotes_config is None:
@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
if self.citation_config is None and self.quotes_config is None:
raise ValueError(
"One of `citation_config` or `quotes_config` must be provided"
)
if citation_config is not None and quotes_config is not None:
if self.citation_config is not None and self.quotes_config is not None:
raise ValueError(
"Only one of `citation_config` or `quotes_config` must be provided"
)
return values
return self
class PromptConfig(BaseModel):
@ -160,6 +157,4 @@ class PromptConfig(BaseModel):
include_citations=model.include_citations,
)
# needed so that this can be passed into lru_cache funcs
class Config:
frozen = True
model_config = ConfigDict(frozen=True)

View File

@ -62,13 +62,13 @@ def _convert_litellm_message_to_langchain_message(
litellm_message: litellm.Message,
) -> BaseMessage:
# Extracting the basic attributes from the litellm message
content = litellm_message.content
content = litellm_message.content or ""
role = litellm_message.role
# Handling function calls and tool calls if present
tool_calls = (
cast(
list[litellm.utils.ChatCompletionMessageToolCall],
list[litellm.ChatCompletionMessageToolCall],
litellm_message.tool_calls,
)
if hasattr(litellm_message, "tool_calls")
@ -87,7 +87,7 @@ def _convert_litellm_message_to_langchain_message(
"args": json.loads(tool_call.function.arguments),
"id": tool_call.id,
}
for tool_call in tool_calls
for tool_call in (tool_calls if tool_calls else [])
],
)
elif role == "system":
@ -296,9 +296,11 @@ class DefaultMultiLLM(LLM):
response = cast(
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
)
return _convert_litellm_message_to_langchain_message(
response.choices[0].message
)
choice = response.choices[0]
if hasattr(choice, "message"):
return _convert_litellm_message_to_langchain_message(choice.message)
else:
raise ValueError("Unexpected response choice type")
def _stream_implementation(
self,
@ -314,7 +316,10 @@ class DefaultMultiLLM(LLM):
return
output = None
response = self._completion(prompt, tools, tool_choice, True)
response = cast(
litellm.CustomStreamWrapper,
self._completion(prompt, tools, tool_choice, True),
)
try:
for part in response:
if len(part["choices"]) == 0:

View File

@ -21,9 +21,12 @@ class LLMConfig(BaseModel):
model_provider: str
model_name: str
temperature: float
api_key: str | None
api_base: str | None
api_version: str | None
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
def log_prompt(prompt: LanguageModelInput) -> None:

View File

@ -11,6 +11,9 @@ class LLMOverride(BaseModel):
model_version: str | None = None
temperature: float | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
class PromptOverride(BaseModel):
system_prompt: str | None = None

View File

@ -110,7 +110,7 @@ class EmbeddingModel:
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_request() -> EmbedResponse:
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
self.embed_server_endpoint, json=embed_request.model_dump()
)
try:
response.raise_for_status()
@ -255,7 +255,7 @@ class RerankingModel:
)
response = requests.post(
self.rerank_server_endpoint, json=rerank_request.dict()
self.rerank_server_endpoint, json=rerank_request.model_dump()
)
response.raise_for_status()
@ -288,7 +288,7 @@ class QueryAnalysisModel:
)
response = requests.post(
self.intent_server_endpoint, json=intent_request.dict()
self.intent_server_endpoint, json=intent_request.model_dump()
)
response.raise_for_status()

View File

@ -326,7 +326,7 @@ def stream_search_answer(
db_session=session,
)
for obj in objects:
yield get_json_line(obj.dict())
yield get_json_line(obj.model_dump())
def get_search_answer(

View File

@ -1,8 +1,6 @@
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from pydantic import root_validator
from pydantic import model_validator
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerContexts
@ -21,7 +19,7 @@ class QueryRephrase(BaseModel):
class ThreadMessage(BaseModel):
message: str
sender: str | None
sender: str | None = None
role: MessageType = MessageType.USER
@ -45,21 +43,16 @@ class DirectQARequest(ChunkContext):
# If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False
@root_validator
def check_chain_of_thought_and_prompt_id(
cls, values: dict[str, Any]
) -> dict[str, Any]:
chain_of_thought = values.get("chain_of_thought")
prompt_id = values.get("prompt_id")
if chain_of_thought and prompt_id is not None:
@model_validator(mode="after")
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
if self.chain_of_thought and self.prompt_id is not None:
raise ValueError(
"If chain_of_thought is True, prompt_id must be None"
"The chain of thought prompt is only for question "
"answering and does not accept customizing."
)
return values
return self
class OneShotQAResponse(BaseModel):

View File

@ -2,7 +2,9 @@ from datetime import datetime
from typing import Any
from pydantic import BaseModel
from pydantic import validator
from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource
@ -112,7 +114,8 @@ class ChunkContext(BaseModel):
chunks_below: int | None = None
full_doc: bool = False
@validator("chunks_above", "chunks_below", pre=True, each_item=False)
@field_validator("chunks_above", "chunks_below")
@classmethod
def check_non_negative(cls, value: int, field: Any) -> int:
if value is not None and value < 0:
raise ValueError(f"{field.name} must be non-negative")
@ -137,9 +140,7 @@ class SearchRequest(ChunkContext):
hybrid_alpha: float | None = None
rerank_settings: RerankingDetails | None = None
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(arbitrary_types_allowed=True)
class SearchQuery(ChunkContext):
@ -163,9 +164,7 @@ class SearchQuery(ChunkContext):
num_hits: int = NUM_RETURNED_HITS
offset: int = 0
class Config:
frozen = True
model_config = ConfigDict(frozen=True)
class RetrievalDetails(ChunkContext):
@ -209,7 +208,7 @@ class InferenceChunk(BaseChunk):
updated_at: datetime | None
primary_owners: list[str] | None = None
secondary_owners: list[str] | None = None
large_chunk_reference_ids: list[int] = []
large_chunk_reference_ids: list[int] = Field(default_factory=list)
@property
def unique_id(self) -> str:
@ -268,7 +267,7 @@ class InferenceChunkUncleaned(InferenceChunk):
# Assumes the cleaning has already been applied and just needs to translate to the right type
inference_chunk_data = {
k: v
for k, v in self.dict().items()
for k, v in self.model_dump().items()
if k
not in ["metadata_suffix"] # May be other fields to throw out in the future
}
@ -288,7 +287,7 @@ class SearchDoc(BaseModel):
document_id: str
chunk_ind: int
semantic_identifier: str
link: str | None
link: str | None = None
blurb: str
source_type: DocumentSource
boost: int
@ -297,7 +296,7 @@ class SearchDoc(BaseModel):
# be `True` when doing an admin search
hidden: bool
metadata: dict[str, str | list[str]]
score: float | None
score: float | None = None
is_relevant: bool | None = None
relevance_explanation: str | None = None
# Matched sections in the doc. Uses Vespa syntax e.g. <hi>TEXT</hi>
@ -305,13 +304,13 @@ class SearchDoc(BaseModel):
# ["<hi>the</hi> <hi>answer</hi> is 42", "the answer is <hi>42</hi>""]
match_highlights: list[str]
# when the doc was last updated
updated_at: datetime | None
primary_owners: list[str] | None
secondary_owners: list[str] | None
updated_at: datetime | None = None
primary_owners: list[str] | None = None
secondary_owners: list[str] | None = None
is_internet: bool = False
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(*args, **kwargs) # type: ignore
initial_dict["updated_at"] = (
self.updated_at.isoformat() if self.updated_at else None
)
@ -329,7 +328,7 @@ class SavedSearchDoc(SearchDoc):
"""IMPORTANT: careful using this and not providing a db_doc_id If db_doc_id is not
provided, it won't be able to actually fetch the saved doc and info later on. So only skip
providing this if the SavedSearchDoc will not be used in the future"""
search_doc_data = search_doc.dict()
search_doc_data = search_doc.model_dump()
search_doc_data["score"] = search_doc_data.get("score") or 0.0
return cls(**search_doc_data, db_doc_id=db_doc_id)

View File

@ -74,7 +74,7 @@ def stream_query_answerability(
QueryValidationResponse(
reasoning="Query Answerability Evaluation feature is turned off",
answerable=True,
).dict()
).model_dump()
)
return
@ -85,7 +85,7 @@ def stream_query_answerability(
QueryValidationResponse(
reasoning="Generative AI is turned off - skipping check",
answerable=True,
).dict()
).model_dump()
)
return
messages = get_query_validation_messages(user_query)
@ -107,7 +107,7 @@ def stream_query_answerability(
remaining = model_output[reason_ind + len(THOUGHT_PAT.upper()) :]
if remaining:
yield get_json_line(
DanswerAnswerPiece(answer_piece=remaining).dict()
DanswerAnswerPiece(answer_piece=remaining).model_dump()
)
continue
@ -116,7 +116,7 @@ def stream_query_answerability(
if hold_answerable == ANSWERABLE_PAT.upper()[: len(hold_answerable)]:
continue
yield get_json_line(
DanswerAnswerPiece(answer_piece=hold_answerable).dict()
DanswerAnswerPiece(answer_piece=hold_answerable).model_dump()
)
hold_answerable = ""
@ -124,11 +124,13 @@ def stream_query_answerability(
answerable = extract_answerability_bool(model_output)
yield get_json_line(
QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict()
QueryValidationResponse(
reasoning=reasoning, answerable=answerable
).model_dump()
)
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
error = StreamingError(error=str(e))
yield get_json_line(error.dict())
yield get_json_line(error.model_dump())
logger.exception("Failed to validate Query")
return

View File

@ -5,7 +5,7 @@ from danswer.connectors.models import DocumentBase
class IngestionDocument(BaseModel):
document: DocumentBase
cc_pair_id: int | None
cc_pair_id: int | None = None
class IngestionResult(BaseModel):
@ -16,4 +16,4 @@ class IngestionResult(BaseModel):
class DocMinimalInfo(BaseModel):
document_id: str
semantic_id: str
link: str | None
link: str | None = None

View File

@ -170,7 +170,7 @@ def associate_credential_to_connector(
connector_id=connector_id,
credential_id=credential_id,
cc_pair_name=metadata.name,
is_public=metadata.is_public,
is_public=metadata.is_public or True,
groups=metadata.groups,
)

View File

@ -3,6 +3,7 @@ from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import DocumentSource
@ -40,14 +41,15 @@ class ConnectorBase(BaseModel):
source: DocumentSource
input_type: InputType
connector_specific_config: dict[str, Any]
refresh_freq: int | None # In seconds, None for one time index with no refresh
prune_freq: int | None
indexing_start: datetime | None
# In seconds, None for one time index with no refresh
refresh_freq: int | None = None
prune_freq: int | None = None
indexing_start: datetime | None = None
class ConnectorUpdateRequest(ConnectorBase):
is_public: bool | None = None
groups: list[int] | None = None
groups: list[int] = Field(default_factory=list)
class ConnectorSnapshot(ConnectorBase):
@ -93,7 +95,7 @@ class CredentialBase(BaseModel):
source: DocumentSource
name: str | None = None
curator_public: bool = False
groups: list[int] = []
groups: list[int] = Field(default_factory=list)
class CredentialSnapshot(CredentialBase):
@ -254,21 +256,21 @@ class ConnectorCredentialPairIdentifier(BaseModel):
class ConnectorCredentialPairMetadata(BaseModel):
name: str | None
is_public: bool
groups: list[int] | None
name: str | None = None
is_public: bool | None = None
groups: list[int] = Field(default_factory=list)
class ConnectorCredentialPairDescriptor(BaseModel):
id: int
name: str | None
name: str | None = None
connector: ConnectorSnapshot
credential: CredentialSnapshot
class RunConnectorRequest(BaseModel):
connector_id: int
credential_ids: list[int] | None
credential_ids: list[int] | None = None
from_beginning: bool = False

View File

@ -1,6 +1,7 @@
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.server.documents.models import ConnectorCredentialPairDescriptor
@ -14,8 +15,8 @@ class DocumentSetCreationRequest(BaseModel):
cc_pair_ids: list[int]
is_public: bool
# For Private Document Sets, who should be able to access these
users: list[UUID] | None = None
groups: list[int] | None = None
users: list[UUID] = Field(default_factory=list)
groups: list[int] = Field(default_factory=list)
class DocumentSetUpdateRequest(BaseModel):

View File

@ -19,7 +19,7 @@ class FolderCreationRequest(BaseModel):
class FolderUpdateRequest(BaseModel):
folder_name: str | None
folder_name: str | None = None
class FolderChatSessionRequest(BaseModel):

View File

@ -1,6 +1,7 @@
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from danswer.db.models import Persona
from danswer.db.models import StarterMessage
@ -31,8 +32,8 @@ class CreatePersonaRequest(BaseModel):
llm_model_version_override: str | None = None
starter_messages: list[StarterMessage] | None = None
# For Private Personas, who should be able to access these
users: list[UUID] | None = None
groups: list[int] | None = None
users: list[UUID] = Field(default_factory=list)
groups: list[int] = Field(default_factory=list)
icon_color: str | None = None
icon_shape: int | None = None
uploaded_image_id: str | None = None # New field for uploaded image

View File

@ -26,14 +26,14 @@ admin_router = APIRouter(prefix="/admin/tool")
class CustomToolCreate(BaseModel):
name: str
description: str | None
description: str | None = None
definition: dict[str, Any]
class CustomToolUpdate(BaseModel):
name: str | None
description: str | None
definition: dict[str, Any] | None
name: str | None = None
description: str | None = None
definition: dict[str, Any] | None = None
def _validate_tool_definition(definition: dict[str, Any]) -> None:

View File

@ -1,6 +1,7 @@
from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import Field
from danswer.llm.llm_provider_options import fetch_models_for_provider
@ -56,26 +57,26 @@ class LLMProviderDescriptor(BaseModel):
class LLMProvider(BaseModel):
name: str
provider: str
api_key: str | None
api_base: str | None
api_version: str | None
custom_config: dict[str, str] | None
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
default_model_name: str
fast_default_model_name: str | None
fast_default_model_name: str | None = None
is_public: bool = True
groups: list[int] | None = None
display_model_names: list[str] | None
groups: list[int] = Field(default_factory=list)
display_model_names: list[str] | None = None
class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider
# for default providers, the built-in model names are used
model_names: list[str] | None
model_names: list[str] | None = None
class FullLLMProvider(LLMProvider):
id: int
is_default_provider: bool | None
is_default_provider: bool | None = None
model_names: list[str]
@classmethod

View File

@ -1,10 +1,11 @@
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import root_validator
from pydantic import validator
from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
@ -39,8 +40,8 @@ class AuthTypeResponse(BaseModel):
class UserPreferences(BaseModel):
chosen_assistants: list[int] | None
default_model: str | None
chosen_assistants: list[int] | None = None
default_model: str | None = None
class UserInfo(BaseModel):
@ -158,7 +159,8 @@ class StandardAnswerCreationRequest(BaseModel):
answer: str
categories: list[int]
@validator("categories", pre=True)
@field_validator("categories", mode="before")
@classmethod
def validate_categories(cls, value: list[int]) -> list[int]:
if len(value) < 1:
raise ValueError(
@ -170,9 +172,7 @@ class StandardAnswerCreationRequest(BaseModel):
class SlackBotTokens(BaseModel):
bot_token: str
app_token: str
class Config:
frozen = True
model_config = ConfigDict(frozen=True)
class SlackBotConfigCreationRequest(BaseModel):
@ -180,23 +180,24 @@ class SlackBotConfigCreationRequest(BaseModel):
# in the future, `document_sets` will probably be replaced
# by an optional `PersonaSnapshot` object. Keeping it like this
# for now for simplicity / speed of development
document_sets: list[int] | None
document_sets: list[int] | None = None
persona_id: (
int | None
) # NOTE: only one of `document_sets` / `persona_id` should be set
) = None # NOTE: only one of `document_sets` / `persona_id` should be set
channel_names: list[str]
respond_tag_only: bool = False
respond_to_bots: bool = False
enable_auto_filters: bool = False
# If no team members, assume respond in the channel to everyone
respond_member_group_list: list[str] = []
answer_filters: list[AllowedAnswerFilters] = []
respond_member_group_list: list[str] = Field(default_factory=list)
answer_filters: list[AllowedAnswerFilters] = Field(default_factory=list)
# list of user emails
follow_up_tags: list[str] | None = None
response_type: SlackBotResponseType
standard_answer_categories: list[int] = []
standard_answer_categories: list[int] = Field(default_factory=list)
@validator("answer_filters", pre=True)
@field_validator("answer_filters", mode="before")
@classmethod
def validate_filters(cls, value: list[str]) -> list[str]:
if any(test not in VALID_SLACK_FILTERS for test in value):
raise ValueError(
@ -204,14 +205,12 @@ class SlackBotConfigCreationRequest(BaseModel):
)
return value
@root_validator
def validate_document_sets_and_persona_id(
cls, values: dict[str, Any]
) -> dict[str, Any]:
if values.get("document_sets") and values.get("persona_id"):
@model_validator(mode="after")
def validate_document_sets_and_persona_id(self) -> "SlackBotConfigCreationRequest":
if self.document_sets and self.persona_id:
raise ValueError("Only one of `document_sets` / `persona_id` should be set")
return values
return self
class SlackBotConfig(BaseModel):

View File

@ -323,7 +323,7 @@ def verify_user_logged_in(
class ChosenDefaultModelRequest(BaseModel):
default_model: str | None
default_model: str | None = None
@router.patch("/user/default-model")

View File

@ -4,7 +4,6 @@ from typing import TypeVar
from uuid import UUID
from pydantic import BaseModel
from pydantic.generics import GenericModel
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserStatus
@ -13,7 +12,7 @@ from danswer.auth.schemas import UserStatus
DataT = TypeVar("DataT")
class StatusResponse(GenericModel, Generic[DataT]):
class StatusResponse(BaseModel, Generic[DataT]):
success: bool
message: Optional[str] = None
data: Optional[DataT] = None

View File

@ -2,7 +2,7 @@ from datetime import datetime
from typing import Any
from pydantic import BaseModel
from pydantic import root_validator
from pydantic import model_validator
from danswer.chat.models import RetrievalDocs
from danswer.configs.constants import DocumentSource
@ -54,16 +54,11 @@ class ChatFeedbackRequest(BaseModel):
feedback_text: str | None = None
predefined_feedback: str | None = None
@root_validator
def check_is_positive_or_feedback_text(cls: BaseModel, values: dict) -> dict:
is_positive, feedback_text = values.get("is_positive"), values.get(
"feedback_text"
)
if is_positive is None and feedback_text is None:
@model_validator(mode="after")
def check_is_positive_or_feedback_text(self) -> "ChatFeedbackRequest":
if self.is_positive is None and self.feedback_text is None:
raise ValueError("Empty feedback received.")
return values
return self
"""
@ -112,18 +107,13 @@ class CreateChatMessageRequest(ChunkContext):
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False
@root_validator
def check_search_doc_ids_or_retrieval_options(cls: BaseModel, values: dict) -> dict:
search_doc_ids, retrieval_options = values.get("search_doc_ids"), values.get(
"retrieval_options"
)
if search_doc_ids is None and retrieval_options is None:
@model_validator(mode="after")
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
if self.search_doc_ids is None and self.retrieval_options is None:
raise ValueError(
"Either search_doc_ids or retrieval_options must be provided, but not both or neither."
)
return values
return self
class ChatMessageIdentifier(BaseModel):
@ -149,7 +139,7 @@ class ChatSessionDetails(BaseModel):
persona_id: int
time_created: str
shared_status: ChatSessionSharedStatus
folder_id: int | None
folder_id: int | None = None
current_alternate_model: str | None = None
@ -162,36 +152,36 @@ class SearchFeedbackRequest(BaseModel):
document_id: str
document_rank: int
click: bool
search_feedback: SearchFeedbackType | None
search_feedback: SearchFeedbackType | None = None
@root_validator
def check_click_or_search_feedback(cls: BaseModel, values: dict) -> dict:
click, feedback = values.get("click"), values.get("search_feedback")
@model_validator(mode="after")
def check_click_or_search_feedback(self) -> "SearchFeedbackRequest":
click, feedback = self.click, self.search_feedback
if click is False and feedback is None:
raise ValueError("Empty feedback received.")
return values
return self
class ChatMessageDetail(BaseModel):
message_id: int
parent_message: int | None
latest_child_message: int | None
parent_message: int | None = None
latest_child_message: int | None = None
message: str
rephrased_query: str | None
context_docs: RetrievalDocs | None
rephrased_query: str | None = None
context_docs: RetrievalDocs | None = None
message_type: MessageType
time_sent: datetime
alternate_assistant_id: str | None
overridden_model: str | None
alternate_assistant_id: int | None = None
# Dict mapping citation number to db_doc_id
chat_session_id: int | None = None
citations: dict[int, int] | None
citations: dict[int, int] | None = None
files: list[FileDescriptor]
tool_calls: list[ToolCallFinalResult]
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
initial_dict["time_sent"] = self.time_sent.isoformat()
return initial_dict
@ -226,7 +216,3 @@ class AdminSearchRequest(BaseModel):
class AdminSearchResponse(BaseModel):
documents: list[SearchDoc]
class DanswerAnswer(BaseModel):
answer: str | None

View File

@ -64,7 +64,7 @@ def fetch_settings(
needs_reindexing = False
return UserSettings(
**general_settings.dict(),
**general_settings.model_dump(),
notifications=user_notifications,
needs_reindexing=needs_reindexing
)

View File

@ -12,10 +12,10 @@ def load_settings() -> Settings:
settings = Settings(**cast(dict, dynamic_config_store.load(KV_SETTINGS_KEY)))
except ConfigNotFoundError:
settings = Settings()
dynamic_config_store.store(KV_SETTINGS_KEY, settings.dict())
dynamic_config_store.store(KV_SETTINGS_KEY, settings.model_dump())
return settings
def store_settings(settings: Settings) -> None:
get_dynamic_config_store().store(KV_SETTINGS_KEY, settings.dict())
get_dynamic_config_store().store(KV_SETTINGS_KEY, settings.model_dump())

View File

@ -1,6 +1,6 @@
import os
from typing import Type
from typing import TypedDict
from typing_extensions import TypedDict # noreorder
from sqlalchemy import not_
from sqlalchemy import or_

View File

@ -254,6 +254,6 @@ class ImageGenerationTool(Tool):
list[ImageGenerationResponse], args[0].response
)
return [
image_generation_response.dict()
image_generation_response.model_dump()
for image_generation_response in image_generation_responses
]

View File

@ -187,7 +187,7 @@ class InternetSearchTool(Tool):
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
search_response = cast(InternetSearchResponse, args[0].response)
return json.dumps(search_response.dict())
return json.dumps(search_response.model_dump())
def _perform_search(self, query: str) -> InternetSearchResponse:
response = self.client.get(
@ -230,4 +230,4 @@ class InternetSearchTool(Tool):
def final_result(self, *args: ToolResponse) -> JSON_ro:
search_response = cast(InternetSearchResponse, args[0].response)
return search_response.dict()
return search_response.model_dump()

View File

@ -4,10 +4,12 @@ from typing import Any
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.tool import ToolCall
from langchain_core.messages.tool import ToolMessage
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModel__v1
from danswer.natural_language_processing.utils import BaseTokenizer
# Langchain has their own version of pydantic which is version 1
def build_tool_message(
tool_call: ToolCall, tool_content: str | list[str | dict[str, Any]]
@ -19,7 +21,7 @@ def build_tool_message(
)
class ToolCallSummary(BaseModel):
class ToolCallSummary(BaseModel__v1):
tool_call_request: AIMessage
tool_call_result: ToolMessage

View File

@ -1,12 +1,12 @@
from typing import Any
from pydantic import BaseModel
from pydantic import root_validator
from pydantic import model_validator
class ToolResponse(BaseModel):
id: str | None = None
response: Any
response: Any = None
class ToolCallKickoff(BaseModel):
@ -19,12 +19,10 @@ class ToolRunnerResponse(BaseModel):
tool_response: ToolResponse | None = None
tool_message_content: str | list[str | dict[str, Any]] | None = None
@root_validator
def validate_tool_runner_response(
cls, values: dict[str, ToolResponse | str]
) -> dict[str, ToolResponse | str]:
@model_validator(mode="after")
def validate_tool_runner_response(self) -> "ToolRunnerResponse":
fields = ["tool_response", "tool_message_content", "tool_run_kickoff"]
provided = sum(1 for field in fields if values.get(field) is not None)
provided = sum(1 for field in fields if getattr(self, field) is not None)
if provided != 1:
raise ValueError(
@ -32,8 +30,10 @@ class ToolRunnerResponse(BaseModel):
"or 'tool_run_kickoff' must be provided"
)
return values
return self
class ToolCallFinalResult(ToolCallKickoff):
tool_result: Any # we would like to use JSON_ro, but can't due to its recursive nature
tool_result: Any = (
None # we would like to use JSON_ro, but can't due to its recursive nature
)

View File

@ -30,13 +30,13 @@ def load_settings() -> EnterpriseSettings:
)
except ConfigNotFoundError:
settings = EnterpriseSettings()
dynamic_config_store.store(KV_ENTERPRISE_SETTINGS_KEY, settings.dict())
dynamic_config_store.store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump())
return settings
def store_settings(settings: EnterpriseSettings) -> None:
get_dynamic_config_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.dict())
get_dynamic_config_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump())
_CUSTOM_ANALYTICS_SECRET_KEY = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY")

View File

@ -17,7 +17,7 @@ class StandardAnswerRequest(BaseModel):
class StandardAnswerResponse(BaseModel):
standard_answers: list[StandardAnswer] = []
standard_answers: list[StandardAnswer] = Field(default_factory=list)
class DocumentSearchRequest(ChunkContext):

View File

@ -376,7 +376,7 @@ def get_query_history_as_csv(
# Create an in-memory text stream
stream = io.StringIO()
writer = csv.DictWriter(
stream, fieldnames=list(QuestionAnswerPairSnapshot.__fields__.keys())
stream, fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys())
)
writer.writeheader()
for row in question_answer_pairs:

View File

@ -29,20 +29,20 @@ langchain==0.1.17
langchain-community==0.0.36
langchain-core==0.1.50
langchain-text-splitters==0.0.1
litellm==1.37.7
litellm==1.43.18
llama-index==0.9.45
Mako==1.2.4
msal==1.26.0
nltk==3.8.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.14.3
openai==1.41.1
openpyxl==3.1.2
playwright==1.41.2
psutil==5.9.5
psycopg2-binary==2.9.9
pycryptodome==3.19.1
pydantic==1.10.13
pydantic==2.8.2
PyGithub==1.58.2
python-dateutil==2.8.2
python-gitlab==3.9.0
@ -64,7 +64,7 @@ slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.15
starlette==0.36.3
supervisor==4.2.5
tiktoken==0.4.0
tiktoken==0.7.0
timeago==1.0.16
transformers==4.39.2
uvicorn==0.21.1

View File

@ -3,8 +3,8 @@ einops==0.8.0
fastapi==0.109.2
google-cloud-aiplatform==1.58.0
numpy==1.26.4
openai==1.14.3
pydantic==1.10.13
openai==1.41.1
pydantic==2.8.2
retry==0.9.2
safetensors==0.4.2
sentence-transformers==2.6.1

View File

@ -10,14 +10,17 @@ Embedding = list[float]
class EmbedRequest(BaseModel):
texts: list[str]
# Can be none for cloud embedding model requests, error handling logic exists for other cases
model_name: str | None
model_name: str | None = None
max_context_length: int
normalize_embeddings: bool
api_key: str | None
provider_type: EmbeddingProvider | None
api_key: str | None = None
provider_type: EmbeddingProvider | None = None
text_type: EmbedTextType
manual_query_prefix: str | None
manual_passage_prefix: str | None
manual_query_prefix: str | None = None
manual_passage_prefix: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
class EmbedResponse(BaseModel):
@ -28,8 +31,11 @@ class RerankRequest(BaseModel):
query: str
documents: list[str]
model_name: str
provider_type: RerankerProvider | None
api_key: str | None
provider_type: RerankerProvider | None = None
api_key: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
class RerankResponse(BaseModel):

View File

@ -14,7 +14,7 @@ class DocumentSetClient:
) -> int:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/document-set",
json=doc_set_creation_request.dict(),
json=doc_set_creation_request.model_dump(),
)
response.raise_for_status()
return cast(int, response.json())

View File

@ -31,7 +31,7 @@ class LLMProvider(BaseModel):
custom_config=None,
fast_default_model_name=None,
is_public=True,
groups=None,
groups=[],
display_model_names=None,
model_names=None,
)

View File

@ -12,7 +12,7 @@ class UserGroupClient:
def create_user_group(user_group_creation_request: UserGroupCreate) -> int:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/user-group",
json=user_group_creation_request.dict(),
json=user_group_creation_request.model_dump(),
)
response.raise_for_status()
return cast(int, response.json()["id"])

View File

@ -57,7 +57,7 @@ def get_answer_from_query(
"Content-Type": "application/json",
}
body = new_message_request.dict()
body = new_message_request.model_dump()
body["user"] = None
try:
response_json = requests.post(url, headers=headers, json=body).json()
@ -122,7 +122,7 @@ def create_cc_pair(env_name: str, connector_id: int, credential_id: int) -> None
env_name, f"/manage/connector/{connector_id}/credential/{credential_id}"
)
body = {"name": "zip_folder_contents", "is_public": True}
body = {"name": "zip_folder_contents", "is_public": True, "groups": []}
print("body:", body)
response = requests.put(url, headers=GENERAL_HEADERS, json=body)
if response.status_code == 200:
@ -167,7 +167,7 @@ def create_connector(env_name: str, file_paths: list[str]) -> int:
indexing_start=None,
)
body = connector.dict()
body = connector.model_dump()
response = requests.post(url, headers=GENERAL_HEADERS, json=body)
if response.status_code == 200:
return response.json()["id"]