2024-11-23 13:42:54 -08:00

232 lines
7.2 KiB
Python

from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
from danswer.chat.models import RetrievalDocs
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import ChunkContext
from danswer.context.search.models import RetrievalDetails
from danswer.context.search.models import SearchDoc
from danswer.context.search.models import Tag
from danswer.db.enums import ChatSessionSharedStatus
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.tools.models import ToolCallFinalResult
class SourceTag(Tag):
source: DocumentSource
class TagResponse(BaseModel):
tags: list[SourceTag]
class SimpleQueryRequest(BaseModel):
query: str
class UpdateChatSessionThreadRequest(BaseModel):
# If not specified, use Danswer default persona
chat_session_id: UUID
new_alternate_model: str
class ChatSessionCreationRequest(BaseModel):
# If not specified, use Danswer default persona
persona_id: int = 0
description: str | None = None
class CreateChatSessionID(BaseModel):
chat_session_id: UUID
class ChatFeedbackRequest(BaseModel):
chat_message_id: int
is_positive: bool | None = None
feedback_text: str | None = None
predefined_feedback: str | None = None
@model_validator(mode="after")
def check_is_positive_or_feedback_text(self) -> "ChatFeedbackRequest":
if self.is_positive is None and self.feedback_text is None:
raise ValueError("Empty feedback received.")
return self
"""
Currently the different branches are generated by changing the search query
[Empty Root Message] This allows the first message to be branched as well
/ | \
[First Message] [First Message Edit 1] [First Message Edit 2]
| |
[Second Message] [Second Message of Edit 1 Branch]
"""
class CreateChatMessageRequest(ChunkContext):
"""Before creating messages, be sure to create a chat_session and get an id"""
chat_session_id: UUID
# This is the primary-key (unique identifier) for the previous message of the tree
parent_message_id: int | None
# New message contents
message: str
# Files that we should attach to this message
file_descriptors: list[FileDescriptor]
# If no prompt provided, uses the largest prompt of the chat session
# but really this should be explicitly specified, only in the simplified APIs is this inferred
# Use prompt_id 0 to use the system default prompt which is Answer-Question
prompt_id: int | None
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None
retrieval_options: RetrievalDetails | None
# allows the caller to specify the exact search query they want to use
# will disable Query Rewording if specified
query_override: str | None = None
# enables additional handling to ensure that we regenerate with a given user message ID
regenerate: bool | None = None
# allows the caller to override the Persona / Prompt
# these do not persist in the chat thread details
llm_override: LLMOverride | None = None
prompt_override: PromptOverride | None = None
# allow user to specify an alternate assistnat
alternate_assistant_id: int | None = None
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False
# used for "OpenAI Assistants API"
existing_assistant_message_id: int | None = None
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
@model_validator(mode="after")
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
if self.search_doc_ids is None and self.retrieval_options is None:
raise ValueError(
"Either search_doc_ids or retrieval_options must be provided, but not both or neither."
)
return self
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
data = super().model_dump(*args, **kwargs)
data["chat_session_id"] = str(data["chat_session_id"])
return data
class ChatMessageIdentifier(BaseModel):
message_id: int
class ChatRenameRequest(BaseModel):
chat_session_id: UUID
name: str | None = None
class ChatSessionUpdateRequest(BaseModel):
sharing_status: ChatSessionSharedStatus
class RenameChatSessionResponse(BaseModel):
new_name: str # This is only really useful if the name is generated
class ChatSessionDetails(BaseModel):
id: UUID
name: str
persona_id: int | None = None
time_created: str
shared_status: ChatSessionSharedStatus
folder_id: int | None = None
current_alternate_model: str | None = None
class ChatSessionsResponse(BaseModel):
sessions: list[ChatSessionDetails]
class SearchFeedbackRequest(BaseModel):
message_id: int
document_id: str
document_rank: int
click: bool
search_feedback: SearchFeedbackType | None = None
@model_validator(mode="after")
def check_click_or_search_feedback(self) -> "SearchFeedbackRequest":
click, feedback = self.click, self.search_feedback
if click is False and feedback is None:
raise ValueError("Empty feedback received.")
return self
class ChatMessageDetail(BaseModel):
message_id: int
parent_message: int | None = None
latest_child_message: int | None = None
message: str
rephrased_query: str | None = None
context_docs: RetrievalDocs | None = None
message_type: MessageType
time_sent: datetime
overridden_model: str | None
alternate_assistant_id: int | None = None
# Dict mapping citation number to db_doc_id
chat_session_id: UUID | None = None
citations: dict[int, int] | None = None
files: list[FileDescriptor]
tool_call: ToolCallFinalResult | None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
initial_dict["time_sent"] = self.time_sent.isoformat()
return initial_dict
class SearchSessionDetailResponse(BaseModel):
search_session_id: UUID
description: str
documents: list[SearchDoc]
messages: list[ChatMessageDetail]
class ChatSessionDetailResponse(BaseModel):
chat_session_id: UUID
description: str
persona_id: int | None = None
persona_name: str | None
messages: list[ChatMessageDetail]
time_created: datetime
shared_status: ChatSessionSharedStatus
current_alternate_model: str | None
class QueryValidationResponse(BaseModel):
reasoning: str
answerable: bool
class AdminSearchRequest(BaseModel):
query: str
filters: BaseFilters
class AdminSearchResponse(BaseModel):
documents: list[SearchDoc]