mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 04:18:32 +02:00
Litellm bump (#2195)
* ran bump-pydantic * replace root_validator with model_validator * mostly working. some alternate assistant error. changed root_validator and typing_extensions * working generation chat. changed type * replacing .dict with .model_dump * argument needed to bring model_dump up to parity with dict() * fix a fewremaining issues -- working with llama and gpt * updating requirements file * more requirement updates * more requirement updates * fix to make search work * return type fix: * half way tpyes change * fixes for mypy and pydantic: * endpoint fix * fix pydantic protected namespaces * it works! * removed unecessary None initializations * better logging * changed default values to empty lists * mypy fixes * fixed array defaulting --------- Co-authored-by: hagen-danswer <hagen@danswer.ai>
This commit is contained in:
parent
657d2050a5
commit
50c17438d5
@ -13,7 +13,7 @@ from danswer.server.manage.models import UserPreferences
|
||||
def set_no_auth_user_preferences(
|
||||
store: DynamicConfigStore, preferences: UserPreferences
|
||||
) -> None:
|
||||
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
|
||||
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.model_dump())
|
||||
|
||||
|
||||
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
|
||||
|
@ -35,11 +35,12 @@ class QADocsResponse(RetrievalDocs):
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
|
||||
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().dict(*args, **kwargs) # type: ignore
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
initial_dict["applied_time_cutoff"] = (
|
||||
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
|
||||
)
|
||||
|
||||
return initial_dict
|
||||
|
||||
|
||||
|
@ -813,4 +813,4 @@ def stream_chat_message(
|
||||
is_connected=is_connected,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.dict())
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import TypedDict
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -125,7 +125,7 @@ def update_gmail_credential_access_tokens(
|
||||
) -> OAuthCredentials | None:
|
||||
app_credentials = get_google_app_gmail_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.dict(),
|
||||
app_credentials.model_dump(),
|
||||
scopes=SCOPES,
|
||||
redirect_uri=_build_frontend_gmail_redirect(),
|
||||
)
|
||||
|
@ -106,7 +106,7 @@ def update_credential_access_tokens(
|
||||
) -> OAuthCredentials | None:
|
||||
app_credentials = get_google_app_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.dict(),
|
||||
app_credentials.model_dump(),
|
||||
scopes=SCOPES,
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
|
@ -3,6 +3,7 @@ from typing import List
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
@ -18,11 +19,11 @@ class Message(BaseModel):
|
||||
sender_realm_str: str
|
||||
subject: str
|
||||
topic_links: Optional[List[Any]] = None
|
||||
last_edit_timestamp: Optional[int] = None
|
||||
edit_history: Any
|
||||
last_edit_timestamp: Optional[int]
|
||||
edit_history: Any = None
|
||||
reactions: List[Any]
|
||||
submessages: List[Any]
|
||||
flags: List[str] = []
|
||||
flags: List[str] = Field(default_factory=list)
|
||||
display_recipient: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
stream_id: int
|
||||
@ -39,4 +40,4 @@ class GetMessagesResponse(BaseModel):
|
||||
found_newest: Optional[bool] = None
|
||||
history_limited: Optional[bool] = None
|
||||
anchor: Optional[str] = None
|
||||
messages: List[Message] = []
|
||||
messages: List[Message] = Field(default_factory=list)
|
||||
|
@ -26,8 +26,11 @@ from danswer.db.models import User__UserGroup
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.features.document_set.models import DocumentSetCreationRequest
|
||||
from danswer.server.features.document_set.models import DocumentSetUpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _add_user_filters(
|
||||
stmt: Select, user: User | None, get_editable: bool = True
|
||||
@ -233,9 +236,9 @@ def insert_document_set(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
except:
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
raise
|
||||
logger.error(f"Error creating document set: {e}")
|
||||
|
||||
return new_document_set_row, ds_cc_pairs
|
||||
|
||||
|
@ -46,10 +46,10 @@ def upsert_cloud_embedding_provider(
|
||||
.first()
|
||||
)
|
||||
if existing_provider:
|
||||
for key, value in provider.dict().items():
|
||||
for key, value in provider.model_dump().items():
|
||||
setattr(existing_provider, key, value)
|
||||
else:
|
||||
new_provider = CloudEmbeddingProviderModel(**provider.dict())
|
||||
new_provider = CloudEmbeddingProviderModel(**provider.model_dump())
|
||||
db_session.add(new_provider)
|
||||
existing_provider = new_provider
|
||||
db_session.commit()
|
||||
|
@ -5,7 +5,7 @@ from typing import Any
|
||||
from typing import Literal
|
||||
from typing import NotRequired
|
||||
from typing import Optional
|
||||
from typing import TypedDict
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
|
||||
|
@ -1,7 +1,7 @@
|
||||
import base64
|
||||
from enum import Enum
|
||||
from typing import NotRequired
|
||||
from typing import TypedDict
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -156,7 +156,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
title_embed_dict[title] = title_embedding
|
||||
|
||||
new_embedded_chunk = IndexChunk(
|
||||
**chunk.dict(),
|
||||
**chunk.model_dump(),
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=chunk_embeddings[0],
|
||||
mini_chunk_embeddings=chunk_embeddings[1:],
|
||||
|
@ -3,6 +3,7 @@ from functools import partial
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
@ -40,9 +41,7 @@ logger = setup_logger()
|
||||
class DocumentBatchPrepareContext(BaseModel):
|
||||
updatable_docs: list[Document]
|
||||
id_to_db_doc_map: dict[str, DBDocument]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class IndexingPipelineProtocol(Protocol):
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.connectors.models import Document
|
||||
@ -24,9 +25,8 @@ class BaseChunk(BaseModel):
|
||||
chunk_id: int
|
||||
blurb: str # The first sentence(s) of the first Section of the chunk
|
||||
content: str
|
||||
source_links: dict[
|
||||
int, str
|
||||
] | None # Holds the link and the offsets into the raw Chunk text
|
||||
# Holds the link and the offsets into the raw Chunk text
|
||||
source_links: dict[int, str] | None
|
||||
section_continuation: bool # True if this Chunk's start is not at the start of a Section
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ class DocAwareChunk(BaseChunk):
|
||||
|
||||
mini_chunk_texts: list[str] | None
|
||||
|
||||
large_chunk_reference_ids: list[int] = []
|
||||
large_chunk_reference_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a chunk"""
|
||||
@ -85,7 +85,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
document_sets: set[str],
|
||||
boost: int,
|
||||
) -> "DocMetadataAwareIndexChunk":
|
||||
index_chunk_data = index_chunk.dict()
|
||||
index_chunk_data = index_chunk.model_dump()
|
||||
return cls(
|
||||
**index_chunk_data,
|
||||
access=access,
|
||||
@ -102,6 +102,9 @@ class EmbeddingModelDetail(BaseModel):
|
||||
provider_type: EmbeddingProvider | None = None
|
||||
api_key: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
@classmethod
|
||||
def from_db_model(
|
||||
cls,
|
||||
@ -123,6 +126,9 @@ class IndexingSetting(EmbeddingModelDetail):
|
||||
index_name: str | None
|
||||
multipass_indexing: bool
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, search_settings: "SearchSettings") -> "IndexingSetting":
|
||||
return cls(
|
||||
|
@ -1,6 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
@ -8,8 +7,9 @@ from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import root_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.configs.constants import MessageType
|
||||
@ -117,22 +117,19 @@ class AnswerStyleConfig(BaseModel):
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
|
||||
@root_validator
|
||||
def check_quotes_and_citation(cls, values: dict[str, Any]) -> dict[str, Any]:
|
||||
citation_config = values.get("citation_config")
|
||||
quotes_config = values.get("quotes_config")
|
||||
|
||||
if citation_config is None and quotes_config is None:
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
if self.citation_config is None and self.quotes_config is None:
|
||||
raise ValueError(
|
||||
"One of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
if citation_config is not None and quotes_config is not None:
|
||||
if self.citation_config is not None and self.quotes_config is not None:
|
||||
raise ValueError(
|
||||
"Only one of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
@ -160,6 +157,4 @@ class PromptConfig(BaseModel):
|
||||
include_citations=model.include_citations,
|
||||
)
|
||||
|
||||
# needed so that this can be passed into lru_cache funcs
|
||||
class Config:
|
||||
frozen = True
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
@ -62,13 +62,13 @@ def _convert_litellm_message_to_langchain_message(
|
||||
litellm_message: litellm.Message,
|
||||
) -> BaseMessage:
|
||||
# Extracting the basic attributes from the litellm message
|
||||
content = litellm_message.content
|
||||
content = litellm_message.content or ""
|
||||
role = litellm_message.role
|
||||
|
||||
# Handling function calls and tool calls if present
|
||||
tool_calls = (
|
||||
cast(
|
||||
list[litellm.utils.ChatCompletionMessageToolCall],
|
||||
list[litellm.ChatCompletionMessageToolCall],
|
||||
litellm_message.tool_calls,
|
||||
)
|
||||
if hasattr(litellm_message, "tool_calls")
|
||||
@ -87,7 +87,7 @@ def _convert_litellm_message_to_langchain_message(
|
||||
"args": json.loads(tool_call.function.arguments),
|
||||
"id": tool_call.id,
|
||||
}
|
||||
for tool_call in tool_calls
|
||||
for tool_call in (tool_calls if tool_calls else [])
|
||||
],
|
||||
)
|
||||
elif role == "system":
|
||||
@ -296,9 +296,11 @@ class DefaultMultiLLM(LLM):
|
||||
response = cast(
|
||||
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
|
||||
)
|
||||
return _convert_litellm_message_to_langchain_message(
|
||||
response.choices[0].message
|
||||
)
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message"):
|
||||
return _convert_litellm_message_to_langchain_message(choice.message)
|
||||
else:
|
||||
raise ValueError("Unexpected response choice type")
|
||||
|
||||
def _stream_implementation(
|
||||
self,
|
||||
@ -314,7 +316,10 @@ class DefaultMultiLLM(LLM):
|
||||
return
|
||||
|
||||
output = None
|
||||
response = self._completion(prompt, tools, tool_choice, True)
|
||||
response = cast(
|
||||
litellm.CustomStreamWrapper,
|
||||
self._completion(prompt, tools, tool_choice, True),
|
||||
)
|
||||
try:
|
||||
for part in response:
|
||||
if len(part["choices"]) == 0:
|
||||
|
@ -21,9 +21,12 @@ class LLMConfig(BaseModel):
|
||||
model_provider: str
|
||||
model_name: str
|
||||
temperature: float
|
||||
api_key: str | None
|
||||
api_base: str | None
|
||||
api_version: str | None
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
|
||||
def log_prompt(prompt: LanguageModelInput) -> None:
|
||||
|
@ -11,6 +11,9 @@ class LLMOverride(BaseModel):
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
|
||||
class PromptOverride(BaseModel):
|
||||
system_prompt: str | None = None
|
||||
|
@ -110,7 +110,7 @@ class EmbeddingModel:
|
||||
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
|
||||
def _make_request() -> EmbedResponse:
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
self.embed_server_endpoint, json=embed_request.model_dump()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
@ -255,7 +255,7 @@ class RerankingModel:
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
self.rerank_server_endpoint, json=rerank_request.dict()
|
||||
self.rerank_server_endpoint, json=rerank_request.model_dump()
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@ -288,7 +288,7 @@ class QueryAnalysisModel:
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
self.intent_server_endpoint, json=intent_request.dict()
|
||||
self.intent_server_endpoint, json=intent_request.model_dump()
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
@ -326,7 +326,7 @@ def stream_search_answer(
|
||||
db_session=session,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.dict())
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
|
||||
def get_search_answer(
|
||||
|
@ -1,8 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import root_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerContexts
|
||||
@ -21,7 +19,7 @@ class QueryRephrase(BaseModel):
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
@ -45,21 +43,16 @@ class DirectQARequest(ChunkContext):
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@root_validator
|
||||
def check_chain_of_thought_and_prompt_id(
|
||||
cls, values: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
chain_of_thought = values.get("chain_of_thought")
|
||||
prompt_id = values.get("prompt_id")
|
||||
|
||||
if chain_of_thought and prompt_id is not None:
|
||||
@model_validator(mode="after")
|
||||
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
|
||||
if self.chain_of_thought and self.prompt_id is not None:
|
||||
raise ValueError(
|
||||
"If chain_of_thought is True, prompt_id must be None"
|
||||
"The chain of thought prompt is only for question "
|
||||
"answering and does not accept customizing."
|
||||
)
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
|
||||
class OneShotQAResponse(BaseModel):
|
||||
|
@ -2,7 +2,9 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import validator
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@ -112,7 +114,8 @@ class ChunkContext(BaseModel):
|
||||
chunks_below: int | None = None
|
||||
full_doc: bool = False
|
||||
|
||||
@validator("chunks_above", "chunks_below", pre=True, each_item=False)
|
||||
@field_validator("chunks_above", "chunks_below")
|
||||
@classmethod
|
||||
def check_non_negative(cls, value: int, field: Any) -> int:
|
||||
if value is not None and value < 0:
|
||||
raise ValueError(f"{field.name} must be non-negative")
|
||||
@ -137,9 +140,7 @@ class SearchRequest(ChunkContext):
|
||||
hybrid_alpha: float | None = None
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class SearchQuery(ChunkContext):
|
||||
@ -163,9 +164,7 @@ class SearchQuery(ChunkContext):
|
||||
|
||||
num_hits: int = NUM_RETURNED_HITS
|
||||
offset: int = 0
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
class RetrievalDetails(ChunkContext):
|
||||
@ -209,7 +208,7 @@ class InferenceChunk(BaseChunk):
|
||||
updated_at: datetime | None
|
||||
primary_owners: list[str] | None = None
|
||||
secondary_owners: list[str] | None = None
|
||||
large_chunk_reference_ids: list[int] = []
|
||||
large_chunk_reference_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
@property
|
||||
def unique_id(self) -> str:
|
||||
@ -268,7 +267,7 @@ class InferenceChunkUncleaned(InferenceChunk):
|
||||
# Assumes the cleaning has already been applied and just needs to translate to the right type
|
||||
inference_chunk_data = {
|
||||
k: v
|
||||
for k, v in self.dict().items()
|
||||
for k, v in self.model_dump().items()
|
||||
if k
|
||||
not in ["metadata_suffix"] # May be other fields to throw out in the future
|
||||
}
|
||||
@ -288,7 +287,7 @@ class SearchDoc(BaseModel):
|
||||
document_id: str
|
||||
chunk_ind: int
|
||||
semantic_identifier: str
|
||||
link: str | None
|
||||
link: str | None = None
|
||||
blurb: str
|
||||
source_type: DocumentSource
|
||||
boost: int
|
||||
@ -297,7 +296,7 @@ class SearchDoc(BaseModel):
|
||||
# be `True` when doing an admin search
|
||||
hidden: bool
|
||||
metadata: dict[str, str | list[str]]
|
||||
score: float | None
|
||||
score: float | None = None
|
||||
is_relevant: bool | None = None
|
||||
relevance_explanation: str | None = None
|
||||
# Matched sections in the doc. Uses Vespa syntax e.g. <hi>TEXT</hi>
|
||||
@ -305,13 +304,13 @@ class SearchDoc(BaseModel):
|
||||
# ["<hi>the</hi> <hi>answer</hi> is 42", "the answer is <hi>42</hi>""]
|
||||
match_highlights: list[str]
|
||||
# when the doc was last updated
|
||||
updated_at: datetime | None
|
||||
primary_owners: list[str] | None
|
||||
secondary_owners: list[str] | None
|
||||
updated_at: datetime | None = None
|
||||
primary_owners: list[str] | None = None
|
||||
secondary_owners: list[str] | None = None
|
||||
is_internet: bool = False
|
||||
|
||||
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().dict(*args, **kwargs) # type: ignore
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(*args, **kwargs) # type: ignore
|
||||
initial_dict["updated_at"] = (
|
||||
self.updated_at.isoformat() if self.updated_at else None
|
||||
)
|
||||
@ -329,7 +328,7 @@ class SavedSearchDoc(SearchDoc):
|
||||
"""IMPORTANT: careful using this and not providing a db_doc_id If db_doc_id is not
|
||||
provided, it won't be able to actually fetch the saved doc and info later on. So only skip
|
||||
providing this if the SavedSearchDoc will not be used in the future"""
|
||||
search_doc_data = search_doc.dict()
|
||||
search_doc_data = search_doc.model_dump()
|
||||
search_doc_data["score"] = search_doc_data.get("score") or 0.0
|
||||
return cls(**search_doc_data, db_doc_id=db_doc_id)
|
||||
|
||||
|
@ -74,7 +74,7 @@ def stream_query_answerability(
|
||||
QueryValidationResponse(
|
||||
reasoning="Query Answerability Evaluation feature is turned off",
|
||||
answerable=True,
|
||||
).dict()
|
||||
).model_dump()
|
||||
)
|
||||
return
|
||||
|
||||
@ -85,7 +85,7 @@ def stream_query_answerability(
|
||||
QueryValidationResponse(
|
||||
reasoning="Generative AI is turned off - skipping check",
|
||||
answerable=True,
|
||||
).dict()
|
||||
).model_dump()
|
||||
)
|
||||
return
|
||||
messages = get_query_validation_messages(user_query)
|
||||
@ -107,7 +107,7 @@ def stream_query_answerability(
|
||||
remaining = model_output[reason_ind + len(THOUGHT_PAT.upper()) :]
|
||||
if remaining:
|
||||
yield get_json_line(
|
||||
DanswerAnswerPiece(answer_piece=remaining).dict()
|
||||
DanswerAnswerPiece(answer_piece=remaining).model_dump()
|
||||
)
|
||||
continue
|
||||
|
||||
@ -116,7 +116,7 @@ def stream_query_answerability(
|
||||
if hold_answerable == ANSWERABLE_PAT.upper()[: len(hold_answerable)]:
|
||||
continue
|
||||
yield get_json_line(
|
||||
DanswerAnswerPiece(answer_piece=hold_answerable).dict()
|
||||
DanswerAnswerPiece(answer_piece=hold_answerable).model_dump()
|
||||
)
|
||||
hold_answerable = ""
|
||||
|
||||
@ -124,11 +124,13 @@ def stream_query_answerability(
|
||||
answerable = extract_answerability_bool(model_output)
|
||||
|
||||
yield get_json_line(
|
||||
QueryValidationResponse(reasoning=reasoning, answerable=answerable).dict()
|
||||
QueryValidationResponse(
|
||||
reasoning=reasoning, answerable=answerable
|
||||
).model_dump()
|
||||
)
|
||||
except Exception as e:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
error = StreamingError(error=str(e))
|
||||
yield get_json_line(error.dict())
|
||||
yield get_json_line(error.model_dump())
|
||||
logger.exception("Failed to validate Query")
|
||||
return
|
||||
|
@ -5,7 +5,7 @@ from danswer.connectors.models import DocumentBase
|
||||
|
||||
class IngestionDocument(BaseModel):
|
||||
document: DocumentBase
|
||||
cc_pair_id: int | None
|
||||
cc_pair_id: int | None = None
|
||||
|
||||
|
||||
class IngestionResult(BaseModel):
|
||||
@ -16,4 +16,4 @@ class IngestionResult(BaseModel):
|
||||
class DocMinimalInfo(BaseModel):
|
||||
document_id: str
|
||||
semantic_id: str
|
||||
link: str | None
|
||||
link: str | None = None
|
||||
|
@ -170,7 +170,7 @@ def associate_credential_to_connector(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
cc_pair_name=metadata.name,
|
||||
is_public=metadata.is_public,
|
||||
is_public=metadata.is_public or True,
|
||||
groups=metadata.groups,
|
||||
)
|
||||
|
||||
|
@ -3,6 +3,7 @@ from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@ -40,14 +41,15 @@ class ConnectorBase(BaseModel):
|
||||
source: DocumentSource
|
||||
input_type: InputType
|
||||
connector_specific_config: dict[str, Any]
|
||||
refresh_freq: int | None # In seconds, None for one time index with no refresh
|
||||
prune_freq: int | None
|
||||
indexing_start: datetime | None
|
||||
# In seconds, None for one time index with no refresh
|
||||
refresh_freq: int | None = None
|
||||
prune_freq: int | None = None
|
||||
indexing_start: datetime | None = None
|
||||
|
||||
|
||||
class ConnectorUpdateRequest(ConnectorBase):
|
||||
is_public: bool | None = None
|
||||
groups: list[int] | None = None
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ConnectorSnapshot(ConnectorBase):
|
||||
@ -93,7 +95,7 @@ class CredentialBase(BaseModel):
|
||||
source: DocumentSource
|
||||
name: str | None = None
|
||||
curator_public: bool = False
|
||||
groups: list[int] = []
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CredentialSnapshot(CredentialBase):
|
||||
@ -254,21 +256,21 @@ class ConnectorCredentialPairIdentifier(BaseModel):
|
||||
|
||||
|
||||
class ConnectorCredentialPairMetadata(BaseModel):
|
||||
name: str | None
|
||||
is_public: bool
|
||||
groups: list[int] | None
|
||||
name: str | None = None
|
||||
is_public: bool | None = None
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ConnectorCredentialPairDescriptor(BaseModel):
|
||||
id: int
|
||||
name: str | None
|
||||
name: str | None = None
|
||||
connector: ConnectorSnapshot
|
||||
credential: CredentialSnapshot
|
||||
|
||||
|
||||
class RunConnectorRequest(BaseModel):
|
||||
connector_id: int
|
||||
credential_ids: list[int] | None
|
||||
credential_ids: list[int] | None = None
|
||||
from_beginning: bool = False
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.server.documents.models import ConnectorCredentialPairDescriptor
|
||||
@ -14,8 +15,8 @@ class DocumentSetCreationRequest(BaseModel):
|
||||
cc_pair_ids: list[int]
|
||||
is_public: bool
|
||||
# For Private Document Sets, who should be able to access these
|
||||
users: list[UUID] | None = None
|
||||
groups: list[int] | None = None
|
||||
users: list[UUID] = Field(default_factory=list)
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DocumentSetUpdateRequest(BaseModel):
|
||||
|
@ -19,7 +19,7 @@ class FolderCreationRequest(BaseModel):
|
||||
|
||||
|
||||
class FolderUpdateRequest(BaseModel):
|
||||
folder_name: str | None
|
||||
folder_name: str | None = None
|
||||
|
||||
|
||||
class FolderChatSessionRequest(BaseModel):
|
||||
|
@ -1,6 +1,7 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import StarterMessage
|
||||
@ -31,8 +32,8 @@ class CreatePersonaRequest(BaseModel):
|
||||
llm_model_version_override: str | None = None
|
||||
starter_messages: list[StarterMessage] | None = None
|
||||
# For Private Personas, who should be able to access these
|
||||
users: list[UUID] | None = None
|
||||
groups: list[int] | None = None
|
||||
users: list[UUID] = Field(default_factory=list)
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
icon_color: str | None = None
|
||||
icon_shape: int | None = None
|
||||
uploaded_image_id: str | None = None # New field for uploaded image
|
||||
|
@ -26,14 +26,14 @@ admin_router = APIRouter(prefix="/admin/tool")
|
||||
|
||||
class CustomToolCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None
|
||||
description: str | None = None
|
||||
definition: dict[str, Any]
|
||||
|
||||
|
||||
class CustomToolUpdate(BaseModel):
|
||||
name: str | None
|
||||
description: str | None
|
||||
definition: dict[str, Any] | None
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
definition: dict[str, Any] | None = None
|
||||
|
||||
|
||||
def _validate_tool_definition(definition: dict[str, Any]) -> None:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from danswer.llm.llm_provider_options import fetch_models_for_provider
|
||||
|
||||
@ -56,26 +57,26 @@ class LLMProviderDescriptor(BaseModel):
|
||||
class LLMProvider(BaseModel):
|
||||
name: str
|
||||
provider: str
|
||||
api_key: str | None
|
||||
api_base: str | None
|
||||
api_version: str | None
|
||||
custom_config: dict[str, str] | None
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
fast_default_model_name: str | None = None
|
||||
is_public: bool = True
|
||||
groups: list[int] | None = None
|
||||
display_model_names: list[str] | None
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
display_model_names: list[str] | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
# should only be used for a "custom" provider
|
||||
# for default providers, the built-in model names are used
|
||||
model_names: list[str] | None
|
||||
model_names: list[str] | None = None
|
||||
|
||||
|
||||
class FullLLMProvider(LLMProvider):
|
||||
id: int
|
||||
is_default_provider: bool | None
|
||||
is_default_provider: bool | None = None
|
||||
model_names: list[str]
|
||||
|
||||
@classmethod
|
||||
|
@ -1,10 +1,11 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import root_validator
|
||||
from pydantic import validator
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
@ -39,8 +40,8 @@ class AuthTypeResponse(BaseModel):
|
||||
|
||||
|
||||
class UserPreferences(BaseModel):
|
||||
chosen_assistants: list[int] | None
|
||||
default_model: str | None
|
||||
chosen_assistants: list[int] | None = None
|
||||
default_model: str | None = None
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
@ -158,7 +159,8 @@ class StandardAnswerCreationRequest(BaseModel):
|
||||
answer: str
|
||||
categories: list[int]
|
||||
|
||||
@validator("categories", pre=True)
|
||||
@field_validator("categories", mode="before")
|
||||
@classmethod
|
||||
def validate_categories(cls, value: list[int]) -> list[int]:
|
||||
if len(value) < 1:
|
||||
raise ValueError(
|
||||
@ -170,9 +172,7 @@ class StandardAnswerCreationRequest(BaseModel):
|
||||
class SlackBotTokens(BaseModel):
|
||||
bot_token: str
|
||||
app_token: str
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
class SlackBotConfigCreationRequest(BaseModel):
|
||||
@ -180,23 +180,24 @@ class SlackBotConfigCreationRequest(BaseModel):
|
||||
# in the future, `document_sets` will probably be replaced
|
||||
# by an optional `PersonaSnapshot` object. Keeping it like this
|
||||
# for now for simplicity / speed of development
|
||||
document_sets: list[int] | None
|
||||
document_sets: list[int] | None = None
|
||||
persona_id: (
|
||||
int | None
|
||||
) # NOTE: only one of `document_sets` / `persona_id` should be set
|
||||
) = None # NOTE: only one of `document_sets` / `persona_id` should be set
|
||||
channel_names: list[str]
|
||||
respond_tag_only: bool = False
|
||||
respond_to_bots: bool = False
|
||||
enable_auto_filters: bool = False
|
||||
# If no team members, assume respond in the channel to everyone
|
||||
respond_member_group_list: list[str] = []
|
||||
answer_filters: list[AllowedAnswerFilters] = []
|
||||
respond_member_group_list: list[str] = Field(default_factory=list)
|
||||
answer_filters: list[AllowedAnswerFilters] = Field(default_factory=list)
|
||||
# list of user emails
|
||||
follow_up_tags: list[str] | None = None
|
||||
response_type: SlackBotResponseType
|
||||
standard_answer_categories: list[int] = []
|
||||
standard_answer_categories: list[int] = Field(default_factory=list)
|
||||
|
||||
@validator("answer_filters", pre=True)
|
||||
@field_validator("answer_filters", mode="before")
|
||||
@classmethod
|
||||
def validate_filters(cls, value: list[str]) -> list[str]:
|
||||
if any(test not in VALID_SLACK_FILTERS for test in value):
|
||||
raise ValueError(
|
||||
@ -204,14 +205,12 @@ class SlackBotConfigCreationRequest(BaseModel):
|
||||
)
|
||||
return value
|
||||
|
||||
@root_validator
|
||||
def validate_document_sets_and_persona_id(
|
||||
cls, values: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
if values.get("document_sets") and values.get("persona_id"):
|
||||
@model_validator(mode="after")
|
||||
def validate_document_sets_and_persona_id(self) -> "SlackBotConfigCreationRequest":
|
||||
if self.document_sets and self.persona_id:
|
||||
raise ValueError("Only one of `document_sets` / `persona_id` should be set")
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
|
||||
class SlackBotConfig(BaseModel):
|
||||
|
@ -323,7 +323,7 @@ def verify_user_logged_in(
|
||||
|
||||
|
||||
class ChosenDefaultModelRequest(BaseModel):
|
||||
default_model: str | None
|
||||
default_model: str | None = None
|
||||
|
||||
|
||||
@router.patch("/user/default-model")
|
||||
|
@ -4,7 +4,6 @@ from typing import TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserStatus
|
||||
@ -13,7 +12,7 @@ from danswer.auth.schemas import UserStatus
|
||||
DataT = TypeVar("DataT")
|
||||
|
||||
|
||||
class StatusResponse(GenericModel, Generic[DataT]):
|
||||
class StatusResponse(BaseModel, Generic[DataT]):
|
||||
success: bool
|
||||
message: Optional[str] = None
|
||||
data: Optional[DataT] = None
|
||||
|
@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import root_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import RetrievalDocs
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@ -54,16 +54,11 @@ class ChatFeedbackRequest(BaseModel):
|
||||
feedback_text: str | None = None
|
||||
predefined_feedback: str | None = None
|
||||
|
||||
@root_validator
|
||||
def check_is_positive_or_feedback_text(cls: BaseModel, values: dict) -> dict:
|
||||
is_positive, feedback_text = values.get("is_positive"), values.get(
|
||||
"feedback_text"
|
||||
)
|
||||
|
||||
if is_positive is None and feedback_text is None:
|
||||
@model_validator(mode="after")
|
||||
def check_is_positive_or_feedback_text(self) -> "ChatFeedbackRequest":
|
||||
if self.is_positive is None and self.feedback_text is None:
|
||||
raise ValueError("Empty feedback received.")
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
|
||||
"""
|
||||
@ -112,18 +107,13 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
@root_validator
|
||||
def check_search_doc_ids_or_retrieval_options(cls: BaseModel, values: dict) -> dict:
|
||||
search_doc_ids, retrieval_options = values.get("search_doc_ids"), values.get(
|
||||
"retrieval_options"
|
||||
)
|
||||
|
||||
if search_doc_ids is None and retrieval_options is None:
|
||||
@model_validator(mode="after")
|
||||
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
|
||||
if self.search_doc_ids is None and self.retrieval_options is None:
|
||||
raise ValueError(
|
||||
"Either search_doc_ids or retrieval_options must be provided, but not both or neither."
|
||||
)
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
|
||||
class ChatMessageIdentifier(BaseModel):
|
||||
@ -149,7 +139,7 @@ class ChatSessionDetails(BaseModel):
|
||||
persona_id: int
|
||||
time_created: str
|
||||
shared_status: ChatSessionSharedStatus
|
||||
folder_id: int | None
|
||||
folder_id: int | None = None
|
||||
current_alternate_model: str | None = None
|
||||
|
||||
|
||||
@ -162,36 +152,36 @@ class SearchFeedbackRequest(BaseModel):
|
||||
document_id: str
|
||||
document_rank: int
|
||||
click: bool
|
||||
search_feedback: SearchFeedbackType | None
|
||||
search_feedback: SearchFeedbackType | None = None
|
||||
|
||||
@root_validator
|
||||
def check_click_or_search_feedback(cls: BaseModel, values: dict) -> dict:
|
||||
click, feedback = values.get("click"), values.get("search_feedback")
|
||||
@model_validator(mode="after")
|
||||
def check_click_or_search_feedback(self) -> "SearchFeedbackRequest":
|
||||
click, feedback = self.click, self.search_feedback
|
||||
|
||||
if click is False and feedback is None:
|
||||
raise ValueError("Empty feedback received.")
|
||||
return values
|
||||
return self
|
||||
|
||||
|
||||
class ChatMessageDetail(BaseModel):
|
||||
message_id: int
|
||||
parent_message: int | None
|
||||
latest_child_message: int | None
|
||||
parent_message: int | None = None
|
||||
latest_child_message: int | None = None
|
||||
message: str
|
||||
rephrased_query: str | None
|
||||
context_docs: RetrievalDocs | None
|
||||
rephrased_query: str | None = None
|
||||
context_docs: RetrievalDocs | None = None
|
||||
message_type: MessageType
|
||||
time_sent: datetime
|
||||
alternate_assistant_id: str | None
|
||||
overridden_model: str | None
|
||||
alternate_assistant_id: int | None = None
|
||||
# Dict mapping citation number to db_doc_id
|
||||
chat_session_id: int | None = None
|
||||
citations: dict[int, int] | None
|
||||
citations: dict[int, int] | None = None
|
||||
files: list[FileDescriptor]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
|
||||
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().dict(*args, **kwargs) # type: ignore
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
initial_dict["time_sent"] = self.time_sent.isoformat()
|
||||
return initial_dict
|
||||
|
||||
@ -226,7 +216,3 @@ class AdminSearchRequest(BaseModel):
|
||||
|
||||
class AdminSearchResponse(BaseModel):
|
||||
documents: list[SearchDoc]
|
||||
|
||||
|
||||
class DanswerAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
@ -64,7 +64,7 @@ def fetch_settings(
|
||||
needs_reindexing = False
|
||||
|
||||
return UserSettings(
|
||||
**general_settings.dict(),
|
||||
**general_settings.model_dump(),
|
||||
notifications=user_notifications,
|
||||
needs_reindexing=needs_reindexing
|
||||
)
|
||||
|
@ -12,10 +12,10 @@ def load_settings() -> Settings:
|
||||
settings = Settings(**cast(dict, dynamic_config_store.load(KV_SETTINGS_KEY)))
|
||||
except ConfigNotFoundError:
|
||||
settings = Settings()
|
||||
dynamic_config_store.store(KV_SETTINGS_KEY, settings.dict())
|
||||
dynamic_config_store.store(KV_SETTINGS_KEY, settings.model_dump())
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
def store_settings(settings: Settings) -> None:
|
||||
get_dynamic_config_store().store(KV_SETTINGS_KEY, settings.dict())
|
||||
get_dynamic_config_store().store(KV_SETTINGS_KEY, settings.model_dump())
|
||||
|
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from typing import Type
|
||||
from typing import TypedDict
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
|
||||
from sqlalchemy import not_
|
||||
from sqlalchemy import or_
|
||||
|
@ -254,6 +254,6 @@ class ImageGenerationTool(Tool):
|
||||
list[ImageGenerationResponse], args[0].response
|
||||
)
|
||||
return [
|
||||
image_generation_response.dict()
|
||||
image_generation_response.model_dump()
|
||||
for image_generation_response in image_generation_responses
|
||||
]
|
||||
|
@ -187,7 +187,7 @@ class InternetSearchTool(Tool):
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
search_response = cast(InternetSearchResponse, args[0].response)
|
||||
return json.dumps(search_response.dict())
|
||||
return json.dumps(search_response.model_dump())
|
||||
|
||||
def _perform_search(self, query: str) -> InternetSearchResponse:
|
||||
response = self.client.get(
|
||||
@ -230,4 +230,4 @@ class InternetSearchTool(Tool):
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
search_response = cast(InternetSearchResponse, args[0].response)
|
||||
return search_response.dict()
|
||||
return search_response.model_dump()
|
||||
|
@ -4,10 +4,12 @@ from typing import Any
|
||||
from langchain_core.messages.ai import AIMessage
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
from pydantic import BaseModel
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
|
||||
# Langchain has their own version of pydantic which is version 1
|
||||
|
||||
|
||||
def build_tool_message(
|
||||
tool_call: ToolCall, tool_content: str | list[str | dict[str, Any]]
|
||||
@ -19,7 +21,7 @@ def build_tool_message(
|
||||
)
|
||||
|
||||
|
||||
class ToolCallSummary(BaseModel):
|
||||
class ToolCallSummary(BaseModel__v1):
|
||||
tool_call_request: AIMessage
|
||||
tool_call_result: ToolMessage
|
||||
|
||||
|
@ -1,12 +1,12 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import root_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
id: str | None = None
|
||||
response: Any
|
||||
response: Any = None
|
||||
|
||||
|
||||
class ToolCallKickoff(BaseModel):
|
||||
@ -19,12 +19,10 @@ class ToolRunnerResponse(BaseModel):
|
||||
tool_response: ToolResponse | None = None
|
||||
tool_message_content: str | list[str | dict[str, Any]] | None = None
|
||||
|
||||
@root_validator
|
||||
def validate_tool_runner_response(
|
||||
cls, values: dict[str, ToolResponse | str]
|
||||
) -> dict[str, ToolResponse | str]:
|
||||
@model_validator(mode="after")
|
||||
def validate_tool_runner_response(self) -> "ToolRunnerResponse":
|
||||
fields = ["tool_response", "tool_message_content", "tool_run_kickoff"]
|
||||
provided = sum(1 for field in fields if values.get(field) is not None)
|
||||
provided = sum(1 for field in fields if getattr(self, field) is not None)
|
||||
|
||||
if provided != 1:
|
||||
raise ValueError(
|
||||
@ -32,8 +30,10 @@ class ToolRunnerResponse(BaseModel):
|
||||
"or 'tool_run_kickoff' must be provided"
|
||||
)
|
||||
|
||||
return values
|
||||
return self
|
||||
|
||||
|
||||
class ToolCallFinalResult(ToolCallKickoff):
|
||||
tool_result: Any # we would like to use JSON_ro, but can't due to its recursive nature
|
||||
tool_result: Any = (
|
||||
None # we would like to use JSON_ro, but can't due to its recursive nature
|
||||
)
|
||||
|
@ -30,13 +30,13 @@ def load_settings() -> EnterpriseSettings:
|
||||
)
|
||||
except ConfigNotFoundError:
|
||||
settings = EnterpriseSettings()
|
||||
dynamic_config_store.store(KV_ENTERPRISE_SETTINGS_KEY, settings.dict())
|
||||
dynamic_config_store.store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump())
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
def store_settings(settings: EnterpriseSettings) -> None:
|
||||
get_dynamic_config_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.dict())
|
||||
get_dynamic_config_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump())
|
||||
|
||||
|
||||
_CUSTOM_ANALYTICS_SECRET_KEY = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY")
|
||||
|
@ -17,7 +17,7 @@ class StandardAnswerRequest(BaseModel):
|
||||
|
||||
|
||||
class StandardAnswerResponse(BaseModel):
|
||||
standard_answers: list[StandardAnswer] = []
|
||||
standard_answers: list[StandardAnswer] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DocumentSearchRequest(ChunkContext):
|
||||
|
@ -376,7 +376,7 @@ def get_query_history_as_csv(
|
||||
# Create an in-memory text stream
|
||||
stream = io.StringIO()
|
||||
writer = csv.DictWriter(
|
||||
stream, fieldnames=list(QuestionAnswerPairSnapshot.__fields__.keys())
|
||||
stream, fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys())
|
||||
)
|
||||
writer.writeheader()
|
||||
for row in question_answer_pairs:
|
||||
|
@ -29,20 +29,20 @@ langchain==0.1.17
|
||||
langchain-community==0.0.36
|
||||
langchain-core==0.1.50
|
||||
langchain-text-splitters==0.0.1
|
||||
litellm==1.37.7
|
||||
litellm==1.43.18
|
||||
llama-index==0.9.45
|
||||
Mako==1.2.4
|
||||
msal==1.26.0
|
||||
nltk==3.8.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
openai==1.14.3
|
||||
openai==1.41.1
|
||||
openpyxl==3.1.2
|
||||
playwright==1.41.2
|
||||
psutil==5.9.5
|
||||
psycopg2-binary==2.9.9
|
||||
pycryptodome==3.19.1
|
||||
pydantic==1.10.13
|
||||
pydantic==2.8.2
|
||||
PyGithub==1.58.2
|
||||
python-dateutil==2.8.2
|
||||
python-gitlab==3.9.0
|
||||
@ -64,7 +64,7 @@ slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.15
|
||||
starlette==0.36.3
|
||||
supervisor==4.2.5
|
||||
tiktoken==0.4.0
|
||||
tiktoken==0.7.0
|
||||
timeago==1.0.16
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
|
@ -3,8 +3,8 @@ einops==0.8.0
|
||||
fastapi==0.109.2
|
||||
google-cloud-aiplatform==1.58.0
|
||||
numpy==1.26.4
|
||||
openai==1.14.3
|
||||
pydantic==1.10.13
|
||||
openai==1.41.1
|
||||
pydantic==2.8.2
|
||||
retry==0.9.2
|
||||
safetensors==0.4.2
|
||||
sentence-transformers==2.6.1
|
||||
|
@ -10,14 +10,17 @@ Embedding = list[float]
|
||||
class EmbedRequest(BaseModel):
|
||||
texts: list[str]
|
||||
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||
model_name: str | None
|
||||
model_name: str | None = None
|
||||
max_context_length: int
|
||||
normalize_embeddings: bool
|
||||
api_key: str | None
|
||||
provider_type: EmbeddingProvider | None
|
||||
api_key: str | None = None
|
||||
provider_type: EmbeddingProvider | None = None
|
||||
text_type: EmbedTextType
|
||||
manual_query_prefix: str | None
|
||||
manual_passage_prefix: str | None
|
||||
manual_query_prefix: str | None = None
|
||||
manual_passage_prefix: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
@ -28,8 +31,11 @@ class RerankRequest(BaseModel):
|
||||
query: str
|
||||
documents: list[str]
|
||||
model_name: str
|
||||
provider_type: RerankerProvider | None
|
||||
api_key: str | None
|
||||
provider_type: RerankerProvider | None = None
|
||||
api_key: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
|
@ -14,7 +14,7 @@ class DocumentSetClient:
|
||||
) -> int:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/document-set",
|
||||
json=doc_set_creation_request.dict(),
|
||||
json=doc_set_creation_request.model_dump(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cast(int, response.json())
|
||||
|
@ -31,7 +31,7 @@ class LLMProvider(BaseModel):
|
||||
custom_config=None,
|
||||
fast_default_model_name=None,
|
||||
is_public=True,
|
||||
groups=None,
|
||||
groups=[],
|
||||
display_model_names=None,
|
||||
model_names=None,
|
||||
)
|
||||
|
@ -12,7 +12,7 @@ class UserGroupClient:
|
||||
def create_user_group(user_group_creation_request: UserGroupCreate) -> int:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
json=user_group_creation_request.dict(),
|
||||
json=user_group_creation_request.model_dump(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cast(int, response.json()["id"])
|
||||
|
@ -57,7 +57,7 @@ def get_answer_from_query(
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
body = new_message_request.dict()
|
||||
body = new_message_request.model_dump()
|
||||
body["user"] = None
|
||||
try:
|
||||
response_json = requests.post(url, headers=headers, json=body).json()
|
||||
@ -122,7 +122,7 @@ def create_cc_pair(env_name: str, connector_id: int, credential_id: int) -> None
|
||||
env_name, f"/manage/connector/{connector_id}/credential/{credential_id}"
|
||||
)
|
||||
|
||||
body = {"name": "zip_folder_contents", "is_public": True}
|
||||
body = {"name": "zip_folder_contents", "is_public": True, "groups": []}
|
||||
print("body:", body)
|
||||
response = requests.put(url, headers=GENERAL_HEADERS, json=body)
|
||||
if response.status_code == 200:
|
||||
@ -167,7 +167,7 @@ def create_connector(env_name: str, file_paths: list[str]) -> int:
|
||||
indexing_start=None,
|
||||
)
|
||||
|
||||
body = connector.dict()
|
||||
body = connector.model_dump()
|
||||
response = requests.post(url, headers=GENERAL_HEADERS, json=body)
|
||||
if response.status_code == 200:
|
||||
return response.json()["id"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user