mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-27 10:13:05 +01:00
* add support for image search * quick fix up * k * k * k * k * nit * quick fix for connector tests
1131 lines
46 KiB
Python
1131 lines
46 KiB
Python
import traceback
|
|
from collections import defaultdict
|
|
from collections.abc import Callable
|
|
from collections.abc import Iterator
|
|
from functools import partial
|
|
from typing import cast
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
|
|
from onyx.chat.answer import Answer
|
|
from onyx.chat.chat_utils import create_chat_chain
|
|
from onyx.chat.chat_utils import create_temporary_persona
|
|
from onyx.chat.models import AgenticMessageResponseIDInfo
|
|
from onyx.chat.models import AgentMessageIDInfo
|
|
from onyx.chat.models import AgentSearchPacket
|
|
from onyx.chat.models import AllCitations
|
|
from onyx.chat.models import AnswerPostInfo
|
|
from onyx.chat.models import AnswerStyleConfig
|
|
from onyx.chat.models import ChatOnyxBotResponse
|
|
from onyx.chat.models import CitationConfig
|
|
from onyx.chat.models import CitationInfo
|
|
from onyx.chat.models import CustomToolResponse
|
|
from onyx.chat.models import DocumentPruningConfig
|
|
from onyx.chat.models import ExtendedToolResponse
|
|
from onyx.chat.models import FileChatDisplay
|
|
from onyx.chat.models import FinalUsedContextDocsResponse
|
|
from onyx.chat.models import LLMRelevanceFilterResponse
|
|
from onyx.chat.models import MessageResponseIDInfo
|
|
from onyx.chat.models import MessageSpecificCitations
|
|
from onyx.chat.models import OnyxAnswerPiece
|
|
from onyx.chat.models import OnyxContexts
|
|
from onyx.chat.models import PromptConfig
|
|
from onyx.chat.models import QADocsResponse
|
|
from onyx.chat.models import RefinedAnswerImprovement
|
|
from onyx.chat.models import StreamingError
|
|
from onyx.chat.models import StreamStopInfo
|
|
from onyx.chat.models import StreamStopReason
|
|
from onyx.chat.models import SubQuestionKey
|
|
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
|
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
|
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
|
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
|
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
|
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
|
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
|
|
from onyx.configs.constants import BASIC_KEY
|
|
from onyx.configs.constants import MessageType
|
|
from onyx.configs.constants import MilestoneRecordType
|
|
from onyx.configs.constants import NO_AUTH_USER_ID
|
|
from onyx.context.search.enums import LLMEvaluationType
|
|
from onyx.context.search.enums import OptionalSearchSetting
|
|
from onyx.context.search.enums import QueryFlow
|
|
from onyx.context.search.enums import SearchType
|
|
from onyx.context.search.models import InferenceSection
|
|
from onyx.context.search.models import RetrievalDetails
|
|
from onyx.context.search.models import SearchRequest
|
|
from onyx.context.search.retrieval.search_runner import (
|
|
inference_sections_from_ids,
|
|
)
|
|
from onyx.context.search.utils import chunks_or_sections_to_search_docs
|
|
from onyx.context.search.utils import dedupe_documents
|
|
from onyx.context.search.utils import drop_llm_indices
|
|
from onyx.context.search.utils import relevant_sections_to_indices
|
|
from onyx.db.chat import attach_files_to_chat_message
|
|
from onyx.db.chat import create_db_search_doc
|
|
from onyx.db.chat import create_new_chat_message
|
|
from onyx.db.chat import get_chat_message
|
|
from onyx.db.chat import get_chat_session_by_id
|
|
from onyx.db.chat import get_db_search_doc_by_id
|
|
from onyx.db.chat import get_doc_query_identifiers_from_model
|
|
from onyx.db.chat import get_or_create_root_message
|
|
from onyx.db.chat import reserve_message_id
|
|
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
|
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
|
from onyx.db.engine import get_session_context_manager
|
|
from onyx.db.milestone import check_multi_assistant_milestone
|
|
from onyx.db.milestone import create_milestone_if_not_exists
|
|
from onyx.db.milestone import update_user_assistant_milestone
|
|
from onyx.db.models import SearchDoc as DbSearchDoc
|
|
from onyx.db.models import ToolCall
|
|
from onyx.db.models import User
|
|
from onyx.db.persona import get_persona_by_id
|
|
from onyx.db.search_settings import get_current_search_settings
|
|
from onyx.document_index.factory import get_default_document_index
|
|
from onyx.file_store.models import ChatFileType
|
|
from onyx.file_store.models import FileDescriptor
|
|
from onyx.file_store.utils import load_all_chat_files
|
|
from onyx.file_store.utils import save_files
|
|
from onyx.llm.exceptions import GenAIDisabledException
|
|
from onyx.llm.factory import get_llms_for_persona
|
|
from onyx.llm.factory import get_main_llm_from_tuple
|
|
from onyx.llm.interfaces import LLM
|
|
from onyx.llm.models import PreviousMessage
|
|
from onyx.llm.utils import litellm_exception_to_error_msg
|
|
from onyx.natural_language_processing.utils import get_tokenizer
|
|
from onyx.server.query_and_chat.models import ChatMessageDetail
|
|
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
|
from onyx.server.utils import get_json_line
|
|
from onyx.tools.force import ForceUseTool
|
|
from onyx.tools.models import ToolResponse
|
|
from onyx.tools.tool import Tool
|
|
from onyx.tools.tool_constructor import construct_tools
|
|
from onyx.tools.tool_constructor import CustomToolConfig
|
|
from onyx.tools.tool_constructor import ImageGenerationToolConfig
|
|
from onyx.tools.tool_constructor import InternetSearchToolConfig
|
|
from onyx.tools.tool_constructor import SearchToolConfig
|
|
from onyx.tools.tool_implementations.custom.custom_tool import (
|
|
CUSTOM_TOOL_RESPONSE_ID,
|
|
)
|
|
from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary
|
|
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
|
IMAGE_GENERATION_RESPONSE_ID,
|
|
)
|
|
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
|
ImageGenerationResponse,
|
|
)
|
|
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
|
INTERNET_SEARCH_RESPONSE_ID,
|
|
)
|
|
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
|
internet_search_response_to_search_docs,
|
|
)
|
|
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
|
InternetSearchResponse,
|
|
)
|
|
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
|
InternetSearchTool,
|
|
)
|
|
from onyx.tools.tool_implementations.search.search_tool import (
|
|
FINAL_CONTEXT_DOCUMENTS_ID,
|
|
)
|
|
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
|
from onyx.tools.tool_implementations.search.search_tool import (
|
|
SEARCH_RESPONSE_SUMMARY_ID,
|
|
)
|
|
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
|
from onyx.tools.tool_implementations.search.search_tool import (
|
|
SECTION_RELEVANCE_LIST_ID,
|
|
)
|
|
from onyx.tools.tool_runner import ToolCallFinalResult
|
|
from onyx.utils.logger import setup_logger
|
|
from onyx.utils.long_term_log import LongTermLogger
|
|
from onyx.utils.telemetry import mt_cloud_telemetry
|
|
from onyx.utils.timing import log_function_time
|
|
from onyx.utils.timing import log_generator_function_time
|
|
from shared_configs.contextvars import get_current_tenant_id
|
|
|
|
logger = setup_logger()
|
|
ERROR_TYPE_CANCELLED = "cancelled"
|
|
|
|
|
|
def _translate_citations(
|
|
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
|
|
) -> MessageSpecificCitations:
|
|
"""Always cites the first instance of the document_id, assumes the db_docs
|
|
are sorted in the order displayed in the UI"""
|
|
doc_id_to_saved_doc_id_map: dict[str, int] = {}
|
|
for db_doc in db_docs:
|
|
if db_doc.document_id not in doc_id_to_saved_doc_id_map:
|
|
doc_id_to_saved_doc_id_map[db_doc.document_id] = db_doc.id
|
|
|
|
citation_to_saved_doc_id_map: dict[int, int] = {}
|
|
for citation in citations_list:
|
|
if citation.citation_num not in citation_to_saved_doc_id_map:
|
|
citation_to_saved_doc_id_map[
|
|
citation.citation_num
|
|
] = doc_id_to_saved_doc_id_map[citation.document_id]
|
|
|
|
return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map)
|
|
|
|
|
|
def _handle_search_tool_response_summary(
|
|
packet: ToolResponse,
|
|
db_session: Session,
|
|
selected_search_docs: list[DbSearchDoc] | None,
|
|
dedupe_docs: bool = False,
|
|
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
|
|
response_sumary = cast(SearchResponseSummary, packet.response)
|
|
|
|
is_extended = isinstance(packet, ExtendedToolResponse)
|
|
dropped_inds = None
|
|
if not selected_search_docs:
|
|
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
|
|
|
|
deduped_docs = top_docs
|
|
if (
|
|
dedupe_docs and not is_extended
|
|
): # Extended tool responses are already deduped
|
|
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
|
|
|
reference_db_search_docs = [
|
|
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
|
for doc in deduped_docs
|
|
]
|
|
else:
|
|
reference_db_search_docs = selected_search_docs
|
|
|
|
response_docs = [
|
|
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
|
for db_search_doc in reference_db_search_docs
|
|
]
|
|
|
|
level, question_num = None, None
|
|
if isinstance(packet, ExtendedToolResponse):
|
|
level, question_num = packet.level, packet.level_question_num
|
|
return (
|
|
QADocsResponse(
|
|
rephrased_query=response_sumary.rephrased_query,
|
|
top_documents=response_docs,
|
|
predicted_flow=response_sumary.predicted_flow,
|
|
predicted_search=response_sumary.predicted_search,
|
|
applied_source_filters=response_sumary.final_filters.source_type,
|
|
applied_time_cutoff=response_sumary.final_filters.time_cutoff,
|
|
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
|
|
level=level,
|
|
level_question_num=question_num,
|
|
),
|
|
reference_db_search_docs,
|
|
dropped_inds,
|
|
)
|
|
|
|
|
|
def _handle_internet_search_tool_response_summary(
|
|
packet: ToolResponse,
|
|
db_session: Session,
|
|
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
|
|
internet_search_response = cast(InternetSearchResponse, packet.response)
|
|
server_search_docs = internet_search_response_to_search_docs(
|
|
internet_search_response
|
|
)
|
|
|
|
reference_db_search_docs = [
|
|
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
|
for doc in server_search_docs
|
|
]
|
|
response_docs = [
|
|
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
|
for db_search_doc in reference_db_search_docs
|
|
]
|
|
return (
|
|
QADocsResponse(
|
|
rephrased_query=internet_search_response.revised_query,
|
|
top_documents=response_docs,
|
|
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
|
predicted_search=SearchType.SEMANTIC,
|
|
applied_source_filters=[],
|
|
applied_time_cutoff=None,
|
|
recency_bias_multiplier=1.0,
|
|
),
|
|
reference_db_search_docs,
|
|
)
|
|
|
|
|
|
def _get_force_search_settings(
|
|
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
|
|
) -> ForceUseTool:
|
|
internet_search_available = any(
|
|
isinstance(tool, InternetSearchTool) for tool in tools
|
|
)
|
|
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
|
|
|
|
if not internet_search_available and not search_tool_available:
|
|
# Does not matter much which tool is set here as force is false and neither tool is available
|
|
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
|
|
|
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
|
# Currently, the internet search tool does not support query override
|
|
args = (
|
|
{"query": new_msg_req.query_override}
|
|
if new_msg_req.query_override and tool_name == SearchTool._NAME
|
|
else None
|
|
)
|
|
|
|
if new_msg_req.file_descriptors:
|
|
# If user has uploaded files they're using, don't run any of the search tools
|
|
return ForceUseTool(force_use=False, tool_name=tool_name)
|
|
|
|
should_force_search = any(
|
|
[
|
|
new_msg_req.retrieval_options
|
|
and new_msg_req.retrieval_options.run_search
|
|
== OptionalSearchSetting.ALWAYS,
|
|
new_msg_req.search_doc_ids,
|
|
new_msg_req.query_override is not None,
|
|
DISABLE_LLM_CHOOSE_SEARCH,
|
|
]
|
|
)
|
|
|
|
if should_force_search:
|
|
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
|
|
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
|
|
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
|
|
|
|
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
|
|
|
|
|
|
ChatPacket = (
|
|
StreamingError
|
|
| QADocsResponse
|
|
| OnyxContexts
|
|
| LLMRelevanceFilterResponse
|
|
| FinalUsedContextDocsResponse
|
|
| ChatMessageDetail
|
|
| OnyxAnswerPiece
|
|
| AllCitations
|
|
| CitationInfo
|
|
| FileChatDisplay
|
|
| CustomToolResponse
|
|
| MessageSpecificCitations
|
|
| MessageResponseIDInfo
|
|
| AgenticMessageResponseIDInfo
|
|
| StreamStopInfo
|
|
| AgentSearchPacket
|
|
)
|
|
ChatPacketStream = Iterator[ChatPacket]
|
|
|
|
|
|
def stream_chat_message_objects(
|
|
new_msg_req: CreateChatMessageRequest,
|
|
user: User | None,
|
|
db_session: Session,
|
|
# Needed to translate persona num_chunks to tokens to the LLM
|
|
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
|
# For flow with search, don't include as many chunks as possible since we need to leave space
|
|
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
|
|
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
|
# if specified, uses the last user message and does not create a new user message based
|
|
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
|
litellm_additional_headers: dict[str, str] | None = None,
|
|
custom_tool_additional_headers: dict[str, str] | None = None,
|
|
is_connected: Callable[[], bool] | None = None,
|
|
enforce_chat_session_id_for_search_docs: bool = True,
|
|
bypass_acl: bool = False,
|
|
include_contexts: bool = False,
|
|
# a string which represents the history of a conversation. Used in cases like
|
|
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
|
|
# messages.
|
|
# NOTE: is not stored in the database at all.
|
|
single_message_history: str | None = None,
|
|
) -> ChatPacketStream:
|
|
"""Streams in order:
|
|
1. [conditional] Retrieved documents if a search needs to be run
|
|
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
|
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
|
4. [always] Details on the final AI response message that is created
|
|
"""
|
|
tenant_id = get_current_tenant_id()
|
|
use_existing_user_message = new_msg_req.use_existing_user_message
|
|
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
|
|
|
# Currently surrounding context is not supported for chat
|
|
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
|
|
new_msg_req.chunks_above = 0
|
|
new_msg_req.chunks_below = 0
|
|
|
|
llm: LLM
|
|
|
|
try:
|
|
user_id = user.id if user is not None else None
|
|
|
|
chat_session = get_chat_session_by_id(
|
|
chat_session_id=new_msg_req.chat_session_id,
|
|
user_id=user_id,
|
|
db_session=db_session,
|
|
)
|
|
|
|
message_text = new_msg_req.message
|
|
chat_session_id = new_msg_req.chat_session_id
|
|
parent_id = new_msg_req.parent_message_id
|
|
reference_doc_ids = new_msg_req.search_doc_ids
|
|
retrieval_options = new_msg_req.retrieval_options
|
|
alternate_assistant_id = new_msg_req.alternate_assistant_id
|
|
|
|
# permanent "log" store, used primarily for debugging
|
|
long_term_logger = LongTermLogger(
|
|
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
|
)
|
|
|
|
if alternate_assistant_id is not None:
|
|
# Allows users to specify a temporary persona (assistant) in the chat session
|
|
# this takes highest priority since it's user specified
|
|
persona = get_persona_by_id(
|
|
alternate_assistant_id,
|
|
user=user,
|
|
db_session=db_session,
|
|
is_for_edit=False,
|
|
)
|
|
elif new_msg_req.persona_override_config:
|
|
# Certain endpoints allow users to specify arbitrary persona settings
|
|
# this should never conflict with the alternate_assistant_id
|
|
persona = persona = create_temporary_persona(
|
|
db_session=db_session,
|
|
persona_config=new_msg_req.persona_override_config,
|
|
user=user,
|
|
)
|
|
else:
|
|
persona = chat_session.persona
|
|
|
|
if not persona:
|
|
raise RuntimeError("No persona specified or found for chat session")
|
|
|
|
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
|
|
user=user,
|
|
event_type=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
|
db_session=db_session,
|
|
)
|
|
|
|
update_user_assistant_milestone(
|
|
milestone=multi_assistant_milestone,
|
|
user_id=str(user.id) if user else NO_AUTH_USER_ID,
|
|
assistant_id=persona.id,
|
|
db_session=db_session,
|
|
)
|
|
|
|
_, just_hit_multi_assistant_milestone = check_multi_assistant_milestone(
|
|
milestone=multi_assistant_milestone,
|
|
db_session=db_session,
|
|
)
|
|
|
|
if just_hit_multi_assistant_milestone:
|
|
mt_cloud_telemetry(
|
|
distinct_id=tenant_id,
|
|
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
|
properties=None,
|
|
)
|
|
|
|
# If a prompt override is specified via the API, use that with highest priority
|
|
# but for saving it, we are just mapping it to an existing prompt
|
|
prompt_id = new_msg_req.prompt_id
|
|
if prompt_id is None and persona.prompts:
|
|
prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id
|
|
|
|
if reference_doc_ids is None and retrieval_options is None:
|
|
raise RuntimeError(
|
|
"Must specify a set of documents for chat or specify search options"
|
|
)
|
|
|
|
try:
|
|
llm, fast_llm = get_llms_for_persona(
|
|
persona=persona,
|
|
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
|
additional_headers=litellm_additional_headers,
|
|
long_term_logger=long_term_logger,
|
|
)
|
|
except GenAIDisabledException:
|
|
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
|
|
|
|
llm_provider = llm.config.model_provider
|
|
llm_model_name = llm.config.model_name
|
|
|
|
llm_tokenizer = get_tokenizer(
|
|
model_name=llm_model_name,
|
|
provider_type=llm_provider,
|
|
)
|
|
llm_tokenizer_encode_func = cast(
|
|
Callable[[str], list[int]], llm_tokenizer.encode
|
|
)
|
|
|
|
search_settings = get_current_search_settings(db_session)
|
|
document_index = get_default_document_index(search_settings, None)
|
|
|
|
# Every chat Session begins with an empty root message
|
|
root_message = get_or_create_root_message(
|
|
chat_session_id=chat_session_id, db_session=db_session
|
|
)
|
|
|
|
if parent_id is not None:
|
|
parent_message = get_chat_message(
|
|
chat_message_id=parent_id,
|
|
user_id=user_id,
|
|
db_session=db_session,
|
|
)
|
|
else:
|
|
parent_message = root_message
|
|
|
|
user_message = None
|
|
|
|
if new_msg_req.regenerate:
|
|
final_msg, history_msgs = create_chat_chain(
|
|
stop_at_message_id=parent_id,
|
|
chat_session_id=chat_session_id,
|
|
db_session=db_session,
|
|
)
|
|
|
|
elif not use_existing_user_message:
|
|
# Create new message at the right place in the tree and update the parent's child pointer
|
|
# Don't commit yet until we verify the chat message chain
|
|
user_message = create_new_chat_message(
|
|
chat_session_id=chat_session_id,
|
|
parent_message=parent_message,
|
|
prompt_id=prompt_id,
|
|
message=message_text,
|
|
token_count=len(llm_tokenizer_encode_func(message_text)),
|
|
message_type=MessageType.USER,
|
|
files=None, # Need to attach later for optimization to only load files once in parallel
|
|
db_session=db_session,
|
|
commit=False,
|
|
)
|
|
# re-create linear history of messages
|
|
final_msg, history_msgs = create_chat_chain(
|
|
chat_session_id=chat_session_id, db_session=db_session
|
|
)
|
|
if final_msg.id != user_message.id:
|
|
db_session.rollback()
|
|
raise RuntimeError(
|
|
"The new message was not on the mainline. "
|
|
"Be sure to update the chat pointers before calling this."
|
|
)
|
|
|
|
# NOTE: do not commit user message - it will be committed when the
|
|
# assistant message is successfully generated
|
|
else:
|
|
# re-create linear history of messages
|
|
final_msg, history_msgs = create_chat_chain(
|
|
chat_session_id=chat_session_id, db_session=db_session
|
|
)
|
|
if existing_assistant_message_id is None:
|
|
if final_msg.message_type != MessageType.USER:
|
|
raise RuntimeError(
|
|
"The last message was not a user message. Cannot call "
|
|
"`stream_chat_message_objects` with `is_regenerate=True` "
|
|
"when the last message is not a user message."
|
|
)
|
|
else:
|
|
if final_msg.id != existing_assistant_message_id:
|
|
raise RuntimeError(
|
|
"The last message was not the existing assistant message. "
|
|
f"Final message id: {final_msg.id}, "
|
|
f"existing assistant message id: {existing_assistant_message_id}"
|
|
)
|
|
|
|
# load all files needed for this chat chain in memory
|
|
files = load_all_chat_files(
|
|
history_msgs, new_msg_req.file_descriptors, db_session
|
|
)
|
|
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
|
latest_query_files = [file for file in files if file.file_id in req_file_ids]
|
|
|
|
if user_message:
|
|
attach_files_to_chat_message(
|
|
chat_message=user_message,
|
|
files=[
|
|
new_file.to_file_descriptor() for new_file in latest_query_files
|
|
],
|
|
db_session=db_session,
|
|
commit=False,
|
|
)
|
|
|
|
selected_db_search_docs = None
|
|
selected_sections: list[InferenceSection] | None = None
|
|
if reference_doc_ids:
|
|
identifier_tuples = get_doc_query_identifiers_from_model(
|
|
search_doc_ids=reference_doc_ids,
|
|
chat_session=chat_session,
|
|
user_id=user_id,
|
|
db_session=db_session,
|
|
enforce_chat_session_id_for_search_docs=enforce_chat_session_id_for_search_docs,
|
|
)
|
|
|
|
# Generates full documents currently
|
|
# May extend to use sections instead in the future
|
|
selected_sections = inference_sections_from_ids(
|
|
doc_identifiers=identifier_tuples,
|
|
document_index=document_index,
|
|
)
|
|
document_pruning_config = DocumentPruningConfig(
|
|
is_manually_selected_docs=True
|
|
)
|
|
|
|
# In case the search doc is deleted, just don't include it
|
|
# though this should never happen
|
|
db_search_docs_or_none = [
|
|
get_db_search_doc_by_id(doc_id=doc_id, db_session=db_session)
|
|
for doc_id in reference_doc_ids
|
|
]
|
|
|
|
selected_db_search_docs = [
|
|
db_sd for db_sd in db_search_docs_or_none if db_sd
|
|
]
|
|
|
|
else:
|
|
document_pruning_config = DocumentPruningConfig(
|
|
max_chunks=int(
|
|
persona.num_chunks
|
|
if persona.num_chunks is not None
|
|
else default_num_chunks
|
|
),
|
|
max_window_percentage=max_document_percentage,
|
|
)
|
|
|
|
# we don't need to reserve a message id if we're using an existing assistant message
|
|
reserved_message_id = (
|
|
final_msg.id
|
|
if existing_assistant_message_id is not None
|
|
else reserve_message_id(
|
|
db_session=db_session,
|
|
chat_session_id=chat_session_id,
|
|
parent_message=user_message.id
|
|
if user_message is not None
|
|
else parent_message.id,
|
|
message_type=MessageType.ASSISTANT,
|
|
)
|
|
)
|
|
yield MessageResponseIDInfo(
|
|
user_message_id=user_message.id if user_message else None,
|
|
reserved_assistant_message_id=reserved_message_id,
|
|
)
|
|
|
|
overridden_model = (
|
|
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
|
|
)
|
|
|
|
# Cannot determine these without the LLM step or breaking out early
|
|
partial_response = partial(
|
|
create_new_chat_message,
|
|
chat_session_id=chat_session_id,
|
|
# if we're using an existing assistant message, then this will just be an
|
|
# update operation, in which case the parent should be the parent of
|
|
# the latest. If we're creating a new assistant message, then the parent
|
|
# should be the latest message (latest user message)
|
|
parent_message=(
|
|
final_msg if existing_assistant_message_id is None else parent_message
|
|
),
|
|
prompt_id=prompt_id,
|
|
overridden_model=overridden_model,
|
|
# message=,
|
|
# rephrased_query=,
|
|
# token_count=,
|
|
message_type=MessageType.ASSISTANT,
|
|
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
|
# error=,
|
|
# reference_docs=,
|
|
db_session=db_session,
|
|
commit=False,
|
|
reserved_message_id=reserved_message_id,
|
|
is_agentic=new_msg_req.use_agentic_search,
|
|
)
|
|
|
|
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
|
if new_msg_req.persona_override_config:
|
|
prompt_config = PromptConfig(
|
|
system_prompt=new_msg_req.persona_override_config.prompts[
|
|
0
|
|
].system_prompt,
|
|
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
|
|
datetime_aware=new_msg_req.persona_override_config.prompts[
|
|
0
|
|
].datetime_aware,
|
|
include_citations=new_msg_req.persona_override_config.prompts[
|
|
0
|
|
].include_citations,
|
|
)
|
|
elif prompt_override:
|
|
if not final_msg.prompt:
|
|
raise ValueError(
|
|
"Prompt override cannot be applied, no base prompt found."
|
|
)
|
|
prompt_config = PromptConfig.from_model(
|
|
final_msg.prompt,
|
|
prompt_override=prompt_override,
|
|
)
|
|
elif final_msg.prompt:
|
|
prompt_config = PromptConfig.from_model(final_msg.prompt)
|
|
else:
|
|
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
|
|
|
answer_style_config = AnswerStyleConfig(
|
|
citation_config=CitationConfig(
|
|
all_docs_useful=selected_db_search_docs is not None
|
|
),
|
|
document_pruning_config=document_pruning_config,
|
|
structured_response_format=new_msg_req.structured_response_format,
|
|
)
|
|
|
|
tool_dict = construct_tools(
|
|
persona=persona,
|
|
prompt_config=prompt_config,
|
|
db_session=db_session,
|
|
user=user,
|
|
llm=llm,
|
|
fast_llm=fast_llm,
|
|
search_tool_config=SearchToolConfig(
|
|
answer_style_config=answer_style_config,
|
|
document_pruning_config=document_pruning_config,
|
|
retrieval_options=retrieval_options or RetrievalDetails(),
|
|
rerank_settings=new_msg_req.rerank_settings,
|
|
selected_sections=selected_sections,
|
|
chunks_above=new_msg_req.chunks_above,
|
|
chunks_below=new_msg_req.chunks_below,
|
|
full_doc=new_msg_req.full_doc,
|
|
latest_query_files=latest_query_files,
|
|
bypass_acl=bypass_acl,
|
|
),
|
|
internet_search_tool_config=InternetSearchToolConfig(
|
|
answer_style_config=answer_style_config,
|
|
),
|
|
image_generation_tool_config=ImageGenerationToolConfig(
|
|
additional_headers=litellm_additional_headers,
|
|
),
|
|
custom_tool_config=CustomToolConfig(
|
|
chat_session_id=chat_session_id,
|
|
message_id=user_message.id if user_message else None,
|
|
additional_headers=custom_tool_additional_headers,
|
|
),
|
|
)
|
|
|
|
tools: list[Tool] = []
|
|
for tool_list in tool_dict.values():
|
|
tools.extend(tool_list)
|
|
|
|
# TODO: unify message history with single message history
|
|
message_history = [
|
|
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
|
]
|
|
|
|
search_request = SearchRequest(
|
|
query=final_msg.message,
|
|
evaluation_type=(
|
|
LLMEvaluationType.BASIC
|
|
if persona.llm_relevance_filter
|
|
else LLMEvaluationType.SKIP
|
|
),
|
|
human_selected_filters=(
|
|
retrieval_options.filters if retrieval_options else None
|
|
),
|
|
persona=persona,
|
|
offset=(retrieval_options.offset if retrieval_options else None),
|
|
limit=retrieval_options.limit if retrieval_options else None,
|
|
rerank_settings=new_msg_req.rerank_settings,
|
|
chunks_above=new_msg_req.chunks_above,
|
|
chunks_below=new_msg_req.chunks_below,
|
|
full_doc=new_msg_req.full_doc,
|
|
enable_auto_detect_filters=(
|
|
retrieval_options.enable_auto_detect_filters
|
|
if retrieval_options
|
|
else None
|
|
),
|
|
)
|
|
|
|
force_use_tool = _get_force_search_settings(new_msg_req, tools)
|
|
prompt_builder = AnswerPromptBuilder(
|
|
user_message=default_build_user_message(
|
|
user_query=final_msg.message,
|
|
prompt_config=prompt_config,
|
|
files=latest_query_files,
|
|
single_message_history=single_message_history,
|
|
),
|
|
system_message=default_build_system_message(prompt_config, llm.config),
|
|
message_history=message_history,
|
|
llm_config=llm.config,
|
|
raw_user_query=final_msg.message,
|
|
raw_user_uploaded_files=latest_query_files or [],
|
|
single_message_history=single_message_history,
|
|
)
|
|
|
|
# LLM prompt building, response capturing, etc.
|
|
|
|
answer = Answer(
|
|
prompt_builder=prompt_builder,
|
|
is_connected=is_connected,
|
|
latest_query_files=latest_query_files,
|
|
answer_style_config=answer_style_config,
|
|
llm=(
|
|
llm
|
|
or get_main_llm_from_tuple(
|
|
get_llms_for_persona(
|
|
persona=persona,
|
|
llm_override=(
|
|
new_msg_req.llm_override or chat_session.llm_override
|
|
),
|
|
additional_headers=litellm_additional_headers,
|
|
)
|
|
)
|
|
),
|
|
fast_llm=fast_llm,
|
|
force_use_tool=force_use_tool,
|
|
search_request=search_request,
|
|
chat_session_id=chat_session_id,
|
|
current_agent_message_id=reserved_message_id,
|
|
tools=tools,
|
|
db_session=db_session,
|
|
use_agentic_search=new_msg_req.use_agentic_search,
|
|
)
|
|
|
|
# reference_db_search_docs = None
|
|
# qa_docs_response = None
|
|
# # any files to associate with the AI message e.g. dall-e generated images
|
|
# ai_message_files = []
|
|
# dropped_indices = None
|
|
# tool_result = None
|
|
|
|
# TODO: different channels for stored info when it's coming from the agent flow
|
|
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
|
|
lambda: AnswerPostInfo(ai_message_files=[])
|
|
)
|
|
refined_answer_improvement = True
|
|
for packet in answer.processed_streamed_output:
|
|
if isinstance(packet, ToolResponse):
|
|
level, level_question_num = (
|
|
(packet.level, packet.level_question_num)
|
|
if isinstance(packet, ExtendedToolResponse)
|
|
else BASIC_KEY
|
|
)
|
|
assert level is not None
|
|
assert level_question_num is not None
|
|
info = info_by_subq[
|
|
SubQuestionKey(level=level, question_num=level_question_num)
|
|
]
|
|
# TODO: don't need to dedupe here when we do it in agent flow
|
|
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
|
(
|
|
info.qa_docs_response,
|
|
info.reference_db_search_docs,
|
|
info.dropped_indices,
|
|
) = _handle_search_tool_response_summary(
|
|
packet=packet,
|
|
db_session=db_session,
|
|
selected_search_docs=selected_db_search_docs,
|
|
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
|
dedupe_docs=(
|
|
retrieval_options.dedupe_docs
|
|
if retrieval_options
|
|
else False
|
|
),
|
|
)
|
|
yield info.qa_docs_response
|
|
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
|
relevance_sections = packet.response
|
|
|
|
if info.reference_db_search_docs is None:
|
|
logger.warning(
|
|
"No reference docs found for relevance filtering"
|
|
)
|
|
continue
|
|
|
|
llm_indices = relevant_sections_to_indices(
|
|
relevance_sections=relevance_sections,
|
|
items=[
|
|
translate_db_search_doc_to_server_search_doc(doc)
|
|
for doc in info.reference_db_search_docs
|
|
],
|
|
)
|
|
|
|
if info.dropped_indices:
|
|
llm_indices = drop_llm_indices(
|
|
llm_indices=llm_indices,
|
|
search_docs=info.reference_db_search_docs,
|
|
dropped_indices=info.dropped_indices,
|
|
)
|
|
|
|
yield LLMRelevanceFilterResponse(
|
|
llm_selected_doc_indices=llm_indices
|
|
)
|
|
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
|
yield FinalUsedContextDocsResponse(
|
|
final_context_docs=packet.response
|
|
)
|
|
|
|
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
|
img_generation_response = cast(
|
|
list[ImageGenerationResponse], packet.response
|
|
)
|
|
|
|
file_ids = save_files(
|
|
urls=[img.url for img in img_generation_response if img.url],
|
|
base64_files=[
|
|
img.image_data
|
|
for img in img_generation_response
|
|
if img.image_data
|
|
],
|
|
)
|
|
info.ai_message_files.extend(
|
|
[
|
|
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
|
for file_id in file_ids
|
|
]
|
|
)
|
|
yield FileChatDisplay(
|
|
file_ids=[str(file_id) for file_id in file_ids]
|
|
)
|
|
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
|
(
|
|
info.qa_docs_response,
|
|
info.reference_db_search_docs,
|
|
) = _handle_internet_search_tool_response_summary(
|
|
packet=packet,
|
|
db_session=db_session,
|
|
)
|
|
yield info.qa_docs_response
|
|
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
|
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
|
|
|
if (
|
|
custom_tool_response.response_type == "image"
|
|
or custom_tool_response.response_type == "csv"
|
|
):
|
|
file_ids = custom_tool_response.tool_result.file_ids
|
|
info.ai_message_files.extend(
|
|
[
|
|
FileDescriptor(
|
|
id=str(file_id),
|
|
type=(
|
|
ChatFileType.IMAGE
|
|
if custom_tool_response.response_type == "image"
|
|
else ChatFileType.CSV
|
|
),
|
|
)
|
|
for file_id in file_ids
|
|
]
|
|
)
|
|
yield FileChatDisplay(
|
|
file_ids=[str(file_id) for file_id in file_ids]
|
|
)
|
|
else:
|
|
yield CustomToolResponse(
|
|
response=custom_tool_response.tool_result,
|
|
tool_name=custom_tool_response.tool_name,
|
|
)
|
|
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
|
|
yield cast(OnyxContexts, packet.response)
|
|
|
|
elif isinstance(packet, StreamStopInfo):
|
|
if packet.stop_reason == StreamStopReason.FINISHED:
|
|
yield packet
|
|
elif isinstance(packet, RefinedAnswerImprovement):
|
|
refined_answer_improvement = packet.refined_answer_improvement
|
|
yield packet
|
|
else:
|
|
if isinstance(packet, ToolCallFinalResult):
|
|
level, level_question_num = (
|
|
(packet.level, packet.level_question_num)
|
|
if packet.level is not None
|
|
and packet.level_question_num is not None
|
|
else BASIC_KEY
|
|
)
|
|
info = info_by_subq[
|
|
SubQuestionKey(level=level, question_num=level_question_num)
|
|
]
|
|
info.tool_result = packet
|
|
yield cast(ChatPacket, packet)
|
|
logger.debug("Reached end of stream")
|
|
except ValueError as e:
|
|
logger.exception("Failed to process chat message.")
|
|
|
|
error_msg = str(e)
|
|
yield StreamingError(error=error_msg)
|
|
db_session.rollback()
|
|
return
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Failed to process chat message due to {e}")
|
|
error_msg = str(e)
|
|
stack_trace = traceback.format_exc()
|
|
|
|
if isinstance(e, ToolCallException):
|
|
yield StreamingError(error=error_msg, stack_trace=stack_trace)
|
|
else:
|
|
if llm:
|
|
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
|
if llm.config.api_key and len(llm.config.api_key) > 2:
|
|
error_msg = error_msg.replace(
|
|
llm.config.api_key, "[REDACTED_API_KEY]"
|
|
)
|
|
stack_trace = stack_trace.replace(
|
|
llm.config.api_key, "[REDACTED_API_KEY]"
|
|
)
|
|
|
|
yield StreamingError(error=client_error_msg, stack_trace=stack_trace)
|
|
|
|
db_session.rollback()
|
|
return
|
|
|
|
# Post-LLM answer processing
|
|
try:
|
|
tool_name_to_tool_id: dict[str, int] = {}
|
|
for tool_id, tool_list in tool_dict.items():
|
|
for tool in tool_list:
|
|
tool_name_to_tool_id[tool.name] = tool_id
|
|
|
|
subq_citations = answer.citations_by_subquestion()
|
|
for subq_key in subq_citations:
|
|
info = info_by_subq[subq_key]
|
|
logger.debug("Post-LLM answer processing")
|
|
if info.reference_db_search_docs:
|
|
info.message_specific_citations = _translate_citations(
|
|
citations_list=subq_citations[subq_key],
|
|
db_docs=info.reference_db_search_docs,
|
|
)
|
|
|
|
# TODO: AllCitations should contain subq info?
|
|
if not answer.is_cancelled():
|
|
yield AllCitations(citations=subq_citations[subq_key])
|
|
|
|
# Saving Gen AI answer and responding with message info
|
|
|
|
basic_key = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
|
|
info = (
|
|
info_by_subq[basic_key]
|
|
if basic_key in info_by_subq
|
|
else info_by_subq[
|
|
SubQuestionKey(
|
|
level=AGENT_SEARCH_INITIAL_KEY[0],
|
|
question_num=AGENT_SEARCH_INITIAL_KEY[1],
|
|
)
|
|
]
|
|
)
|
|
gen_ai_response_message = partial_response(
|
|
message=answer.llm_answer,
|
|
rephrased_query=(
|
|
info.qa_docs_response.rephrased_query if info.qa_docs_response else None
|
|
),
|
|
reference_docs=info.reference_db_search_docs,
|
|
files=info.ai_message_files,
|
|
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
|
citations=(
|
|
info.message_specific_citations.citation_map
|
|
if info.message_specific_citations
|
|
else None
|
|
),
|
|
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
|
|
tool_call=(
|
|
ToolCall(
|
|
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
|
|
tool_name=info.tool_result.tool_name,
|
|
tool_arguments=info.tool_result.tool_args,
|
|
tool_result=info.tool_result.tool_result,
|
|
)
|
|
if info.tool_result
|
|
else None
|
|
),
|
|
)
|
|
|
|
# add answers for levels >= 1, where each level has the previous as its parent. Use
|
|
# the answer_by_level method in answer.py to get the answers for each level
|
|
next_level = 1
|
|
prev_message = gen_ai_response_message
|
|
agent_answers = answer.llm_answer_by_level()
|
|
agentic_message_ids = []
|
|
while next_level in agent_answers:
|
|
next_answer = agent_answers[next_level]
|
|
info = info_by_subq[
|
|
SubQuestionKey(
|
|
level=next_level, question_num=AGENT_SEARCH_INITIAL_KEY[1]
|
|
)
|
|
]
|
|
next_answer_message = create_new_chat_message(
|
|
chat_session_id=chat_session_id,
|
|
parent_message=prev_message,
|
|
message=next_answer,
|
|
prompt_id=None,
|
|
token_count=len(llm_tokenizer_encode_func(next_answer)),
|
|
message_type=MessageType.ASSISTANT,
|
|
db_session=db_session,
|
|
files=info.ai_message_files,
|
|
reference_docs=info.reference_db_search_docs,
|
|
citations=info.message_specific_citations.citation_map
|
|
if info.message_specific_citations
|
|
else None,
|
|
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
|
|
refined_answer_improvement=refined_answer_improvement,
|
|
is_agentic=True,
|
|
)
|
|
agentic_message_ids.append(
|
|
AgentMessageIDInfo(level=next_level, message_id=next_answer_message.id)
|
|
)
|
|
next_level += 1
|
|
prev_message = next_answer_message
|
|
|
|
logger.debug("Committing messages")
|
|
db_session.commit() # actually save user / assistant message
|
|
|
|
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)
|
|
|
|
yield translate_db_message_to_chat_message_detail(gen_ai_response_message)
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.exception(error_msg)
|
|
|
|
# Frontend will erase whatever answer and show this instead
|
|
yield StreamingError(error="Failed to parse LLM output")
|
|
|
|
|
|
@log_generator_function_time()
|
|
def stream_chat_message(
|
|
new_msg_req: CreateChatMessageRequest,
|
|
user: User | None,
|
|
litellm_additional_headers: dict[str, str] | None = None,
|
|
custom_tool_additional_headers: dict[str, str] | None = None,
|
|
is_connected: Callable[[], bool] | None = None,
|
|
) -> Iterator[str]:
|
|
with get_session_context_manager() as db_session:
|
|
objects = stream_chat_message_objects(
|
|
new_msg_req=new_msg_req,
|
|
user=user,
|
|
db_session=db_session,
|
|
litellm_additional_headers=litellm_additional_headers,
|
|
custom_tool_additional_headers=custom_tool_additional_headers,
|
|
is_connected=is_connected,
|
|
)
|
|
for obj in objects:
|
|
yield get_json_line(obj.model_dump())
|
|
|
|
|
|
@log_function_time()
|
|
def gather_stream_for_slack(
|
|
packets: ChatPacketStream,
|
|
) -> ChatOnyxBotResponse:
|
|
response = ChatOnyxBotResponse()
|
|
|
|
answer = ""
|
|
for packet in packets:
|
|
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
|
answer += packet.answer_piece
|
|
elif isinstance(packet, QADocsResponse):
|
|
response.docs = packet
|
|
elif isinstance(packet, StreamingError):
|
|
response.error_msg = packet.error
|
|
elif isinstance(packet, ChatMessageDetail):
|
|
response.chat_message_id = packet.message_id
|
|
elif isinstance(packet, LLMRelevanceFilterResponse):
|
|
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
|
elif isinstance(packet, AllCitations):
|
|
response.citations = packet.citations
|
|
|
|
if answer:
|
|
response.answer = answer
|
|
|
|
return response
|