from collections import defaultdict from collections.abc import Callable from uuid import UUID from sqlalchemy.orm import Session from onyx.agents.agent_search.models import GraphConfig from onyx.agents.agent_search.models import GraphInputs from onyx.agents.agent_search.models import GraphPersistence from onyx.agents.agent_search.models import GraphSearchConfig from onyx.agents.agent_search.models import GraphTooling from onyx.agents.agent_search.run_graph import run_basic_graph from onyx.agents.agent_search.run_graph import run_main_graph from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import AnswerPacket from onyx.chat.models import AnswerStream from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationInfo from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason from onyx.chat.models import SubQuestionKey from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.configs.constants import BASIC_KEY from onyx.context.search.models import SearchRequest from onyx.file_store.utils import InMemoryChatFile from onyx.llm.interfaces import LLM from onyx.tools.force import ForceUseTool from onyx.tools.tool import Tool from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.utils import explicit_tool_calling_supported from onyx.utils.gpu_utils import gpu_status_request from onyx.utils.logger import setup_logger logger = setup_logger() BASIC_SQ_KEY = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1]) class Answer: def __init__( self, prompt_builder: AnswerPromptBuilder, answer_style_config: AnswerStyleConfig, llm: LLM, fast_llm: LLM, force_use_tool: ForceUseTool, search_request: SearchRequest, chat_session_id: UUID, current_agent_message_id: int, db_session: Session, # newly passed in files to include as part of this question # TODO THIS NEEDS TO BE HANDLED latest_query_files: list[InMemoryChatFile] | None = None, tools: list[Tool] | None = None, # NOTE: for native tool-calling, this is only supported by OpenAI atm, # but we only support them anyways # if set to True, then never use the LLMs provided tool-calling functonality skip_explicit_tool_calling: bool = False, skip_gen_ai_answer_generation: bool = False, is_connected: Callable[[], bool] | None = None, use_agentic_search: bool = False, ) -> None: self.is_connected: Callable[[], bool] | None = is_connected self._processed_stream: (list[AnswerPacket] | None) = None self._is_cancelled = False search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)] search_tool: SearchTool | None = None if len(search_tools) > 1: # TODO: handle multiple search tools raise ValueError("Multiple search tools found") elif len(search_tools) == 1: search_tool = search_tools[0] using_tool_calling_llm = ( explicit_tool_calling_supported( llm.config.model_provider, llm.config.model_name ) and not skip_explicit_tool_calling ) rerank_settings = search_request.rerank_settings using_cloud_reranking = ( rerank_settings is not None and rerank_settings.rerank_provider_type is not None ) allow_agent_reranking = gpu_status_request() or using_cloud_reranking # TODO: this is a hack to force the query to be used for the search tool # this should be removed once we fully unify graph inputs (i.e. # remove SearchQuery entirely) if ( force_use_tool.force_use and search_tool and force_use_tool.args and force_use_tool.tool_name == search_tool.name and QUERY_FIELD in force_use_tool.args ): search_request.query = force_use_tool.args[QUERY_FIELD] self.graph_inputs = GraphInputs( search_request=search_request, prompt_builder=prompt_builder, files=latest_query_files, structured_response_format=answer_style_config.structured_response_format, ) self.graph_tooling = GraphTooling( primary_llm=llm, fast_llm=fast_llm, search_tool=search_tool, tools=tools or [], force_use_tool=force_use_tool, using_tool_calling_llm=using_tool_calling_llm, ) self.graph_persistence = GraphPersistence( db_session=db_session, chat_session_id=chat_session_id, message_id=current_agent_message_id, ) self.search_behavior_config = GraphSearchConfig( use_agentic_search=use_agentic_search, skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, allow_refinement=True, allow_agent_reranking=allow_agent_reranking, ) self.graph_config = GraphConfig( inputs=self.graph_inputs, tooling=self.graph_tooling, persistence=self.graph_persistence, behavior=self.search_behavior_config, ) @property def processed_streamed_output(self) -> AnswerStream: if self._processed_stream is not None: yield from self._processed_stream return run_langgraph = ( run_main_graph if self.graph_config.behavior.use_agentic_search else run_basic_graph ) stream = run_langgraph( self.graph_config, ) processed_stream = [] for packet in stream: if self.is_cancelled(): packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED) yield packet break processed_stream.append(packet) yield packet self._processed_stream = processed_stream @property def llm_answer(self) -> str: answer = "" for packet in self.processed_streamed_output: # handle basic answer flow, plus level 0 agent answer flow # since level 0 is the first answer the user sees and therefore the # child message of the user message in the db (so it is handled # like a basic flow answer) if (isinstance(packet, OnyxAnswerPiece) and packet.answer_piece) or ( isinstance(packet, AgentAnswerPiece) and packet.answer_piece and packet.answer_type == "agent_level_answer" and packet.level == 0 ): answer += packet.answer_piece return answer def llm_answer_by_level(self) -> dict[int, str]: answer_by_level: dict[int, str] = defaultdict(str) for packet in self.processed_streamed_output: if ( isinstance(packet, AgentAnswerPiece) and packet.answer_piece and packet.answer_type == "agent_level_answer" ): assert packet.level is not None answer_by_level[packet.level] += packet.answer_piece elif isinstance(packet, OnyxAnswerPiece) and packet.answer_piece: answer_by_level[BASIC_KEY[0]] += packet.answer_piece return answer_by_level @property def citations(self) -> list[CitationInfo]: citations: list[CitationInfo] = [] for packet in self.processed_streamed_output: if isinstance(packet, CitationInfo) and packet.level is None: citations.append(packet) return citations def citations_by_subquestion(self) -> dict[SubQuestionKey, list[CitationInfo]]: citations_by_subquestion: dict[ SubQuestionKey, list[CitationInfo] ] = defaultdict(list) basic_subq_key = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1]) for packet in self.processed_streamed_output: if isinstance(packet, CitationInfo): if packet.level_question_num is not None and packet.level is not None: citations_by_subquestion[ SubQuestionKey( level=packet.level, question_num=packet.level_question_num ) ].append(packet) elif packet.level is None: citations_by_subquestion[basic_subq_key].append(packet) return citations_by_subquestion def is_cancelled(self) -> bool: if self._is_cancelled: return True if self.is_connected is not None: if not self.is_connected(): logger.debug("Answer stream has been cancelled") self._is_cancelled = not self.is_connected() return self._is_cancelled