danswer/backend/onyx/chat/models.py
2025-02-19 15:52:16 -08:00

400 lines
11 KiB
Python

from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Literal
from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.enums import SearchType
from onyx.context.search.models import RetrievalDocs
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import PromptOverride
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType
if TYPE_CHECKING:
from onyx.db.models import Prompt
class LlmDoc(BaseModel):
"""This contains the minimal set information for the LLM portion including citations"""
document_id: str
content: str
blurb: str
semantic_identifier: str
source_type: DocumentSource
metadata: dict[str, str | list[str]]
updated_at: datetime | None
link: str | None
source_links: dict[int, str] | None
match_highlights: list[str] | None
class SubQuestionIdentifier(BaseModel):
level: int | None = None
level_question_num: int | None = None
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
rephrased_query: str | None = None
predicted_flow: QueryFlow | None
predicted_search: SearchType | None
applied_source_filters: list[DocumentSource] | None
applied_time_cutoff: datetime | None
recency_bias_multiplier: float
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
initial_dict["applied_time_cutoff"] = (
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
)
return initial_dict
class StreamStopReason(Enum):
CONTEXT_LENGTH = "context_length"
CANCELLED = "cancelled"
FINISHED = "finished"
class StreamType(Enum):
SUB_QUESTIONS = "sub_questions"
SUB_ANSWER = "sub_answer"
MAIN_ANSWER = "main_answer"
class StreamStopInfo(SubQuestionIdentifier):
stop_reason: StreamStopReason
stream_type: StreamType = StreamType.MAIN_ANSWER
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
data["stop_reason"] = self.stop_reason.name
return data
class LLMRelevanceFilterResponse(BaseModel):
llm_selected_doc_indices: list[int]
class FinalUsedContextDocsResponse(BaseModel):
final_context_docs: list[LlmDoc]
class RelevanceAnalysis(BaseModel):
relevant: bool
content: str | None = None
class SectionRelevancePiece(RelevanceAnalysis):
"""LLM analysis mapped to an Inference Section"""
document_id: str
chunk_id: int # ID of the center chunk for a given inference section
class DocumentRelevance(BaseModel):
"""Contains all relevance information for a given search"""
relevance_summaries: dict[str, RelevanceAnalysis]
class OnyxAnswerPiece(BaseModel):
# A small piece of a complete answer. Used for streaming back answers.
answer_piece: str | None # if None, specifies the end of an Answer
# An intermediate representation of citations, later translated into
# a mapping of the citation [n] number to SearchDoc
class CitationInfo(SubQuestionIdentifier):
citation_num: int
document_id: str
class AllCitations(BaseModel):
citations: list[CitationInfo]
# This is a mapping of the citation number to the document index within
# the result search doc set
class MessageSpecificCitations(BaseModel):
citation_map: dict[int, int]
class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
class AgentMessageIDInfo(BaseModel):
level: int
message_id: int
class AgenticMessageResponseIDInfo(BaseModel):
agentic_message_ids: list[AgentMessageIDInfo]
class StreamingError(BaseModel):
error: str
stack_trace: str | None = None
class OnyxContext(BaseModel):
content: str
document_id: str
semantic_identifier: str
blurb: str
class OnyxContexts(BaseModel):
contexts: list[OnyxContext]
class OnyxAnswer(BaseModel):
answer: str | None
class ThreadMessage(BaseModel):
message: str
sender: str | None = None
role: MessageType = MessageType.USER
class ChatOnyxBotResponse(BaseModel):
answer: str | None = None
citations: list[CitationInfo] | None = None
docs: QADocsResponse | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
class FileChatDisplay(BaseModel):
file_ids: list[str]
class CustomToolResponse(BaseModel):
response: ToolResultType
tool_name: str
class ToolConfig(BaseModel):
id: int
class PromptOverrideConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
include_citations: bool = True
datetime_aware: bool = True
class PersonaOverrideConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
prompt_ids: list[int] = Field(default_factory=list)
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
AnswerQuestionPossibleReturn = (
OnyxAnswerPiece
| CitationInfo
| OnyxContexts
| FileChatDisplay
| CustomToolResponse
| StreamingError
| StreamStopInfo
)
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
class LLMMetricsContainer(BaseModel):
prompt_tokens: int
response_tokens: int
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
class DocumentPruningConfig(BaseModel):
max_chunks: int | None = None
max_window_percentage: float | None = None
max_tokens: int | None = None
# different pruning behavior is expected when the
# user manually selects documents they want to chat with
# e.g. we don't want to truncate each document to be no more
# than one chunk long
is_manually_selected_docs: bool = False
# If user specifies to include additional context Chunks for each match, then different pruning
# is used. As many Sections as possible are included, and the last Section is truncated
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
use_sections: bool = True
# If using tools, then we need to consider the tool length
tool_num_tokens: int = 0
# If using a tool message to represent the docs, then we have to JSON serialize
# the document content, which adds to the token count.
using_tool_message: bool = False
class ContextualPruningConfig(DocumentPruningConfig):
num_chunk_multiple: int
@classmethod
def from_doc_pruning_config(
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
) -> "ContextualPruningConfig":
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
class CitationConfig(BaseModel):
all_docs_useful: bool = False
class AnswerStyleConfig(BaseModel):
citation_config: CitationConfig
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
# right now, only used by the simple chat API
structured_response_format: dict | None = None
class PromptConfig(BaseModel):
"""Final representation of the Prompt configuration passed
into the `PromptBuilder` object."""
system_prompt: str
task_prompt: str
datetime_aware: bool
include_citations: bool
@classmethod
def from_model(
cls, model: "Prompt", prompt_override: PromptOverride | None = None
) -> "PromptConfig":
override_system_prompt = (
prompt_override.system_prompt if prompt_override else None
)
override_task_prompt = prompt_override.task_prompt if prompt_override else None
return cls(
system_prompt=override_system_prompt or model.system_prompt,
task_prompt=override_task_prompt or model.task_prompt,
datetime_aware=model.datetime_aware,
include_citations=model.include_citations,
)
model_config = ConfigDict(frozen=True)
class SubQueryPiece(SubQuestionIdentifier):
sub_query: str
query_id: int
class AgentAnswerPiece(SubQuestionIdentifier):
answer_piece: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class SubQuestionPiece(SubQuestionIdentifier):
sub_question: str
class ExtendedToolResponse(ToolResponse, SubQuestionIdentifier):
pass
class RefinedAnswerImprovement(BaseModel):
refined_answer_improvement: bool
AgentSearchPacket = (
SubQuestionPiece
| AgentAnswerPiece
| SubQueryPiece
| ExtendedToolResponse
| RefinedAnswerImprovement
)
AnswerPacket = (
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse
)
ResponsePart = (
OnyxAnswerPiece
| CitationInfo
| ToolCallKickoff
| ToolResponse
| ToolCallFinalResult
| StreamStopInfo
| AgentSearchPacket
)
AnswerStream = Iterator[AnswerPacket]
class AnswerPostInfo(BaseModel):
ai_message_files: list[FileDescriptor]
qa_docs_response: QADocsResponse | None = None
reference_db_search_docs: list[DbSearchDoc] | None = None
dropped_indices: list[int] | None = None
tool_result: ToolCallFinalResult | None = None
message_specific_citations: MessageSpecificCitations | None = None
class Config:
arbitrary_types_allowed = True
class SubQuestionKey(BaseModel):
level: int
question_num: int
def __hash__(self) -> int:
return hash((self.level, self.question_num))
def __eq__(self, other: object) -> bool:
return isinstance(other, SubQuestionKey) and (
self.level,
self.question_num,
) == (other.level, other.question_num)