import traceback from collections import defaultdict from collections.abc import Callable from collections.abc import Iterator from functools import partial from typing import cast from sqlalchemy.orm import Session from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException from onyx.chat.answer import Answer from onyx.chat.chat_utils import create_chat_chain from onyx.chat.chat_utils import create_temporary_persona from onyx.chat.models import AgenticMessageResponseIDInfo from onyx.chat.models import AgentMessageIDInfo from onyx.chat.models import AgentSearchPacket from onyx.chat.models import AllCitations from onyx.chat.models import AnswerPostInfo from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import ChatOnyxBotResponse from onyx.chat.models import CitationConfig from onyx.chat.models import CitationInfo from onyx.chat.models import CustomToolResponse from onyx.chat.models import DocumentPruningConfig from onyx.chat.models import ExtendedToolResponse from onyx.chat.models import FileChatDisplay from onyx.chat.models import FinalUsedContextDocsResponse from onyx.chat.models import LLMRelevanceFilterResponse from onyx.chat.models import MessageResponseIDInfo from onyx.chat.models import MessageSpecificCitations from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import OnyxContexts from onyx.chat.models import PromptConfig from onyx.chat.models import QADocsResponse from onyx.chat.models import RefinedAnswerImprovement from onyx.chat.models import StreamingError from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason from onyx.chat.models import SubQuestionKey from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY from onyx.configs.constants import BASIC_KEY from onyx.configs.constants import MessageType from onyx.configs.constants import MilestoneRecordType from onyx.configs.constants import NO_AUTH_USER_ID from onyx.context.search.enums import LLMEvaluationType from onyx.context.search.enums import OptionalSearchSetting from onyx.context.search.enums import QueryFlow from onyx.context.search.enums import SearchType from onyx.context.search.models import InferenceSection from onyx.context.search.models import RetrievalDetails from onyx.context.search.models import SearchRequest from onyx.context.search.retrieval.search_runner import ( inference_sections_from_ids, ) from onyx.context.search.utils import chunks_or_sections_to_search_docs from onyx.context.search.utils import dedupe_documents from onyx.context.search.utils import drop_llm_indices from onyx.context.search.utils import relevant_sections_to_indices from onyx.db.chat import attach_files_to_chat_message from onyx.db.chat import create_db_search_doc from onyx.db.chat import create_new_chat_message from onyx.db.chat import get_chat_message from onyx.db.chat import get_chat_session_by_id from onyx.db.chat import get_db_search_doc_by_id from onyx.db.chat import get_doc_query_identifiers_from_model from onyx.db.chat import get_or_create_root_message from onyx.db.chat import reserve_message_id from onyx.db.chat import translate_db_message_to_chat_message_detail from onyx.db.chat import translate_db_search_doc_to_server_search_doc from onyx.db.engine import get_session_context_manager from onyx.db.milestone import check_multi_assistant_milestone from onyx.db.milestone import create_milestone_if_not_exists from onyx.db.milestone import update_user_assistant_milestone from onyx.db.models import SearchDoc as DbSearchDoc from onyx.db.models import ToolCall from onyx.db.models import User from onyx.db.persona import get_persona_by_id from onyx.db.search_settings import get_current_search_settings from onyx.document_index.factory import get_default_document_index from onyx.file_store.models import ChatFileType from onyx.file_store.models import FileDescriptor from onyx.file_store.utils import load_all_chat_files from onyx.file_store.utils import save_files from onyx.llm.exceptions import GenAIDisabledException from onyx.llm.factory import get_llms_for_persona from onyx.llm.factory import get_main_llm_from_tuple from onyx.llm.interfaces import LLM from onyx.llm.models import PreviousMessage from onyx.llm.utils import litellm_exception_to_error_msg from onyx.natural_language_processing.utils import get_tokenizer from onyx.server.query_and_chat.models import ChatMessageDetail from onyx.server.query_and_chat.models import CreateChatMessageRequest from onyx.server.utils import get_json_line from onyx.tools.force import ForceUseTool from onyx.tools.models import ToolResponse from onyx.tools.tool import Tool from onyx.tools.tool_constructor import construct_tools from onyx.tools.tool_constructor import CustomToolConfig from onyx.tools.tool_constructor import ImageGenerationToolConfig from onyx.tools.tool_constructor import InternetSearchToolConfig from onyx.tools.tool_constructor import SearchToolConfig from onyx.tools.tool_implementations.custom.custom_tool import ( CUSTOM_TOOL_RESPONSE_ID, ) from onyx.tools.tool_implementations.custom.custom_tool import CustomToolCallSummary from onyx.tools.tool_implementations.images.image_generation_tool import ( IMAGE_GENERATION_RESPONSE_ID, ) from onyx.tools.tool_implementations.images.image_generation_tool import ( ImageGenerationResponse, ) from onyx.tools.tool_implementations.internet_search.internet_search_tool import ( INTERNET_SEARCH_RESPONSE_ID, ) from onyx.tools.tool_implementations.internet_search.internet_search_tool import ( internet_search_response_to_search_docs, ) from onyx.tools.tool_implementations.internet_search.internet_search_tool import ( InternetSearchResponse, ) from onyx.tools.tool_implementations.internet_search.internet_search_tool import ( InternetSearchTool, ) from onyx.tools.tool_implementations.search.search_tool import ( FINAL_CONTEXT_DOCUMENTS_ID, ) from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID from onyx.tools.tool_implementations.search.search_tool import ( SEARCH_RESPONSE_SUMMARY_ID, ) from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.search.search_tool import ( SECTION_RELEVANCE_LIST_ID, ) from onyx.tools.tool_runner import ToolCallFinalResult from onyx.utils.logger import setup_logger from onyx.utils.long_term_log import LongTermLogger from onyx.utils.telemetry import mt_cloud_telemetry from onyx.utils.timing import log_function_time from onyx.utils.timing import log_generator_function_time from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() ERROR_TYPE_CANCELLED = "cancelled" def _translate_citations( citations_list: list[CitationInfo], db_docs: list[DbSearchDoc] ) -> MessageSpecificCitations: """Always cites the first instance of the document_id, assumes the db_docs are sorted in the order displayed in the UI""" doc_id_to_saved_doc_id_map: dict[str, int] = {} for db_doc in db_docs: if db_doc.document_id not in doc_id_to_saved_doc_id_map: doc_id_to_saved_doc_id_map[db_doc.document_id] = db_doc.id citation_to_saved_doc_id_map: dict[int, int] = {} for citation in citations_list: if citation.citation_num not in citation_to_saved_doc_id_map: citation_to_saved_doc_id_map[ citation.citation_num ] = doc_id_to_saved_doc_id_map[citation.document_id] return MessageSpecificCitations(citation_map=citation_to_saved_doc_id_map) def _handle_search_tool_response_summary( packet: ToolResponse, db_session: Session, selected_search_docs: list[DbSearchDoc] | None, dedupe_docs: bool = False, ) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]: response_sumary = cast(SearchResponseSummary, packet.response) is_extended = isinstance(packet, ExtendedToolResponse) dropped_inds = None if not selected_search_docs: top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections) deduped_docs = top_docs if ( dedupe_docs and not is_extended ): # Extended tool responses are already deduped deduped_docs, dropped_inds = dedupe_documents(top_docs) reference_db_search_docs = [ create_db_search_doc(server_search_doc=doc, db_session=db_session) for doc in deduped_docs ] else: reference_db_search_docs = selected_search_docs response_docs = [ translate_db_search_doc_to_server_search_doc(db_search_doc) for db_search_doc in reference_db_search_docs ] level, question_num = None, None if isinstance(packet, ExtendedToolResponse): level, question_num = packet.level, packet.level_question_num return ( QADocsResponse( rephrased_query=response_sumary.rephrased_query, top_documents=response_docs, predicted_flow=response_sumary.predicted_flow, predicted_search=response_sumary.predicted_search, applied_source_filters=response_sumary.final_filters.source_type, applied_time_cutoff=response_sumary.final_filters.time_cutoff, recency_bias_multiplier=response_sumary.recency_bias_multiplier, level=level, level_question_num=question_num, ), reference_db_search_docs, dropped_inds, ) def _handle_internet_search_tool_response_summary( packet: ToolResponse, db_session: Session, ) -> tuple[QADocsResponse, list[DbSearchDoc]]: internet_search_response = cast(InternetSearchResponse, packet.response) server_search_docs = internet_search_response_to_search_docs( internet_search_response ) reference_db_search_docs = [ create_db_search_doc(server_search_doc=doc, db_session=db_session) for doc in server_search_docs ] response_docs = [ translate_db_search_doc_to_server_search_doc(db_search_doc) for db_search_doc in reference_db_search_docs ] return ( QADocsResponse( rephrased_query=internet_search_response.revised_query, top_documents=response_docs, predicted_flow=QueryFlow.QUESTION_ANSWER, predicted_search=SearchType.SEMANTIC, applied_source_filters=[], applied_time_cutoff=None, recency_bias_multiplier=1.0, ), reference_db_search_docs, ) def _get_force_search_settings( new_msg_req: CreateChatMessageRequest, tools: list[Tool] ) -> ForceUseTool: internet_search_available = any( isinstance(tool, InternetSearchTool) for tool in tools ) search_tool_available = any(isinstance(tool, SearchTool) for tool in tools) if not internet_search_available and not search_tool_available: # Does not matter much which tool is set here as force is false and neither tool is available return ForceUseTool(force_use=False, tool_name=SearchTool._NAME) tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME # Currently, the internet search tool does not support query override args = ( {"query": new_msg_req.query_override} if new_msg_req.query_override and tool_name == SearchTool._NAME else None ) if new_msg_req.file_descriptors: # If user has uploaded files they're using, don't run any of the search tools return ForceUseTool(force_use=False, tool_name=tool_name) should_force_search = any( [ new_msg_req.retrieval_options and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS, new_msg_req.search_doc_ids, new_msg_req.query_override is not None, DISABLE_LLM_CHOOSE_SEARCH, ] ) if should_force_search: # If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args return ForceUseTool(force_use=True, tool_name=tool_name, args=args) return ForceUseTool(force_use=False, tool_name=tool_name, args=args) ChatPacket = ( StreamingError | QADocsResponse | OnyxContexts | LLMRelevanceFilterResponse | FinalUsedContextDocsResponse | ChatMessageDetail | OnyxAnswerPiece | AllCitations | CitationInfo | FileChatDisplay | CustomToolResponse | MessageSpecificCitations | MessageResponseIDInfo | AgenticMessageResponseIDInfo | StreamStopInfo | AgentSearchPacket ) ChatPacketStream = Iterator[ChatPacket] def stream_chat_message_objects( new_msg_req: CreateChatMessageRequest, user: User | None, db_session: Session, # Needed to translate persona num_chunks to tokens to the LLM default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, # For flow with search, don't include as many chunks as possible since we need to leave space # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, # if specified, uses the last user message and does not create a new user message based # on the `new_msg_req.message`. Currently, requires a state where the last message is a litellm_additional_headers: dict[str, str] | None = None, custom_tool_additional_headers: dict[str, str] | None = None, is_connected: Callable[[], bool] | None = None, enforce_chat_session_id_for_search_docs: bool = True, bypass_acl: bool = False, include_contexts: bool = False, # a string which represents the history of a conversation. Used in cases like # Slack threads where the conversation cannot be represented by a chain of User/Assistant # messages. # NOTE: is not stored in the database at all. single_message_history: str | None = None, ) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run 2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on 3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails 4. [always] Details on the final AI response message that is created """ tenant_id = get_current_tenant_id() use_existing_user_message = new_msg_req.use_existing_user_message existing_assistant_message_id = new_msg_req.existing_assistant_message_id # Currently surrounding context is not supported for chat # Chat is already token heavy and harder for the model to process plus it would roll history over much faster new_msg_req.chunks_above = 0 new_msg_req.chunks_below = 0 llm: LLM try: user_id = user.id if user is not None else None chat_session = get_chat_session_by_id( chat_session_id=new_msg_req.chat_session_id, user_id=user_id, db_session=db_session, ) message_text = new_msg_req.message chat_session_id = new_msg_req.chat_session_id parent_id = new_msg_req.parent_message_id reference_doc_ids = new_msg_req.search_doc_ids retrieval_options = new_msg_req.retrieval_options alternate_assistant_id = new_msg_req.alternate_assistant_id # permanent "log" store, used primarily for debugging long_term_logger = LongTermLogger( metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)} ) if alternate_assistant_id is not None: # Allows users to specify a temporary persona (assistant) in the chat session # this takes highest priority since it's user specified persona = get_persona_by_id( alternate_assistant_id, user=user, db_session=db_session, is_for_edit=False, ) elif new_msg_req.persona_override_config: # Certain endpoints allow users to specify arbitrary persona settings # this should never conflict with the alternate_assistant_id persona = persona = create_temporary_persona( db_session=db_session, persona_config=new_msg_req.persona_override_config, user=user, ) else: persona = chat_session.persona if not persona: raise RuntimeError("No persona specified or found for chat session") multi_assistant_milestone, _is_new = create_milestone_if_not_exists( user=user, event_type=MilestoneRecordType.MULTIPLE_ASSISTANTS, db_session=db_session, ) update_user_assistant_milestone( milestone=multi_assistant_milestone, user_id=str(user.id) if user else NO_AUTH_USER_ID, assistant_id=persona.id, db_session=db_session, ) _, just_hit_multi_assistant_milestone = check_multi_assistant_milestone( milestone=multi_assistant_milestone, db_session=db_session, ) if just_hit_multi_assistant_milestone: mt_cloud_telemetry( distinct_id=tenant_id, event=MilestoneRecordType.MULTIPLE_ASSISTANTS, properties=None, ) # If a prompt override is specified via the API, use that with highest priority # but for saving it, we are just mapping it to an existing prompt prompt_id = new_msg_req.prompt_id if prompt_id is None and persona.prompts: prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id if reference_doc_ids is None and retrieval_options is None: raise RuntimeError( "Must specify a set of documents for chat or specify search options" ) try: llm, fast_llm = get_llms_for_persona( persona=persona, llm_override=new_msg_req.llm_override or chat_session.llm_override, additional_headers=litellm_additional_headers, long_term_logger=long_term_logger, ) except GenAIDisabledException: raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.") llm_provider = llm.config.model_provider llm_model_name = llm.config.model_name llm_tokenizer = get_tokenizer( model_name=llm_model_name, provider_type=llm_provider, ) llm_tokenizer_encode_func = cast( Callable[[str], list[int]], llm_tokenizer.encode ) search_settings = get_current_search_settings(db_session) document_index = get_default_document_index(search_settings, None) # Every chat Session begins with an empty root message root_message = get_or_create_root_message( chat_session_id=chat_session_id, db_session=db_session ) if parent_id is not None: parent_message = get_chat_message( chat_message_id=parent_id, user_id=user_id, db_session=db_session, ) else: parent_message = root_message user_message = None if new_msg_req.regenerate: final_msg, history_msgs = create_chat_chain( stop_at_message_id=parent_id, chat_session_id=chat_session_id, db_session=db_session, ) elif not use_existing_user_message: # Create new message at the right place in the tree and update the parent's child pointer # Don't commit yet until we verify the chat message chain user_message = create_new_chat_message( chat_session_id=chat_session_id, parent_message=parent_message, prompt_id=prompt_id, message=message_text, token_count=len(llm_tokenizer_encode_func(message_text)), message_type=MessageType.USER, files=None, # Need to attach later for optimization to only load files once in parallel db_session=db_session, commit=False, ) # re-create linear history of messages final_msg, history_msgs = create_chat_chain( chat_session_id=chat_session_id, db_session=db_session ) if final_msg.id != user_message.id: db_session.rollback() raise RuntimeError( "The new message was not on the mainline. " "Be sure to update the chat pointers before calling this." ) # NOTE: do not commit user message - it will be committed when the # assistant message is successfully generated else: # re-create linear history of messages final_msg, history_msgs = create_chat_chain( chat_session_id=chat_session_id, db_session=db_session ) if existing_assistant_message_id is None: if final_msg.message_type != MessageType.USER: raise RuntimeError( "The last message was not a user message. Cannot call " "`stream_chat_message_objects` with `is_regenerate=True` " "when the last message is not a user message." ) else: if final_msg.id != existing_assistant_message_id: raise RuntimeError( "The last message was not the existing assistant message. " f"Final message id: {final_msg.id}, " f"existing assistant message id: {existing_assistant_message_id}" ) # load all files needed for this chat chain in memory files = load_all_chat_files( history_msgs, new_msg_req.file_descriptors, db_session ) req_file_ids = [f["id"] for f in new_msg_req.file_descriptors] latest_query_files = [file for file in files if file.file_id in req_file_ids] if user_message: attach_files_to_chat_message( chat_message=user_message, files=[ new_file.to_file_descriptor() for new_file in latest_query_files ], db_session=db_session, commit=False, ) selected_db_search_docs = None selected_sections: list[InferenceSection] | None = None if reference_doc_ids: identifier_tuples = get_doc_query_identifiers_from_model( search_doc_ids=reference_doc_ids, chat_session=chat_session, user_id=user_id, db_session=db_session, enforce_chat_session_id_for_search_docs=enforce_chat_session_id_for_search_docs, ) # Generates full documents currently # May extend to use sections instead in the future selected_sections = inference_sections_from_ids( doc_identifiers=identifier_tuples, document_index=document_index, ) document_pruning_config = DocumentPruningConfig( is_manually_selected_docs=True ) # In case the search doc is deleted, just don't include it # though this should never happen db_search_docs_or_none = [ get_db_search_doc_by_id(doc_id=doc_id, db_session=db_session) for doc_id in reference_doc_ids ] selected_db_search_docs = [ db_sd for db_sd in db_search_docs_or_none if db_sd ] else: document_pruning_config = DocumentPruningConfig( max_chunks=int( persona.num_chunks if persona.num_chunks is not None else default_num_chunks ), max_window_percentage=max_document_percentage, ) # we don't need to reserve a message id if we're using an existing assistant message reserved_message_id = ( final_msg.id if existing_assistant_message_id is not None else reserve_message_id( db_session=db_session, chat_session_id=chat_session_id, parent_message=user_message.id if user_message is not None else parent_message.id, message_type=MessageType.ASSISTANT, ) ) yield MessageResponseIDInfo( user_message_id=user_message.id if user_message else None, reserved_assistant_message_id=reserved_message_id, ) overridden_model = ( new_msg_req.llm_override.model_version if new_msg_req.llm_override else None ) # Cannot determine these without the LLM step or breaking out early partial_response = partial( create_new_chat_message, chat_session_id=chat_session_id, # if we're using an existing assistant message, then this will just be an # update operation, in which case the parent should be the parent of # the latest. If we're creating a new assistant message, then the parent # should be the latest message (latest user message) parent_message=( final_msg if existing_assistant_message_id is None else parent_message ), prompt_id=prompt_id, overridden_model=overridden_model, # message=, # rephrased_query=, # token_count=, message_type=MessageType.ASSISTANT, alternate_assistant_id=new_msg_req.alternate_assistant_id, # error=, # reference_docs=, db_session=db_session, commit=False, reserved_message_id=reserved_message_id, is_agentic=new_msg_req.use_agentic_search, ) prompt_override = new_msg_req.prompt_override or chat_session.prompt_override if new_msg_req.persona_override_config: prompt_config = PromptConfig( system_prompt=new_msg_req.persona_override_config.prompts[ 0 ].system_prompt, task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt, datetime_aware=new_msg_req.persona_override_config.prompts[ 0 ].datetime_aware, include_citations=new_msg_req.persona_override_config.prompts[ 0 ].include_citations, ) elif prompt_override: if not final_msg.prompt: raise ValueError( "Prompt override cannot be applied, no base prompt found." ) prompt_config = PromptConfig.from_model( final_msg.prompt, prompt_override=prompt_override, ) elif final_msg.prompt: prompt_config = PromptConfig.from_model(final_msg.prompt) else: prompt_config = PromptConfig.from_model(persona.prompts[0]) answer_style_config = AnswerStyleConfig( citation_config=CitationConfig( all_docs_useful=selected_db_search_docs is not None ), document_pruning_config=document_pruning_config, structured_response_format=new_msg_req.structured_response_format, ) tool_dict = construct_tools( persona=persona, prompt_config=prompt_config, db_session=db_session, user=user, llm=llm, fast_llm=fast_llm, search_tool_config=SearchToolConfig( answer_style_config=answer_style_config, document_pruning_config=document_pruning_config, retrieval_options=retrieval_options or RetrievalDetails(), rerank_settings=new_msg_req.rerank_settings, selected_sections=selected_sections, chunks_above=new_msg_req.chunks_above, chunks_below=new_msg_req.chunks_below, full_doc=new_msg_req.full_doc, latest_query_files=latest_query_files, bypass_acl=bypass_acl, ), internet_search_tool_config=InternetSearchToolConfig( answer_style_config=answer_style_config, ), image_generation_tool_config=ImageGenerationToolConfig( additional_headers=litellm_additional_headers, ), custom_tool_config=CustomToolConfig( chat_session_id=chat_session_id, message_id=user_message.id if user_message else None, additional_headers=custom_tool_additional_headers, ), ) tools: list[Tool] = [] for tool_list in tool_dict.values(): tools.extend(tool_list) # TODO: unify message history with single message history message_history = [ PreviousMessage.from_chat_message(msg, files) for msg in history_msgs ] search_request = SearchRequest( query=final_msg.message, evaluation_type=( LLMEvaluationType.BASIC if persona.llm_relevance_filter else LLMEvaluationType.SKIP ), human_selected_filters=( retrieval_options.filters if retrieval_options else None ), persona=persona, offset=(retrieval_options.offset if retrieval_options else None), limit=retrieval_options.limit if retrieval_options else None, rerank_settings=new_msg_req.rerank_settings, chunks_above=new_msg_req.chunks_above, chunks_below=new_msg_req.chunks_below, full_doc=new_msg_req.full_doc, enable_auto_detect_filters=( retrieval_options.enable_auto_detect_filters if retrieval_options else None ), ) force_use_tool = _get_force_search_settings(new_msg_req, tools) prompt_builder = AnswerPromptBuilder( user_message=default_build_user_message( user_query=final_msg.message, prompt_config=prompt_config, files=latest_query_files, single_message_history=single_message_history, ), system_message=default_build_system_message(prompt_config, llm.config), message_history=message_history, llm_config=llm.config, raw_user_query=final_msg.message, raw_user_uploaded_files=latest_query_files or [], single_message_history=single_message_history, ) # LLM prompt building, response capturing, etc. answer = Answer( prompt_builder=prompt_builder, is_connected=is_connected, latest_query_files=latest_query_files, answer_style_config=answer_style_config, llm=( llm or get_main_llm_from_tuple( get_llms_for_persona( persona=persona, llm_override=( new_msg_req.llm_override or chat_session.llm_override ), additional_headers=litellm_additional_headers, ) ) ), fast_llm=fast_llm, force_use_tool=force_use_tool, search_request=search_request, chat_session_id=chat_session_id, current_agent_message_id=reserved_message_id, tools=tools, db_session=db_session, use_agentic_search=new_msg_req.use_agentic_search, ) # reference_db_search_docs = None # qa_docs_response = None # # any files to associate with the AI message e.g. dall-e generated images # ai_message_files = [] # dropped_indices = None # tool_result = None # TODO: different channels for stored info when it's coming from the agent flow info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict( lambda: AnswerPostInfo(ai_message_files=[]) ) refined_answer_improvement = True for packet in answer.processed_streamed_output: if isinstance(packet, ToolResponse): level, level_question_num = ( (packet.level, packet.level_question_num) if isinstance(packet, ExtendedToolResponse) else BASIC_KEY ) assert level is not None assert level_question_num is not None info = info_by_subq[ SubQuestionKey(level=level, question_num=level_question_num) ] # TODO: don't need to dedupe here when we do it in agent flow if packet.id == SEARCH_RESPONSE_SUMMARY_ID: ( info.qa_docs_response, info.reference_db_search_docs, info.dropped_indices, ) = _handle_search_tool_response_summary( packet=packet, db_session=db_session, selected_search_docs=selected_db_search_docs, # Deduping happens at the last step to avoid harming quality by dropping content early on dedupe_docs=( retrieval_options.dedupe_docs if retrieval_options else False ), ) yield info.qa_docs_response elif packet.id == SECTION_RELEVANCE_LIST_ID: relevance_sections = packet.response if info.reference_db_search_docs is None: logger.warning( "No reference docs found for relevance filtering" ) continue llm_indices = relevant_sections_to_indices( relevance_sections=relevance_sections, items=[ translate_db_search_doc_to_server_search_doc(doc) for doc in info.reference_db_search_docs ], ) if info.dropped_indices: llm_indices = drop_llm_indices( llm_indices=llm_indices, search_docs=info.reference_db_search_docs, dropped_indices=info.dropped_indices, ) yield LLMRelevanceFilterResponse( llm_selected_doc_indices=llm_indices ) elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID: yield FinalUsedContextDocsResponse( final_context_docs=packet.response ) elif packet.id == IMAGE_GENERATION_RESPONSE_ID: img_generation_response = cast( list[ImageGenerationResponse], packet.response ) file_ids = save_files( urls=[img.url for img in img_generation_response if img.url], base64_files=[ img.image_data for img in img_generation_response if img.image_data ], ) info.ai_message_files.extend( [ FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE) for file_id in file_ids ] ) yield FileChatDisplay( file_ids=[str(file_id) for file_id in file_ids] ) elif packet.id == INTERNET_SEARCH_RESPONSE_ID: ( info.qa_docs_response, info.reference_db_search_docs, ) = _handle_internet_search_tool_response_summary( packet=packet, db_session=db_session, ) yield info.qa_docs_response elif packet.id == CUSTOM_TOOL_RESPONSE_ID: custom_tool_response = cast(CustomToolCallSummary, packet.response) if ( custom_tool_response.response_type == "image" or custom_tool_response.response_type == "csv" ): file_ids = custom_tool_response.tool_result.file_ids info.ai_message_files.extend( [ FileDescriptor( id=str(file_id), type=( ChatFileType.IMAGE if custom_tool_response.response_type == "image" else ChatFileType.CSV ), ) for file_id in file_ids ] ) yield FileChatDisplay( file_ids=[str(file_id) for file_id in file_ids] ) else: yield CustomToolResponse( response=custom_tool_response.tool_result, tool_name=custom_tool_response.tool_name, ) elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts: yield cast(OnyxContexts, packet.response) elif isinstance(packet, StreamStopInfo): if packet.stop_reason == StreamStopReason.FINISHED: yield packet elif isinstance(packet, RefinedAnswerImprovement): refined_answer_improvement = packet.refined_answer_improvement yield packet else: if isinstance(packet, ToolCallFinalResult): level, level_question_num = ( (packet.level, packet.level_question_num) if packet.level is not None and packet.level_question_num is not None else BASIC_KEY ) info = info_by_subq[ SubQuestionKey(level=level, question_num=level_question_num) ] info.tool_result = packet yield cast(ChatPacket, packet) logger.debug("Reached end of stream") except ValueError as e: logger.exception("Failed to process chat message.") error_msg = str(e) yield StreamingError(error=error_msg) db_session.rollback() return except Exception as e: logger.exception(f"Failed to process chat message due to {e}") error_msg = str(e) stack_trace = traceback.format_exc() if isinstance(e, ToolCallException): yield StreamingError(error=error_msg, stack_trace=stack_trace) else: if llm: client_error_msg = litellm_exception_to_error_msg(e, llm) if llm.config.api_key and len(llm.config.api_key) > 2: error_msg = error_msg.replace( llm.config.api_key, "[REDACTED_API_KEY]" ) stack_trace = stack_trace.replace( llm.config.api_key, "[REDACTED_API_KEY]" ) yield StreamingError(error=client_error_msg, stack_trace=stack_trace) db_session.rollback() return # Post-LLM answer processing try: tool_name_to_tool_id: dict[str, int] = {} for tool_id, tool_list in tool_dict.items(): for tool in tool_list: tool_name_to_tool_id[tool.name] = tool_id subq_citations = answer.citations_by_subquestion() for subq_key in subq_citations: info = info_by_subq[subq_key] logger.debug("Post-LLM answer processing") if info.reference_db_search_docs: info.message_specific_citations = _translate_citations( citations_list=subq_citations[subq_key], db_docs=info.reference_db_search_docs, ) # TODO: AllCitations should contain subq info? if not answer.is_cancelled(): yield AllCitations(citations=subq_citations[subq_key]) # Saving Gen AI answer and responding with message info basic_key = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1]) info = ( info_by_subq[basic_key] if basic_key in info_by_subq else info_by_subq[ SubQuestionKey( level=AGENT_SEARCH_INITIAL_KEY[0], question_num=AGENT_SEARCH_INITIAL_KEY[1], ) ] ) gen_ai_response_message = partial_response( message=answer.llm_answer, rephrased_query=( info.qa_docs_response.rephrased_query if info.qa_docs_response else None ), reference_docs=info.reference_db_search_docs, files=info.ai_message_files, token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), citations=( info.message_specific_citations.citation_map if info.message_specific_citations else None ), error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None, tool_call=( ToolCall( tool_id=tool_name_to_tool_id[info.tool_result.tool_name], tool_name=info.tool_result.tool_name, tool_arguments=info.tool_result.tool_args, tool_result=info.tool_result.tool_result, ) if info.tool_result else None ), ) # add answers for levels >= 1, where each level has the previous as its parent. Use # the answer_by_level method in answer.py to get the answers for each level next_level = 1 prev_message = gen_ai_response_message agent_answers = answer.llm_answer_by_level() agentic_message_ids = [] while next_level in agent_answers: next_answer = agent_answers[next_level] info = info_by_subq[ SubQuestionKey( level=next_level, question_num=AGENT_SEARCH_INITIAL_KEY[1] ) ] next_answer_message = create_new_chat_message( chat_session_id=chat_session_id, parent_message=prev_message, message=next_answer, prompt_id=None, token_count=len(llm_tokenizer_encode_func(next_answer)), message_type=MessageType.ASSISTANT, db_session=db_session, files=info.ai_message_files, reference_docs=info.reference_db_search_docs, citations=info.message_specific_citations.citation_map if info.message_specific_citations else None, error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None, refined_answer_improvement=refined_answer_improvement, is_agentic=True, ) agentic_message_ids.append( AgentMessageIDInfo(level=next_level, message_id=next_answer_message.id) ) next_level += 1 prev_message = next_answer_message logger.debug("Committing messages") db_session.commit() # actually save user / assistant message yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids) yield translate_db_message_to_chat_message_detail(gen_ai_response_message) except Exception as e: error_msg = str(e) logger.exception(error_msg) # Frontend will erase whatever answer and show this instead yield StreamingError(error="Failed to parse LLM output") @log_generator_function_time() def stream_chat_message( new_msg_req: CreateChatMessageRequest, user: User | None, litellm_additional_headers: dict[str, str] | None = None, custom_tool_additional_headers: dict[str, str] | None = None, is_connected: Callable[[], bool] | None = None, ) -> Iterator[str]: with get_session_context_manager() as db_session: objects = stream_chat_message_objects( new_msg_req=new_msg_req, user=user, db_session=db_session, litellm_additional_headers=litellm_additional_headers, custom_tool_additional_headers=custom_tool_additional_headers, is_connected=is_connected, ) for obj in objects: yield get_json_line(obj.model_dump()) @log_function_time() def gather_stream_for_slack( packets: ChatPacketStream, ) -> ChatOnyxBotResponse: response = ChatOnyxBotResponse() answer = "" for packet in packets: if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece: answer += packet.answer_piece elif isinstance(packet, QADocsResponse): response.docs = packet elif isinstance(packet, StreamingError): response.error_msg = packet.error elif isinstance(packet, ChatMessageDetail): response.chat_message_id = packet.message_id elif isinstance(packet, LLMRelevanceFilterResponse): response.llm_selected_doc_indices = packet.llm_selected_doc_indices elif isinstance(packet, AllCitations): response.citations = packet.citations if answer: response.answer = answer return response