mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-27 02:02:18 +01:00
400 lines
11 KiB
Python
400 lines
11 KiB
Python
from collections.abc import Callable
|
|
from collections.abc import Iterator
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
from typing import Any
|
|
from typing import Literal
|
|
from typing import TYPE_CHECKING
|
|
|
|
from pydantic import BaseModel
|
|
from pydantic import ConfigDict
|
|
from pydantic import Field
|
|
|
|
from onyx.configs.constants import DocumentSource
|
|
from onyx.configs.constants import MessageType
|
|
from onyx.context.search.enums import QueryFlow
|
|
from onyx.context.search.enums import RecencyBiasSetting
|
|
from onyx.context.search.enums import SearchType
|
|
from onyx.context.search.models import RetrievalDocs
|
|
from onyx.db.models import SearchDoc as DbSearchDoc
|
|
from onyx.file_store.models import FileDescriptor
|
|
from onyx.llm.override_models import PromptOverride
|
|
from onyx.tools.models import ToolCallFinalResult
|
|
from onyx.tools.models import ToolCallKickoff
|
|
from onyx.tools.models import ToolResponse
|
|
from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
|
|
|
if TYPE_CHECKING:
|
|
from onyx.db.models import Prompt
|
|
|
|
|
|
class LlmDoc(BaseModel):
|
|
"""This contains the minimal set information for the LLM portion including citations"""
|
|
|
|
document_id: str
|
|
content: str
|
|
blurb: str
|
|
semantic_identifier: str
|
|
source_type: DocumentSource
|
|
metadata: dict[str, str | list[str]]
|
|
updated_at: datetime | None
|
|
link: str | None
|
|
source_links: dict[int, str] | None
|
|
match_highlights: list[str] | None
|
|
|
|
|
|
class SubQuestionIdentifier(BaseModel):
|
|
level: int | None = None
|
|
level_question_num: int | None = None
|
|
|
|
|
|
# First chunk of info for streaming QA
|
|
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
|
|
rephrased_query: str | None = None
|
|
predicted_flow: QueryFlow | None
|
|
predicted_search: SearchType | None
|
|
applied_source_filters: list[DocumentSource] | None
|
|
applied_time_cutoff: datetime | None
|
|
recency_bias_multiplier: float
|
|
|
|
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
|
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
|
initial_dict["applied_time_cutoff"] = (
|
|
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
|
|
)
|
|
|
|
return initial_dict
|
|
|
|
|
|
class StreamStopReason(Enum):
|
|
CONTEXT_LENGTH = "context_length"
|
|
CANCELLED = "cancelled"
|
|
FINISHED = "finished"
|
|
|
|
|
|
class StreamType(Enum):
|
|
SUB_QUESTIONS = "sub_questions"
|
|
SUB_ANSWER = "sub_answer"
|
|
MAIN_ANSWER = "main_answer"
|
|
|
|
|
|
class StreamStopInfo(SubQuestionIdentifier):
|
|
stop_reason: StreamStopReason
|
|
|
|
stream_type: StreamType = StreamType.MAIN_ANSWER
|
|
|
|
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
|
data = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
|
data["stop_reason"] = self.stop_reason.name
|
|
return data
|
|
|
|
|
|
class LLMRelevanceFilterResponse(BaseModel):
|
|
llm_selected_doc_indices: list[int]
|
|
|
|
|
|
class FinalUsedContextDocsResponse(BaseModel):
|
|
final_context_docs: list[LlmDoc]
|
|
|
|
|
|
class RelevanceAnalysis(BaseModel):
|
|
relevant: bool
|
|
content: str | None = None
|
|
|
|
|
|
class SectionRelevancePiece(RelevanceAnalysis):
|
|
"""LLM analysis mapped to an Inference Section"""
|
|
|
|
document_id: str
|
|
chunk_id: int # ID of the center chunk for a given inference section
|
|
|
|
|
|
class DocumentRelevance(BaseModel):
|
|
"""Contains all relevance information for a given search"""
|
|
|
|
relevance_summaries: dict[str, RelevanceAnalysis]
|
|
|
|
|
|
class OnyxAnswerPiece(BaseModel):
|
|
# A small piece of a complete answer. Used for streaming back answers.
|
|
answer_piece: str | None # if None, specifies the end of an Answer
|
|
|
|
|
|
# An intermediate representation of citations, later translated into
|
|
# a mapping of the citation [n] number to SearchDoc
|
|
class CitationInfo(SubQuestionIdentifier):
|
|
citation_num: int
|
|
document_id: str
|
|
|
|
|
|
class AllCitations(BaseModel):
|
|
citations: list[CitationInfo]
|
|
|
|
|
|
# This is a mapping of the citation number to the document index within
|
|
# the result search doc set
|
|
class MessageSpecificCitations(BaseModel):
|
|
citation_map: dict[int, int]
|
|
|
|
|
|
class MessageResponseIDInfo(BaseModel):
|
|
user_message_id: int | None
|
|
reserved_assistant_message_id: int
|
|
|
|
|
|
class AgentMessageIDInfo(BaseModel):
|
|
level: int
|
|
message_id: int
|
|
|
|
|
|
class AgenticMessageResponseIDInfo(BaseModel):
|
|
agentic_message_ids: list[AgentMessageIDInfo]
|
|
|
|
|
|
class StreamingError(BaseModel):
|
|
error: str
|
|
stack_trace: str | None = None
|
|
|
|
|
|
class OnyxContext(BaseModel):
|
|
content: str
|
|
document_id: str
|
|
semantic_identifier: str
|
|
blurb: str
|
|
|
|
|
|
class OnyxContexts(BaseModel):
|
|
contexts: list[OnyxContext]
|
|
|
|
|
|
class OnyxAnswer(BaseModel):
|
|
answer: str | None
|
|
|
|
|
|
class ThreadMessage(BaseModel):
|
|
message: str
|
|
sender: str | None = None
|
|
role: MessageType = MessageType.USER
|
|
|
|
|
|
class ChatOnyxBotResponse(BaseModel):
|
|
answer: str | None = None
|
|
citations: list[CitationInfo] | None = None
|
|
docs: QADocsResponse | None = None
|
|
llm_selected_doc_indices: list[int] | None = None
|
|
error_msg: str | None = None
|
|
chat_message_id: int | None = None
|
|
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
|
|
|
|
|
|
class FileChatDisplay(BaseModel):
|
|
file_ids: list[str]
|
|
|
|
|
|
class CustomToolResponse(BaseModel):
|
|
response: ToolResultType
|
|
tool_name: str
|
|
|
|
|
|
class ToolConfig(BaseModel):
|
|
id: int
|
|
|
|
|
|
class PromptOverrideConfig(BaseModel):
|
|
name: str
|
|
description: str = ""
|
|
system_prompt: str
|
|
task_prompt: str = ""
|
|
include_citations: bool = True
|
|
datetime_aware: bool = True
|
|
|
|
|
|
class PersonaOverrideConfig(BaseModel):
|
|
name: str
|
|
description: str
|
|
search_type: SearchType = SearchType.SEMANTIC
|
|
num_chunks: float | None = None
|
|
llm_relevance_filter: bool = False
|
|
llm_filter_extraction: bool = False
|
|
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
|
llm_model_provider_override: str | None = None
|
|
llm_model_version_override: str | None = None
|
|
|
|
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
|
prompt_ids: list[int] = Field(default_factory=list)
|
|
|
|
document_set_ids: list[int] = Field(default_factory=list)
|
|
tools: list[ToolConfig] = Field(default_factory=list)
|
|
tool_ids: list[int] = Field(default_factory=list)
|
|
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
|
|
|
|
|
AnswerQuestionPossibleReturn = (
|
|
OnyxAnswerPiece
|
|
| CitationInfo
|
|
| OnyxContexts
|
|
| FileChatDisplay
|
|
| CustomToolResponse
|
|
| StreamingError
|
|
| StreamStopInfo
|
|
)
|
|
|
|
|
|
AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
|
|
|
|
|
class LLMMetricsContainer(BaseModel):
|
|
prompt_tokens: int
|
|
response_tokens: int
|
|
|
|
|
|
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
|
|
|
|
|
class DocumentPruningConfig(BaseModel):
|
|
max_chunks: int | None = None
|
|
max_window_percentage: float | None = None
|
|
max_tokens: int | None = None
|
|
# different pruning behavior is expected when the
|
|
# user manually selects documents they want to chat with
|
|
# e.g. we don't want to truncate each document to be no more
|
|
# than one chunk long
|
|
is_manually_selected_docs: bool = False
|
|
# If user specifies to include additional context Chunks for each match, then different pruning
|
|
# is used. As many Sections as possible are included, and the last Section is truncated
|
|
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
|
|
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
|
|
use_sections: bool = True
|
|
# If using tools, then we need to consider the tool length
|
|
tool_num_tokens: int = 0
|
|
# If using a tool message to represent the docs, then we have to JSON serialize
|
|
# the document content, which adds to the token count.
|
|
using_tool_message: bool = False
|
|
|
|
|
|
class ContextualPruningConfig(DocumentPruningConfig):
|
|
num_chunk_multiple: int
|
|
|
|
@classmethod
|
|
def from_doc_pruning_config(
|
|
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
|
|
) -> "ContextualPruningConfig":
|
|
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
|
|
|
|
|
|
class CitationConfig(BaseModel):
|
|
all_docs_useful: bool = False
|
|
|
|
|
|
class AnswerStyleConfig(BaseModel):
|
|
citation_config: CitationConfig
|
|
document_pruning_config: DocumentPruningConfig = Field(
|
|
default_factory=DocumentPruningConfig
|
|
)
|
|
# forces the LLM to return a structured response, see
|
|
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
|
# right now, only used by the simple chat API
|
|
structured_response_format: dict | None = None
|
|
|
|
|
|
class PromptConfig(BaseModel):
|
|
"""Final representation of the Prompt configuration passed
|
|
into the `PromptBuilder` object."""
|
|
|
|
system_prompt: str
|
|
task_prompt: str
|
|
datetime_aware: bool
|
|
include_citations: bool
|
|
|
|
@classmethod
|
|
def from_model(
|
|
cls, model: "Prompt", prompt_override: PromptOverride | None = None
|
|
) -> "PromptConfig":
|
|
override_system_prompt = (
|
|
prompt_override.system_prompt if prompt_override else None
|
|
)
|
|
override_task_prompt = prompt_override.task_prompt if prompt_override else None
|
|
|
|
return cls(
|
|
system_prompt=override_system_prompt or model.system_prompt,
|
|
task_prompt=override_task_prompt or model.task_prompt,
|
|
datetime_aware=model.datetime_aware,
|
|
include_citations=model.include_citations,
|
|
)
|
|
|
|
model_config = ConfigDict(frozen=True)
|
|
|
|
|
|
class SubQueryPiece(SubQuestionIdentifier):
|
|
sub_query: str
|
|
query_id: int
|
|
|
|
|
|
class AgentAnswerPiece(SubQuestionIdentifier):
|
|
answer_piece: str
|
|
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
|
|
|
|
|
class SubQuestionPiece(SubQuestionIdentifier):
|
|
sub_question: str
|
|
|
|
|
|
class ExtendedToolResponse(ToolResponse, SubQuestionIdentifier):
|
|
pass
|
|
|
|
|
|
class RefinedAnswerImprovement(BaseModel):
|
|
refined_answer_improvement: bool
|
|
|
|
|
|
AgentSearchPacket = (
|
|
SubQuestionPiece
|
|
| AgentAnswerPiece
|
|
| SubQueryPiece
|
|
| ExtendedToolResponse
|
|
| RefinedAnswerImprovement
|
|
)
|
|
|
|
AnswerPacket = (
|
|
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse
|
|
)
|
|
|
|
|
|
ResponsePart = (
|
|
OnyxAnswerPiece
|
|
| CitationInfo
|
|
| ToolCallKickoff
|
|
| ToolResponse
|
|
| ToolCallFinalResult
|
|
| StreamStopInfo
|
|
| AgentSearchPacket
|
|
)
|
|
|
|
AnswerStream = Iterator[AnswerPacket]
|
|
|
|
|
|
class AnswerPostInfo(BaseModel):
|
|
ai_message_files: list[FileDescriptor]
|
|
qa_docs_response: QADocsResponse | None = None
|
|
reference_db_search_docs: list[DbSearchDoc] | None = None
|
|
dropped_indices: list[int] | None = None
|
|
tool_result: ToolCallFinalResult | None = None
|
|
message_specific_citations: MessageSpecificCitations | None = None
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
class SubQuestionKey(BaseModel):
|
|
level: int
|
|
question_num: int
|
|
|
|
def __hash__(self) -> int:
|
|
return hash((self.level, self.question_num))
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
return isinstance(other, SubQuestionKey) and (
|
|
self.level,
|
|
self.question_num,
|
|
) == (other.level, other.question_num)
|