diff --git a/backend/danswer/auth/noauth_user.py b/backend/danswer/auth/noauth_user.py index 11b04bf40..9520ef41c 100644 --- a/backend/danswer/auth/noauth_user.py +++ b/backend/danswer/auth/noauth_user.py @@ -13,7 +13,7 @@ from danswer.server.manage.models import UserPreferences def set_no_auth_user_preferences( store: DynamicConfigStore, preferences: UserPreferences ) -> None: - store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.dict()) + store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump()) def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences: diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 1828d5250..6d12d68df 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -35,11 +35,12 @@ class QADocsResponse(RetrievalDocs): applied_time_cutoff: datetime | None recency_bias_multiplier: float - def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore - initial_dict = super().dict(*args, **kwargs) # type: ignore + def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore + initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore initial_dict["applied_time_cutoff"] = ( self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None ) + return initial_dict diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 3fb94714c..2eea2cfc2 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -813,4 +813,4 @@ def stream_chat_message( is_connected=is_connected, ) for obj in objects: - yield get_json_line(obj.dict()) + yield get_json_line(obj.model_dump()) diff --git a/backend/danswer/chat/tools.py b/backend/danswer/chat/tools.py index 717cead63..11b405929 100644 --- a/backend/danswer/chat/tools.py +++ b/backend/danswer/chat/tools.py @@ -1,4 +1,4 @@ -from typing import TypedDict +from typing_extensions import TypedDict # noreorder from pydantic import BaseModel diff --git a/backend/danswer/connectors/gmail/connector_auth.py b/backend/danswer/connectors/gmail/connector_auth.py index e518d5a50..ad80d1e1e 100644 --- a/backend/danswer/connectors/gmail/connector_auth.py +++ b/backend/danswer/connectors/gmail/connector_auth.py @@ -125,7 +125,7 @@ def update_gmail_credential_access_tokens( ) -> OAuthCredentials | None: app_credentials = get_google_app_gmail_cred() flow = InstalledAppFlow.from_client_config( - app_credentials.dict(), + app_credentials.model_dump(), scopes=SCOPES, redirect_uri=_build_frontend_gmail_redirect(), ) diff --git a/backend/danswer/connectors/google_drive/connector_auth.py b/backend/danswer/connectors/google_drive/connector_auth.py index ae08eca7f..0f47727e6 100644 --- a/backend/danswer/connectors/google_drive/connector_auth.py +++ b/backend/danswer/connectors/google_drive/connector_auth.py @@ -106,7 +106,7 @@ def update_credential_access_tokens( ) -> OAuthCredentials | None: app_credentials = get_google_app_cred() flow = InstalledAppFlow.from_client_config( - app_credentials.dict(), + app_credentials.model_dump(), scopes=SCOPES, redirect_uri=_build_frontend_google_drive_redirect(), ) diff --git a/backend/danswer/connectors/zulip/schemas.py b/backend/danswer/connectors/zulip/schemas.py index 76a1cb5cc..385272cb4 100644 --- a/backend/danswer/connectors/zulip/schemas.py +++ b/backend/danswer/connectors/zulip/schemas.py @@ -3,6 +3,7 @@ from typing import List from typing import Optional from pydantic import BaseModel +from pydantic import Field class Message(BaseModel): @@ -18,11 +19,11 @@ class Message(BaseModel): sender_realm_str: str subject: str topic_links: Optional[List[Any]] = None - last_edit_timestamp: Optional[int] = None - edit_history: Any + last_edit_timestamp: Optional[int] + edit_history: Any = None reactions: List[Any] submessages: List[Any] - flags: List[str] = [] + flags: List[str] = Field(default_factory=list) display_recipient: Optional[str] = None type: Optional[str] = None stream_id: int @@ -39,4 +40,4 @@ class GetMessagesResponse(BaseModel): found_newest: Optional[bool] = None history_limited: Optional[bool] = None anchor: Optional[str] = None - messages: List[Message] = [] + messages: List[Message] = Field(default_factory=list) diff --git a/backend/danswer/db/document_set.py b/backend/danswer/db/document_set.py index 29227fbcb..2de61a491 100644 --- a/backend/danswer/db/document_set.py +++ b/backend/danswer/db/document_set.py @@ -26,8 +26,11 @@ from danswer.db.models import User__UserGroup from danswer.db.models import UserRole from danswer.server.features.document_set.models import DocumentSetCreationRequest from danswer.server.features.document_set.models import DocumentSetUpdateRequest +from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import fetch_versioned_implementation +logger = setup_logger() + def _add_user_filters( stmt: Select, user: User | None, get_editable: bool = True @@ -233,9 +236,9 @@ def insert_document_set( ) db_session.commit() - except: + except Exception as e: db_session.rollback() - raise + logger.error(f"Error creating document set: {e}") return new_document_set_row, ds_cc_pairs diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py index 821006cfa..152cb1305 100644 --- a/backend/danswer/db/llm.py +++ b/backend/danswer/db/llm.py @@ -46,10 +46,10 @@ def upsert_cloud_embedding_provider( .first() ) if existing_provider: - for key, value in provider.dict().items(): + for key, value in provider.model_dump().items(): setattr(existing_provider, key, value) else: - new_provider = CloudEmbeddingProviderModel(**provider.dict()) + new_provider = CloudEmbeddingProviderModel(**provider.model_dump()) db_session.add(new_provider) existing_provider = new_provider db_session.commit() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 82c2247c6..3cdec3239 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -5,7 +5,7 @@ from typing import Any from typing import Literal from typing import NotRequired from typing import Optional -from typing import TypedDict +from typing_extensions import TypedDict # noreorder from uuid import UUID from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID diff --git a/backend/danswer/file_store/models.py b/backend/danswer/file_store/models.py index f26fa4ca5..d944a2fd2 100644 --- a/backend/danswer/file_store/models.py +++ b/backend/danswer/file_store/models.py @@ -1,7 +1,7 @@ import base64 from enum import Enum from typing import NotRequired -from typing import TypedDict +from typing_extensions import TypedDict # noreorder from pydantic import BaseModel diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index bde7cce03..f7d8f4e74 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -156,7 +156,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder): title_embed_dict[title] = title_embedding new_embedded_chunk = IndexChunk( - **chunk.dict(), + **chunk.model_dump(), embeddings=ChunkEmbedding( full_embedding=chunk_embeddings[0], mini_chunk_embeddings=chunk_embeddings[1:], diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index 335def573..3517b5576 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -3,6 +3,7 @@ from functools import partial from typing import Protocol from pydantic import BaseModel +from pydantic import ConfigDict from sqlalchemy.orm import Session from danswer.access.access import get_access_for_documents @@ -40,9 +41,7 @@ logger = setup_logger() class DocumentBatchPrepareContext(BaseModel): updatable_docs: list[Document] id_to_db_doc_map: dict[str, DBDocument] - - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) class IndexingPipelineProtocol(Protocol): diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index 83ab42c35..b23de0eb4 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING from pydantic import BaseModel +from pydantic import Field from danswer.access.models import DocumentAccess from danswer.connectors.models import Document @@ -24,9 +25,8 @@ class BaseChunk(BaseModel): chunk_id: int blurb: str # The first sentence(s) of the first Section of the chunk content: str - source_links: dict[ - int, str - ] | None # Holds the link and the offsets into the raw Chunk text + # Holds the link and the offsets into the raw Chunk text + source_links: dict[int, str] | None section_continuation: bool # True if this Chunk's start is not at the start of a Section @@ -47,7 +47,7 @@ class DocAwareChunk(BaseChunk): mini_chunk_texts: list[str] | None - large_chunk_reference_ids: list[int] = [] + large_chunk_reference_ids: list[int] = Field(default_factory=list) def to_short_descriptor(self) -> str: """Used when logging the identity of a chunk""" @@ -85,7 +85,7 @@ class DocMetadataAwareIndexChunk(IndexChunk): document_sets: set[str], boost: int, ) -> "DocMetadataAwareIndexChunk": - index_chunk_data = index_chunk.dict() + index_chunk_data = index_chunk.model_dump() return cls( **index_chunk_data, access=access, @@ -102,6 +102,9 @@ class EmbeddingModelDetail(BaseModel): provider_type: EmbeddingProvider | None = None api_key: str | None = None + # This disables the "model_" protected namespace for pydantic + model_config = {"protected_namespaces": ()} + @classmethod def from_db_model( cls, @@ -123,6 +126,9 @@ class IndexingSetting(EmbeddingModelDetail): index_name: str | None multipass_indexing: bool + # This disables the "model_" protected namespace for pydantic + model_config = {"protected_namespaces": ()} + @classmethod def from_db_model(cls, search_settings: "SearchSettings") -> "IndexingSetting": return cls( diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py index 3d05a08c4..fb5fa9c31 100644 --- a/backend/danswer/llm/answering/models.py +++ b/backend/danswer/llm/answering/models.py @@ -1,6 +1,5 @@ from collections.abc import Callable from collections.abc import Iterator -from typing import Any from typing import TYPE_CHECKING from langchain.schema.messages import AIMessage @@ -8,8 +7,9 @@ from langchain.schema.messages import BaseMessage from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage from pydantic import BaseModel +from pydantic import ConfigDict from pydantic import Field -from pydantic import root_validator +from pydantic import model_validator from danswer.chat.models import AnswerQuestionStreamReturn from danswer.configs.constants import MessageType @@ -117,22 +117,19 @@ class AnswerStyleConfig(BaseModel): default_factory=DocumentPruningConfig ) - @root_validator - def check_quotes_and_citation(cls, values: dict[str, Any]) -> dict[str, Any]: - citation_config = values.get("citation_config") - quotes_config = values.get("quotes_config") - - if citation_config is None and quotes_config is None: + @model_validator(mode="after") + def check_quotes_and_citation(self) -> "AnswerStyleConfig": + if self.citation_config is None and self.quotes_config is None: raise ValueError( "One of `citation_config` or `quotes_config` must be provided" ) - if citation_config is not None and quotes_config is not None: + if self.citation_config is not None and self.quotes_config is not None: raise ValueError( "Only one of `citation_config` or `quotes_config` must be provided" ) - return values + return self class PromptConfig(BaseModel): @@ -160,6 +157,4 @@ class PromptConfig(BaseModel): include_citations=model.include_citations, ) - # needed so that this can be passed into lru_cache funcs - class Config: - frozen = True + model_config = ConfigDict(frozen=True) diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index 3a489a6c1..b0e7d8034 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -62,13 +62,13 @@ def _convert_litellm_message_to_langchain_message( litellm_message: litellm.Message, ) -> BaseMessage: # Extracting the basic attributes from the litellm message - content = litellm_message.content + content = litellm_message.content or "" role = litellm_message.role # Handling function calls and tool calls if present tool_calls = ( cast( - list[litellm.utils.ChatCompletionMessageToolCall], + list[litellm.ChatCompletionMessageToolCall], litellm_message.tool_calls, ) if hasattr(litellm_message, "tool_calls") @@ -87,7 +87,7 @@ def _convert_litellm_message_to_langchain_message( "args": json.loads(tool_call.function.arguments), "id": tool_call.id, } - for tool_call in tool_calls + for tool_call in (tool_calls if tool_calls else []) ], ) elif role == "system": @@ -296,9 +296,11 @@ class DefaultMultiLLM(LLM): response = cast( litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False) ) - return _convert_litellm_message_to_langchain_message( - response.choices[0].message - ) + choice = response.choices[0] + if hasattr(choice, "message"): + return _convert_litellm_message_to_langchain_message(choice.message) + else: + raise ValueError("Unexpected response choice type") def _stream_implementation( self, @@ -314,7 +316,10 @@ class DefaultMultiLLM(LLM): return output = None - response = self._completion(prompt, tools, tool_choice, True) + response = cast( + litellm.CustomStreamWrapper, + self._completion(prompt, tools, tool_choice, True), + ) try: for part in response: if len(part["choices"]) == 0: diff --git a/backend/danswer/llm/interfaces.py b/backend/danswer/llm/interfaces.py index 63bd45ba7..5e39792c3 100644 --- a/backend/danswer/llm/interfaces.py +++ b/backend/danswer/llm/interfaces.py @@ -21,9 +21,12 @@ class LLMConfig(BaseModel): model_provider: str model_name: str temperature: float - api_key: str | None - api_base: str | None - api_version: str | None + api_key: str | None = None + api_base: str | None = None + api_version: str | None = None + + # This disables the "model_" protected namespace for pydantic + model_config = {"protected_namespaces": ()} def log_prompt(prompt: LanguageModelInput) -> None: diff --git a/backend/danswer/llm/override_models.py b/backend/danswer/llm/override_models.py index 1ecb3192f..08e425891 100644 --- a/backend/danswer/llm/override_models.py +++ b/backend/danswer/llm/override_models.py @@ -11,6 +11,9 @@ class LLMOverride(BaseModel): model_version: str | None = None temperature: float | None = None + # This disables the "model_" protected namespace for pydantic + model_config = {"protected_namespaces": ()} + class PromptOverride(BaseModel): system_prompt: str | None = None diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index 6d807fb72..b7835c4e9 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -110,7 +110,7 @@ class EmbeddingModel: def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse: def _make_request() -> EmbedResponse: response = requests.post( - self.embed_server_endpoint, json=embed_request.dict() + self.embed_server_endpoint, json=embed_request.model_dump() ) try: response.raise_for_status() @@ -255,7 +255,7 @@ class RerankingModel: ) response = requests.post( - self.rerank_server_endpoint, json=rerank_request.dict() + self.rerank_server_endpoint, json=rerank_request.model_dump() ) response.raise_for_status() @@ -288,7 +288,7 @@ class QueryAnalysisModel: ) response = requests.post( - self.intent_server_endpoint, json=intent_request.dict() + self.intent_server_endpoint, json=intent_request.model_dump() ) response.raise_for_status() diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 823057880..a5a0fe0da 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -326,7 +326,7 @@ def stream_search_answer( db_session=session, ) for obj in objects: - yield get_json_line(obj.dict()) + yield get_json_line(obj.model_dump()) def get_search_answer( diff --git a/backend/danswer/one_shot_answer/models.py b/backend/danswer/one_shot_answer/models.py index c36a6a3ca..d7e819756 100644 --- a/backend/danswer/one_shot_answer/models.py +++ b/backend/danswer/one_shot_answer/models.py @@ -1,8 +1,6 @@ -from typing import Any - from pydantic import BaseModel from pydantic import Field -from pydantic import root_validator +from pydantic import model_validator from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerContexts @@ -21,7 +19,7 @@ class QueryRephrase(BaseModel): class ThreadMessage(BaseModel): message: str - sender: str | None + sender: str | None = None role: MessageType = MessageType.USER @@ -45,21 +43,16 @@ class DirectQARequest(ChunkContext): # If True, skips generative an AI response to the search query skip_gen_ai_answer_generation: bool = False - @root_validator - def check_chain_of_thought_and_prompt_id( - cls, values: dict[str, Any] - ) -> dict[str, Any]: - chain_of_thought = values.get("chain_of_thought") - prompt_id = values.get("prompt_id") - - if chain_of_thought and prompt_id is not None: + @model_validator(mode="after") + def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest": + if self.chain_of_thought and self.prompt_id is not None: raise ValueError( "If chain_of_thought is True, prompt_id must be None" "The chain of thought prompt is only for question " "answering and does not accept customizing." ) - return values + return self class OneShotQAResponse(BaseModel): diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index 1af43ed82..576d1503b 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -2,7 +2,9 @@ from datetime import datetime from typing import Any from pydantic import BaseModel -from pydantic import validator +from pydantic import ConfigDict +from pydantic import Field +from pydantic import field_validator from danswer.configs.chat_configs import NUM_RETURNED_HITS from danswer.configs.constants import DocumentSource @@ -112,7 +114,8 @@ class ChunkContext(BaseModel): chunks_below: int | None = None full_doc: bool = False - @validator("chunks_above", "chunks_below", pre=True, each_item=False) + @field_validator("chunks_above", "chunks_below") + @classmethod def check_non_negative(cls, value: int, field: Any) -> int: if value is not None and value < 0: raise ValueError(f"{field.name} must be non-negative") @@ -137,9 +140,7 @@ class SearchRequest(ChunkContext): hybrid_alpha: float | None = None rerank_settings: RerankingDetails | None = None evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED - - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) class SearchQuery(ChunkContext): @@ -163,9 +164,7 @@ class SearchQuery(ChunkContext): num_hits: int = NUM_RETURNED_HITS offset: int = 0 - - class Config: - frozen = True + model_config = ConfigDict(frozen=True) class RetrievalDetails(ChunkContext): @@ -209,7 +208,7 @@ class InferenceChunk(BaseChunk): updated_at: datetime | None primary_owners: list[str] | None = None secondary_owners: list[str] | None = None - large_chunk_reference_ids: list[int] = [] + large_chunk_reference_ids: list[int] = Field(default_factory=list) @property def unique_id(self) -> str: @@ -268,7 +267,7 @@ class InferenceChunkUncleaned(InferenceChunk): # Assumes the cleaning has already been applied and just needs to translate to the right type inference_chunk_data = { k: v - for k, v in self.dict().items() + for k, v in self.model_dump().items() if k not in ["metadata_suffix"] # May be other fields to throw out in the future } @@ -288,7 +287,7 @@ class SearchDoc(BaseModel): document_id: str chunk_ind: int semantic_identifier: str - link: str | None + link: str | None = None blurb: str source_type: DocumentSource boost: int @@ -297,7 +296,7 @@ class SearchDoc(BaseModel): # be `True` when doing an admin search hidden: bool metadata: dict[str, str | list[str]] - score: float | None + score: float | None = None is_relevant: bool | None = None relevance_explanation: str | None = None # Matched sections in the doc. Uses Vespa syntax e.g. TEXT @@ -305,13 +304,13 @@ class SearchDoc(BaseModel): # ["the answer is 42", "the answer is 42""] match_highlights: list[str] # when the doc was last updated - updated_at: datetime | None - primary_owners: list[str] | None - secondary_owners: list[str] | None + updated_at: datetime | None = None + primary_owners: list[str] | None = None + secondary_owners: list[str] | None = None is_internet: bool = False - def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore - initial_dict = super().dict(*args, **kwargs) # type: ignore + def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore + initial_dict = super().model_dump(*args, **kwargs) # type: ignore initial_dict["updated_at"] = ( self.updated_at.isoformat() if self.updated_at else None ) @@ -329,7 +328,7 @@ class SavedSearchDoc(SearchDoc): """IMPORTANT: careful using this and not providing a db_doc_id If db_doc_id is not provided, it won't be able to actually fetch the saved doc and info later on. So only skip providing this if the SavedSearchDoc will not be used in the future""" - search_doc_data = search_doc.dict() + search_doc_data = search_doc.model_dump() search_doc_data["score"] = search_doc_data.get("score") or 0.0 return cls(**search_doc_data, db_doc_id=db_doc_id) diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py index bbc1ef412..2ee428f00 100644 --- a/backend/danswer/secondary_llm_flows/query_validation.py +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -74,7 +74,7 @@ def stream_query_answerability( QueryValidationResponse( reasoning="Query Answerability Evaluation feature is turned off", answerable=True, - ).dict() + ).model_dump() ) return @@ -85,7 +85,7 @@ def stream_query_answerability( QueryValidationResponse( reasoning="Generative AI is turned off - skipping check", answerable=True, - ).dict() + ).model_dump() ) return messages = get_query_validation_messages(user_query) @@ -107,7 +107,7 @@ def stream_query_answerability( remaining = model_output[reason_ind + len(THOUGHT_PAT.upper()) :] if remaining: yield get_json_line( - DanswerAnswerPiece(answer_piece=remaining).dict() + DanswerAnswerPiece(answer_piece=remaining).model_dump() ) continue @@ -116,7 +116,7 @@ def stream_query_answerability( if hold_answerable == ANSWERABLE_PAT.upper()[: len(hold_answerable)]: continue yield get_json_line( - DanswerAnswerPiece(answer_piece=hold_answerable).dict() + DanswerAnswerPiece(answer_piece=hold_answerable).model_dump() ) hold_answerable = "" @@ -124,11 +124,13 @@ def stream_query_answerability( answerable = extract_answerability_bool(model_output) yield get_json_line( - QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict() + QueryValidationResponse( + reasoning=reasoning, answerable=answerable + ).model_dump() ) except Exception as e: # exception is logged in the answer_question method, no need to re-log error = StreamingError(error=str(e)) - yield get_json_line(error.dict()) + yield get_json_line(error.model_dump()) logger.exception("Failed to validate Query") return diff --git a/backend/danswer/server/danswer_api/models.py b/backend/danswer/server/danswer_api/models.py index 8a534c3e3..17d6a32c0 100644 --- a/backend/danswer/server/danswer_api/models.py +++ b/backend/danswer/server/danswer_api/models.py @@ -5,7 +5,7 @@ from danswer.connectors.models import DocumentBase class IngestionDocument(BaseModel): document: DocumentBase - cc_pair_id: int | None + cc_pair_id: int | None = None class IngestionResult(BaseModel): @@ -16,4 +16,4 @@ class IngestionResult(BaseModel): class DocMinimalInfo(BaseModel): document_id: str semantic_id: str - link: str | None + link: str | None = None diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index dd28eee1a..69ae99163 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -170,7 +170,7 @@ def associate_credential_to_connector( connector_id=connector_id, credential_id=credential_id, cc_pair_name=metadata.name, - is_public=metadata.is_public, + is_public=metadata.is_public or True, groups=metadata.groups, ) diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index ad23fd3b9..ba011afc1 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -3,6 +3,7 @@ from typing import Any from uuid import UUID from pydantic import BaseModel +from pydantic import Field from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX from danswer.configs.constants import DocumentSource @@ -40,14 +41,15 @@ class ConnectorBase(BaseModel): source: DocumentSource input_type: InputType connector_specific_config: dict[str, Any] - refresh_freq: int | None # In seconds, None for one time index with no refresh - prune_freq: int | None - indexing_start: datetime | None + # In seconds, None for one time index with no refresh + refresh_freq: int | None = None + prune_freq: int | None = None + indexing_start: datetime | None = None class ConnectorUpdateRequest(ConnectorBase): is_public: bool | None = None - groups: list[int] | None = None + groups: list[int] = Field(default_factory=list) class ConnectorSnapshot(ConnectorBase): @@ -93,7 +95,7 @@ class CredentialBase(BaseModel): source: DocumentSource name: str | None = None curator_public: bool = False - groups: list[int] = [] + groups: list[int] = Field(default_factory=list) class CredentialSnapshot(CredentialBase): @@ -254,21 +256,21 @@ class ConnectorCredentialPairIdentifier(BaseModel): class ConnectorCredentialPairMetadata(BaseModel): - name: str | None - is_public: bool - groups: list[int] | None + name: str | None = None + is_public: bool | None = None + groups: list[int] = Field(default_factory=list) class ConnectorCredentialPairDescriptor(BaseModel): id: int - name: str | None + name: str | None = None connector: ConnectorSnapshot credential: CredentialSnapshot class RunConnectorRequest(BaseModel): connector_id: int - credential_ids: list[int] | None + credential_ids: list[int] | None = None from_beginning: bool = False diff --git a/backend/danswer/server/features/document_set/models.py b/backend/danswer/server/features/document_set/models.py index 05ada42c8..55f337654 100644 --- a/backend/danswer/server/features/document_set/models.py +++ b/backend/danswer/server/features/document_set/models.py @@ -1,6 +1,7 @@ from uuid import UUID from pydantic import BaseModel +from pydantic import Field from danswer.db.models import DocumentSet as DocumentSetDBModel from danswer.server.documents.models import ConnectorCredentialPairDescriptor @@ -14,8 +15,8 @@ class DocumentSetCreationRequest(BaseModel): cc_pair_ids: list[int] is_public: bool # For Private Document Sets, who should be able to access these - users: list[UUID] | None = None - groups: list[int] | None = None + users: list[UUID] = Field(default_factory=list) + groups: list[int] = Field(default_factory=list) class DocumentSetUpdateRequest(BaseModel): diff --git a/backend/danswer/server/features/folder/models.py b/backend/danswer/server/features/folder/models.py index d665fd919..d7b161414 100644 --- a/backend/danswer/server/features/folder/models.py +++ b/backend/danswer/server/features/folder/models.py @@ -19,7 +19,7 @@ class FolderCreationRequest(BaseModel): class FolderUpdateRequest(BaseModel): - folder_name: str | None + folder_name: str | None = None class FolderChatSessionRequest(BaseModel): diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index 5e2dec982..777ef2037 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -1,6 +1,7 @@ from uuid import UUID from pydantic import BaseModel +from pydantic import Field from danswer.db.models import Persona from danswer.db.models import StarterMessage @@ -31,8 +32,8 @@ class CreatePersonaRequest(BaseModel): llm_model_version_override: str | None = None starter_messages: list[StarterMessage] | None = None # For Private Personas, who should be able to access these - users: list[UUID] | None = None - groups: list[int] | None = None + users: list[UUID] = Field(default_factory=list) + groups: list[int] = Field(default_factory=list) icon_color: str | None = None icon_shape: int | None = None uploaded_image_id: str | None = None # New field for uploaded image diff --git a/backend/danswer/server/features/tool/api.py b/backend/danswer/server/features/tool/api.py index b1f57a1a9..9635a2765 100644 --- a/backend/danswer/server/features/tool/api.py +++ b/backend/danswer/server/features/tool/api.py @@ -26,14 +26,14 @@ admin_router = APIRouter(prefix="/admin/tool") class CustomToolCreate(BaseModel): name: str - description: str | None + description: str | None = None definition: dict[str, Any] class CustomToolUpdate(BaseModel): - name: str | None - description: str | None - definition: dict[str, Any] | None + name: str | None = None + description: str | None = None + definition: dict[str, Any] | None = None def _validate_tool_definition(definition: dict[str, Any]) -> None: diff --git a/backend/danswer/server/manage/llm/models.py b/backend/danswer/server/manage/llm/models.py index ff89d8afc..3ef669710 100644 --- a/backend/danswer/server/manage/llm/models.py +++ b/backend/danswer/server/manage/llm/models.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING from pydantic import BaseModel +from pydantic import Field from danswer.llm.llm_provider_options import fetch_models_for_provider @@ -56,26 +57,26 @@ class LLMProviderDescriptor(BaseModel): class LLMProvider(BaseModel): name: str provider: str - api_key: str | None - api_base: str | None - api_version: str | None - custom_config: dict[str, str] | None + api_key: str | None = None + api_base: str | None = None + api_version: str | None = None + custom_config: dict[str, str] | None = None default_model_name: str - fast_default_model_name: str | None + fast_default_model_name: str | None = None is_public: bool = True - groups: list[int] | None = None - display_model_names: list[str] | None + groups: list[int] = Field(default_factory=list) + display_model_names: list[str] | None = None class LLMProviderUpsertRequest(LLMProvider): # should only be used for a "custom" provider # for default providers, the built-in model names are used - model_names: list[str] | None + model_names: list[str] | None = None class FullLLMProvider(LLMProvider): id: int - is_default_provider: bool | None + is_default_provider: bool | None = None model_names: list[str] @classmethod diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index ccc7f7757..160c90bdb 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -1,10 +1,11 @@ from datetime import datetime -from typing import Any from typing import TYPE_CHECKING from pydantic import BaseModel -from pydantic import root_validator -from pydantic import validator +from pydantic import ConfigDict +from pydantic import Field +from pydantic import field_validator +from pydantic import model_validator from danswer.auth.schemas import UserRole from danswer.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY @@ -39,8 +40,8 @@ class AuthTypeResponse(BaseModel): class UserPreferences(BaseModel): - chosen_assistants: list[int] | None - default_model: str | None + chosen_assistants: list[int] | None = None + default_model: str | None = None class UserInfo(BaseModel): @@ -158,7 +159,8 @@ class StandardAnswerCreationRequest(BaseModel): answer: str categories: list[int] - @validator("categories", pre=True) + @field_validator("categories", mode="before") + @classmethod def validate_categories(cls, value: list[int]) -> list[int]: if len(value) < 1: raise ValueError( @@ -170,9 +172,7 @@ class StandardAnswerCreationRequest(BaseModel): class SlackBotTokens(BaseModel): bot_token: str app_token: str - - class Config: - frozen = True + model_config = ConfigDict(frozen=True) class SlackBotConfigCreationRequest(BaseModel): @@ -180,23 +180,24 @@ class SlackBotConfigCreationRequest(BaseModel): # in the future, `document_sets` will probably be replaced # by an optional `PersonaSnapshot` object. Keeping it like this # for now for simplicity / speed of development - document_sets: list[int] | None + document_sets: list[int] | None = None persona_id: ( int | None - ) # NOTE: only one of `document_sets` / `persona_id` should be set + ) = None # NOTE: only one of `document_sets` / `persona_id` should be set channel_names: list[str] respond_tag_only: bool = False respond_to_bots: bool = False enable_auto_filters: bool = False # If no team members, assume respond in the channel to everyone - respond_member_group_list: list[str] = [] - answer_filters: list[AllowedAnswerFilters] = [] + respond_member_group_list: list[str] = Field(default_factory=list) + answer_filters: list[AllowedAnswerFilters] = Field(default_factory=list) # list of user emails follow_up_tags: list[str] | None = None response_type: SlackBotResponseType - standard_answer_categories: list[int] = [] + standard_answer_categories: list[int] = Field(default_factory=list) - @validator("answer_filters", pre=True) + @field_validator("answer_filters", mode="before") + @classmethod def validate_filters(cls, value: list[str]) -> list[str]: if any(test not in VALID_SLACK_FILTERS for test in value): raise ValueError( @@ -204,14 +205,12 @@ class SlackBotConfigCreationRequest(BaseModel): ) return value - @root_validator - def validate_document_sets_and_persona_id( - cls, values: dict[str, Any] - ) -> dict[str, Any]: - if values.get("document_sets") and values.get("persona_id"): + @model_validator(mode="after") + def validate_document_sets_and_persona_id(self) -> "SlackBotConfigCreationRequest": + if self.document_sets and self.persona_id: raise ValueError("Only one of `document_sets` / `persona_id` should be set") - return values + return self class SlackBotConfig(BaseModel): diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index 16ebc3e2f..d2fd981b5 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -323,7 +323,7 @@ def verify_user_logged_in( class ChosenDefaultModelRequest(BaseModel): - default_model: str | None + default_model: str | None = None @router.patch("/user/default-model") diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index fa70189f1..9c78851eb 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -4,7 +4,6 @@ from typing import TypeVar from uuid import UUID from pydantic import BaseModel -from pydantic.generics import GenericModel from danswer.auth.schemas import UserRole from danswer.auth.schemas import UserStatus @@ -13,7 +12,7 @@ from danswer.auth.schemas import UserStatus DataT = TypeVar("DataT") -class StatusResponse(GenericModel, Generic[DataT]): +class StatusResponse(BaseModel, Generic[DataT]): success: bool message: Optional[str] = None data: Optional[DataT] = None diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 8a2c89af5..55d1094ea 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Any from pydantic import BaseModel -from pydantic import root_validator +from pydantic import model_validator from danswer.chat.models import RetrievalDocs from danswer.configs.constants import DocumentSource @@ -54,16 +54,11 @@ class ChatFeedbackRequest(BaseModel): feedback_text: str | None = None predefined_feedback: str | None = None - @root_validator - def check_is_positive_or_feedback_text(cls: BaseModel, values: dict) -> dict: - is_positive, feedback_text = values.get("is_positive"), values.get( - "feedback_text" - ) - - if is_positive is None and feedback_text is None: + @model_validator(mode="after") + def check_is_positive_or_feedback_text(self) -> "ChatFeedbackRequest": + if self.is_positive is None and self.feedback_text is None: raise ValueError("Empty feedback received.") - - return values + return self """ @@ -112,18 +107,13 @@ class CreateChatMessageRequest(ChunkContext): # used for seeded chats to kick off the generation of an AI answer use_existing_user_message: bool = False - @root_validator - def check_search_doc_ids_or_retrieval_options(cls: BaseModel, values: dict) -> dict: - search_doc_ids, retrieval_options = values.get("search_doc_ids"), values.get( - "retrieval_options" - ) - - if search_doc_ids is None and retrieval_options is None: + @model_validator(mode="after") + def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest": + if self.search_doc_ids is None and self.retrieval_options is None: raise ValueError( "Either search_doc_ids or retrieval_options must be provided, but not both or neither." ) - - return values + return self class ChatMessageIdentifier(BaseModel): @@ -149,7 +139,7 @@ class ChatSessionDetails(BaseModel): persona_id: int time_created: str shared_status: ChatSessionSharedStatus - folder_id: int | None + folder_id: int | None = None current_alternate_model: str | None = None @@ -162,36 +152,36 @@ class SearchFeedbackRequest(BaseModel): document_id: str document_rank: int click: bool - search_feedback: SearchFeedbackType | None + search_feedback: SearchFeedbackType | None = None - @root_validator - def check_click_or_search_feedback(cls: BaseModel, values: dict) -> dict: - click, feedback = values.get("click"), values.get("search_feedback") + @model_validator(mode="after") + def check_click_or_search_feedback(self) -> "SearchFeedbackRequest": + click, feedback = self.click, self.search_feedback if click is False and feedback is None: raise ValueError("Empty feedback received.") - return values + return self class ChatMessageDetail(BaseModel): message_id: int - parent_message: int | None - latest_child_message: int | None + parent_message: int | None = None + latest_child_message: int | None = None message: str - rephrased_query: str | None - context_docs: RetrievalDocs | None + rephrased_query: str | None = None + context_docs: RetrievalDocs | None = None message_type: MessageType time_sent: datetime - alternate_assistant_id: str | None overridden_model: str | None + alternate_assistant_id: int | None = None # Dict mapping citation number to db_doc_id chat_session_id: int | None = None - citations: dict[int, int] | None + citations: dict[int, int] | None = None files: list[FileDescriptor] tool_calls: list[ToolCallFinalResult] - def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore - initial_dict = super().dict(*args, **kwargs) # type: ignore + def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore + initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore initial_dict["time_sent"] = self.time_sent.isoformat() return initial_dict @@ -226,7 +216,3 @@ class AdminSearchRequest(BaseModel): class AdminSearchResponse(BaseModel): documents: list[SearchDoc] - - -class DanswerAnswer(BaseModel): - answer: str | None diff --git a/backend/danswer/server/settings/api.py b/backend/danswer/server/settings/api.py index 8dfb07869..3330f6cc5 100644 --- a/backend/danswer/server/settings/api.py +++ b/backend/danswer/server/settings/api.py @@ -64,7 +64,7 @@ def fetch_settings( needs_reindexing = False return UserSettings( - **general_settings.dict(), + **general_settings.model_dump(), notifications=user_notifications, needs_reindexing=needs_reindexing ) diff --git a/backend/danswer/server/settings/store.py b/backend/danswer/server/settings/store.py index dcf31c46f..6f2872f40 100644 --- a/backend/danswer/server/settings/store.py +++ b/backend/danswer/server/settings/store.py @@ -12,10 +12,10 @@ def load_settings() -> Settings: settings = Settings(**cast(dict, dynamic_config_store.load(KV_SETTINGS_KEY))) except ConfigNotFoundError: settings = Settings() - dynamic_config_store.store(KV_SETTINGS_KEY, settings.dict()) + dynamic_config_store.store(KV_SETTINGS_KEY, settings.model_dump()) return settings def store_settings(settings: Settings) -> None: - get_dynamic_config_store().store(KV_SETTINGS_KEY, settings.dict()) + get_dynamic_config_store().store(KV_SETTINGS_KEY, settings.model_dump()) diff --git a/backend/danswer/tools/built_in_tools.py b/backend/danswer/tools/built_in_tools.py index 59159e717..99b2ae3bb 100644 --- a/backend/danswer/tools/built_in_tools.py +++ b/backend/danswer/tools/built_in_tools.py @@ -1,6 +1,6 @@ import os from typing import Type -from typing import TypedDict +from typing_extensions import TypedDict # noreorder from sqlalchemy import not_ from sqlalchemy import or_ diff --git a/backend/danswer/tools/images/image_generation_tool.py b/backend/danswer/tools/images/image_generation_tool.py index 1a525f6a2..fe839b7d6 100644 --- a/backend/danswer/tools/images/image_generation_tool.py +++ b/backend/danswer/tools/images/image_generation_tool.py @@ -254,6 +254,6 @@ class ImageGenerationTool(Tool): list[ImageGenerationResponse], args[0].response ) return [ - image_generation_response.dict() + image_generation_response.model_dump() for image_generation_response in image_generation_responses ] diff --git a/backend/danswer/tools/internet_search/internet_search_tool.py b/backend/danswer/tools/internet_search/internet_search_tool.py index 0f92deadd..2640afcdf 100644 --- a/backend/danswer/tools/internet_search/internet_search_tool.py +++ b/backend/danswer/tools/internet_search/internet_search_tool.py @@ -187,7 +187,7 @@ class InternetSearchTool(Tool): self, *args: ToolResponse ) -> str | list[str | dict[str, Any]]: search_response = cast(InternetSearchResponse, args[0].response) - return json.dumps(search_response.dict()) + return json.dumps(search_response.model_dump()) def _perform_search(self, query: str) -> InternetSearchResponse: response = self.client.get( @@ -230,4 +230,4 @@ class InternetSearchTool(Tool): def final_result(self, *args: ToolResponse) -> JSON_ro: search_response = cast(InternetSearchResponse, args[0].response) - return search_response.dict() + return search_response.model_dump() diff --git a/backend/danswer/tools/message.py b/backend/danswer/tools/message.py index 826f4a30b..b0259c29b 100644 --- a/backend/danswer/tools/message.py +++ b/backend/danswer/tools/message.py @@ -4,10 +4,12 @@ from typing import Any from langchain_core.messages.ai import AIMessage from langchain_core.messages.tool import ToolCall from langchain_core.messages.tool import ToolMessage -from pydantic import BaseModel +from pydantic.v1 import BaseModel as BaseModel__v1 from danswer.natural_language_processing.utils import BaseTokenizer +# Langchain has their own version of pydantic which is version 1 + def build_tool_message( tool_call: ToolCall, tool_content: str | list[str | dict[str, Any]] @@ -19,7 +21,7 @@ def build_tool_message( ) -class ToolCallSummary(BaseModel): +class ToolCallSummary(BaseModel__v1): tool_call_request: AIMessage tool_call_result: ToolMessage diff --git a/backend/danswer/tools/models.py b/backend/danswer/tools/models.py index 53940dcea..052e4293a 100644 --- a/backend/danswer/tools/models.py +++ b/backend/danswer/tools/models.py @@ -1,12 +1,12 @@ from typing import Any from pydantic import BaseModel -from pydantic import root_validator +from pydantic import model_validator class ToolResponse(BaseModel): id: str | None = None - response: Any + response: Any = None class ToolCallKickoff(BaseModel): @@ -19,12 +19,10 @@ class ToolRunnerResponse(BaseModel): tool_response: ToolResponse | None = None tool_message_content: str | list[str | dict[str, Any]] | None = None - @root_validator - def validate_tool_runner_response( - cls, values: dict[str, ToolResponse | str] - ) -> dict[str, ToolResponse | str]: + @model_validator(mode="after") + def validate_tool_runner_response(self) -> "ToolRunnerResponse": fields = ["tool_response", "tool_message_content", "tool_run_kickoff"] - provided = sum(1 for field in fields if values.get(field) is not None) + provided = sum(1 for field in fields if getattr(self, field) is not None) if provided != 1: raise ValueError( @@ -32,8 +30,10 @@ class ToolRunnerResponse(BaseModel): "or 'tool_run_kickoff' must be provided" ) - return values + return self class ToolCallFinalResult(ToolCallKickoff): - tool_result: Any # we would like to use JSON_ro, but can't due to its recursive nature + tool_result: Any = ( + None # we would like to use JSON_ro, but can't due to its recursive nature + ) diff --git a/backend/ee/danswer/server/enterprise_settings/store.py b/backend/ee/danswer/server/enterprise_settings/store.py index 37dd320d7..30b72d5d2 100644 --- a/backend/ee/danswer/server/enterprise_settings/store.py +++ b/backend/ee/danswer/server/enterprise_settings/store.py @@ -30,13 +30,13 @@ def load_settings() -> EnterpriseSettings: ) except ConfigNotFoundError: settings = EnterpriseSettings() - dynamic_config_store.store(KV_ENTERPRISE_SETTINGS_KEY, settings.dict()) + dynamic_config_store.store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump()) return settings def store_settings(settings: EnterpriseSettings) -> None: - get_dynamic_config_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.dict()) + get_dynamic_config_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump()) _CUSTOM_ANALYTICS_SECRET_KEY = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY") diff --git a/backend/ee/danswer/server/query_and_chat/models.py b/backend/ee/danswer/server/query_and_chat/models.py index 8f1e3a9db..b0ce553eb 100644 --- a/backend/ee/danswer/server/query_and_chat/models.py +++ b/backend/ee/danswer/server/query_and_chat/models.py @@ -17,7 +17,7 @@ class StandardAnswerRequest(BaseModel): class StandardAnswerResponse(BaseModel): - standard_answers: list[StandardAnswer] = [] + standard_answers: list[StandardAnswer] = Field(default_factory=list) class DocumentSearchRequest(ChunkContext): diff --git a/backend/ee/danswer/server/query_history/api.py b/backend/ee/danswer/server/query_history/api.py index d8927b726..ed532a856 100644 --- a/backend/ee/danswer/server/query_history/api.py +++ b/backend/ee/danswer/server/query_history/api.py @@ -376,7 +376,7 @@ def get_query_history_as_csv( # Create an in-memory text stream stream = io.StringIO() writer = csv.DictWriter( - stream, fieldnames=list(QuestionAnswerPairSnapshot.__fields__.keys()) + stream, fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()) ) writer.writeheader() for row in question_answer_pairs: diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index fc36432df..9427335c4 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -29,20 +29,20 @@ langchain==0.1.17 langchain-community==0.0.36 langchain-core==0.1.50 langchain-text-splitters==0.0.1 -litellm==1.37.7 +litellm==1.43.18 llama-index==0.9.45 Mako==1.2.4 msal==1.26.0 nltk==3.8.1 Office365-REST-Python-Client==2.5.9 oauthlib==3.2.2 -openai==1.14.3 +openai==1.41.1 openpyxl==3.1.2 playwright==1.41.2 psutil==5.9.5 psycopg2-binary==2.9.9 pycryptodome==3.19.1 -pydantic==1.10.13 +pydantic==2.8.2 PyGithub==1.58.2 python-dateutil==2.8.2 python-gitlab==3.9.0 @@ -64,7 +64,7 @@ slack-sdk==3.20.2 SQLAlchemy[mypy]==2.0.15 starlette==0.36.3 supervisor==4.2.5 -tiktoken==0.4.0 +tiktoken==0.7.0 timeago==1.0.16 transformers==4.39.2 uvicorn==0.21.1 diff --git a/backend/requirements/model_server.txt b/backend/requirements/model_server.txt index da4de1c69..0fb0e74b6 100644 --- a/backend/requirements/model_server.txt +++ b/backend/requirements/model_server.txt @@ -3,8 +3,8 @@ einops==0.8.0 fastapi==0.109.2 google-cloud-aiplatform==1.58.0 numpy==1.26.4 -openai==1.14.3 -pydantic==1.10.13 +openai==1.41.1 +pydantic==2.8.2 retry==0.9.2 safetensors==0.4.2 sentence-transformers==2.6.1 diff --git a/backend/shared_configs/model_server_models.py b/backend/shared_configs/model_server_models.py index 21bd4c7d1..3014616c6 100644 --- a/backend/shared_configs/model_server_models.py +++ b/backend/shared_configs/model_server_models.py @@ -10,14 +10,17 @@ Embedding = list[float] class EmbedRequest(BaseModel): texts: list[str] # Can be none for cloud embedding model requests, error handling logic exists for other cases - model_name: str | None + model_name: str | None = None max_context_length: int normalize_embeddings: bool - api_key: str | None - provider_type: EmbeddingProvider | None + api_key: str | None = None + provider_type: EmbeddingProvider | None = None text_type: EmbedTextType - manual_query_prefix: str | None - manual_passage_prefix: str | None + manual_query_prefix: str | None = None + manual_passage_prefix: str | None = None + + # This disables the "model_" protected namespace for pydantic + model_config = {"protected_namespaces": ()} class EmbedResponse(BaseModel): @@ -28,8 +31,11 @@ class RerankRequest(BaseModel): query: str documents: list[str] model_name: str - provider_type: RerankerProvider | None - api_key: str | None + provider_type: RerankerProvider | None = None + api_key: str | None = None + + # This disables the "model_" protected namespace for pydantic + model_config = {"protected_namespaces": ()} class RerankResponse(BaseModel): diff --git a/backend/tests/integration/common_utils/document_sets.py b/backend/tests/integration/common_utils/document_sets.py index 9805107be..dc8986111 100644 --- a/backend/tests/integration/common_utils/document_sets.py +++ b/backend/tests/integration/common_utils/document_sets.py @@ -14,7 +14,7 @@ class DocumentSetClient: ) -> int: response = requests.post( f"{API_SERVER_URL}/manage/admin/document-set", - json=doc_set_creation_request.dict(), + json=doc_set_creation_request.model_dump(), ) response.raise_for_status() return cast(int, response.json()) diff --git a/backend/tests/integration/common_utils/llm.py b/backend/tests/integration/common_utils/llm.py index 92310480d..ba8b89d6b 100644 --- a/backend/tests/integration/common_utils/llm.py +++ b/backend/tests/integration/common_utils/llm.py @@ -31,7 +31,7 @@ class LLMProvider(BaseModel): custom_config=None, fast_default_model_name=None, is_public=True, - groups=None, + groups=[], display_model_names=None, model_names=None, ) diff --git a/backend/tests/integration/common_utils/user_groups.py b/backend/tests/integration/common_utils/user_groups.py index 35add97b3..0cd440664 100644 --- a/backend/tests/integration/common_utils/user_groups.py +++ b/backend/tests/integration/common_utils/user_groups.py @@ -12,7 +12,7 @@ class UserGroupClient: def create_user_group(user_group_creation_request: UserGroupCreate) -> int: response = requests.post( f"{API_SERVER_URL}/manage/admin/user-group", - json=user_group_creation_request.dict(), + json=user_group_creation_request.model_dump(), ) response.raise_for_status() return cast(int, response.json()["id"]) diff --git a/backend/tests/regression/answer_quality/api_utils.py b/backend/tests/regression/answer_quality/api_utils.py index 92b3b96a0..5a46032c6 100644 --- a/backend/tests/regression/answer_quality/api_utils.py +++ b/backend/tests/regression/answer_quality/api_utils.py @@ -57,7 +57,7 @@ def get_answer_from_query( "Content-Type": "application/json", } - body = new_message_request.dict() + body = new_message_request.model_dump() body["user"] = None try: response_json = requests.post(url, headers=headers, json=body).json() @@ -122,7 +122,7 @@ def create_cc_pair(env_name: str, connector_id: int, credential_id: int) -> None env_name, f"/manage/connector/{connector_id}/credential/{credential_id}" ) - body = {"name": "zip_folder_contents", "is_public": True} + body = {"name": "zip_folder_contents", "is_public": True, "groups": []} print("body:", body) response = requests.put(url, headers=GENERAL_HEADERS, json=body) if response.status_code == 200: @@ -167,7 +167,7 @@ def create_connector(env_name: str, file_paths: list[str]) -> int: indexing_start=None, ) - body = connector.dict() + body = connector.model_dump() response = requests.post(url, headers=GENERAL_HEADERS, json=body) if response.status_code == 200: return response.json()["id"]