From 6c7f8eaefb36c01d9fba1d8cca6fab4364936b7f Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Wed, 29 Jan 2025 14:28:46 -0800 Subject: [PATCH] first pass at dead code deletion --- .../main__graph/nodes/agent_logging.py | 2 +- backend/onyx/agents/agent_search/models.py | 15 +- .../orchestration/nodes/tool_call.py | 1 - backend/onyx/agents/agent_search/run_graph.py | 52 +++-- .../agent_search/shared_graph_utils/utils.py | 2 +- backend/onyx/chat/answer.py | 177 ++---------------- .../answer_response_handler.py | 103 ---------- .../stream_processing/citation_processing.py | 159 ---------------- backend/onyx/configs/constants.py | 1 + backend/tests/unit/onyx/chat/conftest.py | 45 +---- backend/tests/unit/onyx/chat/test_answer.py | 1 + .../tests/unit/onyx/chat/test_skip_gen_ai.py | 4 +- 12 files changed, 68 insertions(+), 494 deletions(-) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main__graph/nodes/agent_logging.py b/backend/onyx/agents/agent_search/deep_search_a/main__graph/nodes/agent_logging.py index a76a84e95..8ec12ea9d 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main__graph/nodes/agent_logging.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main__graph/nodes/agent_logging.py @@ -80,7 +80,7 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput: agent_metrics=combined_agent_metrics, ) - if agent_a_config.use_persistence: + if agent_a_config.use_agentic_persistence: # Persist the sub-answer in the database db_session = agent_a_config.db_session chat_session_id = agent_a_config.chat_session_id diff --git a/backend/onyx/agents/agent_search/models.py b/backend/onyx/agents/agent_search/models.py index f437fb4f2..517d66d74 100644 --- a/backend/onyx/agents/agent_search/models.py +++ b/backend/onyx/agents/agent_search/models.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from uuid import UUID from pydantic import BaseModel @@ -15,8 +14,7 @@ from onyx.tools.tool import Tool from onyx.tools.tool_implementations.search.search_tool import SearchTool -@dataclass -class AgentSearchConfig: +class AgentSearchConfig(BaseModel): """ Configuration for the Agent Search feature. """ @@ -47,10 +45,10 @@ class AgentSearchConfig: # The message ID of the user message that triggered the Pro Search message_id: int | None = None - # Whether to persistence data for the Pro Search (turned off for testing) - use_persistence: bool = True + # Whether to persistence data for Agentic Search (turned off for testing) + use_agentic_persistence: bool = True - # The database session for the Pro Search + # The database session for Agentic Search db_session: Session | None = None # Whether to perform initial search to inform decomposition @@ -75,7 +73,7 @@ class AgentSearchConfig: @model_validator(mode="after") def validate_db_session(self) -> "AgentSearchConfig": - if self.use_persistence and self.db_session is None: + if self.use_agentic_persistence and self.db_session is None: raise ValueError( "db_session must be provided for pro search when using persistence" ) @@ -87,6 +85,9 @@ class AgentSearchConfig: raise ValueError("search_tool must be provided for agentic search") return self + class Config: + arbitrary_types_allowed = True + class AgentDocumentCitations(BaseModel): document_id: str diff --git a/backend/onyx/agents/agent_search/orchestration/nodes/tool_call.py b/backend/onyx/agents/agent_search/orchestration/nodes/tool_call.py index bf238ca13..5b35f0faa 100644 --- a/backend/onyx/agents/agent_search/orchestration/nodes/tool_call.py +++ b/backend/onyx/agents/agent_search/orchestration/nodes/tool_call.py @@ -38,7 +38,6 @@ def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate tool_runner = ToolRunner(tool, tool_args) tool_kickoff = tool_runner.kickoff() - # TODO: custom events for yields emit_packet(tool_kickoff) tool_responses = [] diff --git a/backend/onyx/agents/agent_search/run_graph.py b/backend/onyx/agents/agent_search/run_graph.py index 4b1632693..b78c9f4f0 100644 --- a/backend/onyx/agents/agent_search/run_graph.py +++ b/backend/onyx/agents/agent_search/run_graph.py @@ -155,26 +155,19 @@ def run_graph( # TODO: call this once on startup, TBD where and if it should be gated based # on dev mode or not -def load_compiled_graph(graph_name: str) -> CompiledStateGraph: - main_graph_builder = ( - main_graph_builder_a if graph_name == "a" else main_graph_builder_a - ) +def load_compiled_graph() -> CompiledStateGraph: global _COMPILED_GRAPH if _COMPILED_GRAPH is None: - graph = main_graph_builder() + graph = main_graph_builder_a() _COMPILED_GRAPH = graph.compile() return _COMPILED_GRAPH def run_main_graph( config: AgentSearchConfig, - graph_name: str = "a", ) -> AnswerStream: - compiled_graph = load_compiled_graph(graph_name) - if graph_name == "a": - input = MainInput_a(base_question=config.search_request.query, log_messages=[]) - else: - input = MainInput_a(base_question=config.search_request.query, log_messages=[]) + compiled_graph = load_compiled_graph() + input = MainInput_a(base_question=config.search_request.query, log_messages=[]) # Agent search is not a Tool per se, but this is helpful for the frontend yield ToolCallKickoff( @@ -202,11 +195,36 @@ if __name__ == "__main__": now_start = datetime.now() logger.debug(f"Start at {now_start}") - if GRAPH_VERSION_NAME == "a": - graph = main_graph_builder_a() - else: - graph = main_graph_builder_a() - compiled_graph = graph.compile() + graph = main_graph_builder_a() + now_start = datetime.now() + compiled_graph = graph.compile() + now_end = datetime.now() + logger.debug(f"Graph compiled in {now_end - now_start} seconds") + primary_llm, fast_llm = get_default_llms() + search_request = SearchRequest( + # query="what can you do with gitlab?", + # query="What are the guiding principles behind the development of cockroachDB", + # query="What are the temperatures in Munich, Hawaii, and New York?", + # query="When was Washington born?", + # query="What is Onyx?", + # query="What is the difference between astronomy and astrology?", + query="Do a search to tell me what is the difference between astronomy and astrology?", + ) + # Joachim custom persona + + with get_session_context_manager() as db_session: + config, search_tool = get_test_config( + db_session, primary_llm, fast_llm, search_request + ) + # search_request.persona = get_persona_by_id(1, None, db_session) + config.use_agentic_persistence = True + # config.perform_initial_search_path_decision = False + config.perform_initial_search_decomposition = True + + input = MainInput_a( + base_question=config.search_request.query, log_messages=[] + ) + now_end = datetime.now() logger.debug(f"Graph compiled in {now_end - now_start} seconds") primary_llm, fast_llm = get_default_llms() @@ -226,7 +244,7 @@ if __name__ == "__main__": db_session, primary_llm, fast_llm, search_request ) # search_request.persona = get_persona_by_id(1, None, db_session) - config.use_persistence = True + config.use_agentic_persistence = True # config.perform_initial_search_path_decision = False config.perform_initial_search_decomposition = True if GRAPH_VERSION_NAME == "a": diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py index f9cf6dd18..49200279c 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py @@ -248,7 +248,7 @@ def get_test_config( chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim # chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan message_id=1, - use_persistence=True, + use_agentic_persistence=True, db_session=db_session, tools=[search_tool], use_agentic_search=use_agentic_search, diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index fac0f744b..16c78b838 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -54,6 +54,7 @@ class Answer: is_connected: Callable[[], bool] | None = None, db_session: Session | None = None, use_agentic_search: bool = False, + use_agentic_persistence: bool = True, ) -> None: self.is_connected: Callable[[], bool] | None = is_connected @@ -90,14 +91,9 @@ class Answer: if len(search_tools) > 1: # TODO: handle multiple search tools - logger.warning("Multiple search tools found, using first one") - search_tool = search_tools[0] + raise ValueError("Multiple search tools found") elif len(search_tools) == 1: search_tool = search_tools[0] - else: - logger.warning("No search tool found") - if use_agentic_search: - raise ValueError("No search tool found, cannot use agentic search") using_tool_calling_llm = explicit_tool_calling_supported( llm.config.model_provider, llm.config.model_name @@ -111,7 +107,7 @@ class Answer: use_agentic_search=use_agentic_search, chat_session_id=chat_session_id, message_id=current_agent_message_id, - use_persistence=True, + use_agentic_persistence=use_agentic_persistence, allow_refinement=True, db_session=db_session, prompt_builder=prompt_builder, @@ -137,104 +133,20 @@ class Answer: logger.info(f"Forcefully using tool='{tool.name}'{args_str}") return [tool] - # TODO: delete the function and move the full body to processed_streamed_output - def _get_response(self) -> AnswerStream: - # current_llm_call = llm_calls[-1] + @property + def processed_streamed_output(self) -> AnswerStream: + if self._processed_stream is not None: + yield from self._processed_stream + return - # tool, tool_args = None, None - # # handle the case where no decision has to be made; we simply run the tool - # if ( - # current_llm_call.force_use_tool.force_use - # and current_llm_call.force_use_tool.args is not None - # ): - # tool_name, tool_args = ( - # current_llm_call.force_use_tool.tool_name, - # current_llm_call.force_use_tool.args, - # ) - # tool = get_tool_by_name(current_llm_call.tools, tool_name) - - # # special pre-logic for non-tool calling LLM case - # elif not self.using_tool_calling_llm and current_llm_call.tools: - # chosen_tool_and_args = ( - # ToolResponseHandler.get_tool_call_for_non_tool_calling_llm( - # current_llm_call, self.llm - # ) - # ) - # if chosen_tool_and_args: - # tool, tool_args = chosen_tool_and_args - - # if tool and tool_args: - # dummy_tool_call_chunk = AIMessageChunk(content="") - # dummy_tool_call_chunk.tool_calls = [ - # ToolCall(name=tool.name, args=tool_args, id=str(uuid4())) - # ] - - # response_handler_manager = LLMResponseHandlerManager( - # ToolResponseHandler([tool]), None, self.is_cancelled - # ) - # yield from response_handler_manager.handle_llm_response( - # iter([dummy_tool_call_chunk]) - # ) - - # tmp_call = response_handler_manager.next_llm_call(current_llm_call) - # if tmp_call is None: - # return # no more LLM calls to process - # current_llm_call = tmp_call - - # # if we're skipping gen ai answer generation, we should break - # # out unless we're forcing a tool call. If we don't, we might generate an - # # answer, which is a no-no! - # if ( - # self.skip_gen_ai_answer_generation - # and not current_llm_call.force_use_tool.force_use - # ): - # return - - # # set up "handlers" to listen to the LLM response stream and - # # feed back the processed results + handle tool call requests - # # + figure out what the next LLM call should be - # tool_call_handler = ToolResponseHandler(current_llm_call.tools) - - # final_search_results, displayed_search_results = SearchTool.get_search_result( - # current_llm_call - # ) or ([], []) - - # # NEXT: we still want to handle the LLM response stream, but it is now: - # # 1. handle the tool call requests - # # 2. feed back the processed results - # # 3. handle the citations - - # answer_handler = CitationResponseHandler( - # context_docs=final_search_results, - # final_doc_id_to_rank_map=map_document_id_order(final_search_results), - # display_doc_id_to_rank_map=map_document_id_order(displayed_search_results), - # ) - - # # At the moment, this wrapper class passes streamed stuff through citation and tool handlers. - # # In the future, we'll want to handle citations and tool calls in the langgraph graph. - # response_handler_manager = LLMResponseHandlerManager( - # tool_call_handler, answer_handler, self.is_cancelled - # ) - - # In langgraph, whether we do the basic thing (call llm stream) or pro search - # is based on a flag in the pro search config - - if self.agent_search_config.use_agentic_search: - if ( - self.agent_search_config.db_session is None - and self.agent_search_config.use_persistence - ): - raise ValueError( - "db_session must be provided for pro search when using persistence" - ) - - stream = run_main_graph( - config=self.agent_search_config, - ) - else: - stream = run_basic_graph( - config=self.agent_search_config, - ) + run_langgraph = ( + run_main_graph + if self.agent_search_config.use_agentic_search + else run_basic_graph + ) + stream = run_langgraph( + self.agent_search_config, + ) processed_stream = [] for packet in stream: @@ -244,62 +156,6 @@ class Answer: break processed_stream.append(packet) yield packet - self._processed_stream = processed_stream - return - # DEBUG: good breakpoint - # stream = self.llm.stream( - # # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM - # # may choose to not call any tools and just generate the answer, in which case the task prompt is needed. - # prompt=current_llm_call.prompt_builder.build(), - # tools=[tool.tool_definition() for tool in current_llm_call.tools] or None, - # tool_choice=( - # "required" - # if current_llm_call.tools and current_llm_call.force_use_tool.force_use - # else None - # ), - # structured_response_format=self.answer_style_config.structured_response_format, - # ) - # yield from response_handler_manager.handle_llm_response(stream) - - # new_llm_call = response_handler_manager.next_llm_call(current_llm_call) - # if new_llm_call: - # yield from self._get_response(llm_calls + [new_llm_call]) - - @property - def processed_streamed_output(self) -> AnswerStream: - if self._processed_stream is not None: - yield from self._processed_stream - return - - # prompt_builder = AnswerPromptBuilder( - # user_message=default_build_user_message( - # user_query=self.question, - # prompt_config=self.prompt_config, - # files=self.latest_query_files, - # single_message_history=self.single_message_history, - # ), - # message_history=self.message_history, - # llm_config=self.llm.config, - # raw_user_query=self.question, - # raw_user_uploaded_files=self.latest_query_files or [], - # single_message_history=self.single_message_history, - # ) - # prompt_builder.update_system_prompt( - # default_build_system_message(self.prompt_config) - # ) - # llm_call = LLMCall( - # prompt_builder=prompt_builder, - # tools=self._get_tools_list(), - # force_use_tool=self.force_use_tool, - # files=self.latest_query_files, - # tool_call_info=[], - # using_tool_calling_llm=self.using_tool_calling_llm, - # ) - - processed_stream = [] - for processed_packet in self._get_response(): - processed_stream.append(processed_packet) - yield processed_packet self._processed_stream = processed_stream @@ -343,6 +199,7 @@ class Answer: return citations + # TODO: replace tuple of ints with SubQuestionId EVERYWHERE def citations_by_subquestion(self) -> dict[tuple[int, int], list[CitationInfo]]: citations_by_subquestion: dict[ tuple[int, int], list[CitationInfo] diff --git a/backend/onyx/chat/stream_processing/answer_response_handler.py b/backend/onyx/chat/stream_processing/answer_response_handler.py index 055011ae3..59bfa2c8c 100644 --- a/backend/onyx/chat/stream_processing/answer_response_handler.py +++ b/backend/onyx/chat/stream_processing/answer_response_handler.py @@ -1,7 +1,5 @@ import abc from collections.abc import Generator -from typing import Any -from typing import cast from langchain_core.messages import BaseMessage @@ -26,10 +24,6 @@ class AnswerResponseHandler(abc.ABC): ) -> Generator[ResponsePart, None, None]: raise NotImplementedError - @abc.abstractmethod - def update(self, state_update: Any) -> None: - raise NotImplementedError - class PassThroughAnswerResponseHandler(AnswerResponseHandler): def handle_response_part( @@ -40,9 +34,6 @@ class PassThroughAnswerResponseHandler(AnswerResponseHandler): content = _message_to_str(response_item) yield OnyxAnswerPiece(answer_piece=content) - def update(self, state_update: Any) -> None: - pass - class DummyAnswerResponseHandler(AnswerResponseHandler): def handle_response_part( @@ -53,9 +44,6 @@ class DummyAnswerResponseHandler(AnswerResponseHandler): # This is a dummy handler that returns nothing yield from [] - def update(self, state_update: Any) -> None: - pass - class CitationResponseHandler(AnswerResponseHandler): def __init__( @@ -91,20 +79,6 @@ class CitationResponseHandler(AnswerResponseHandler): # Process the new content through the citation processor yield from self.citation_processor.process_token(content) - def update(self, state_update: Any) -> None: - state = cast( - tuple[list[LlmDoc], DocumentIdOrderMapping, DocumentIdOrderMapping], - state_update, - ) - self.context_docs = state[0] - self.final_doc_id_to_rank_map = state[1] - self.display_doc_id_to_rank_map = state[2] - self.citation_processor = CitationProcessor( - context_docs=self.context_docs, - final_doc_id_to_rank_map=self.final_doc_id_to_rank_map, - display_doc_id_to_rank_map=self.display_doc_id_to_rank_map, - ) - def _message_to_str(message: BaseMessage | str | None) -> str: if message is None: @@ -116,80 +90,3 @@ def _message_to_str(message: BaseMessage | str | None) -> str: logger.warning(f"Received non-string content: {type(content)}") content = str(content) if content is not None else "" return content - - -# class CitationMultiResponseHandler(AnswerResponseHandler): -# def __init__(self) -> None: -# self.channel_processors: dict[str, CitationProcessor] = {} -# self._default_channel = "__default__" - -# def register_default_channel( -# self, -# context_docs: list[LlmDoc], -# final_doc_id_to_rank_map: DocumentIdOrderMapping, -# display_doc_id_to_rank_map: DocumentIdOrderMapping, -# ) -> None: -# """Register the default channel with its associated documents and ranking maps.""" -# self.register_channel( -# channel_id=self._default_channel, -# context_docs=context_docs, -# final_doc_id_to_rank_map=final_doc_id_to_rank_map, -# display_doc_id_to_rank_map=display_doc_id_to_rank_map, -# ) - -# def register_channel( -# self, -# channel_id: str, -# context_docs: list[LlmDoc], -# final_doc_id_to_rank_map: DocumentIdOrderMapping, -# display_doc_id_to_rank_map: DocumentIdOrderMapping, -# ) -> None: -# """Register a new channel with its associated documents and ranking maps.""" -# self.channel_processors[channel_id] = CitationProcessor( -# context_docs=context_docs, -# final_doc_id_to_rank_map=final_doc_id_to_rank_map, -# display_doc_id_to_rank_map=display_doc_id_to_rank_map, -# ) - -# def handle_response_part( -# self, -# response_item: BaseMessage | str | None, -# previous_response_items: list[BaseMessage | str], -# ) -> Generator[ResponsePart, None, None]: -# """Default implementation that uses the default channel.""" - -# yield from self.handle_channel_response( -# response_item=content, -# previous_response_items=previous_response_items, -# channel_id=self._default_channel, -# ) - -# def handle_channel_response( -# self, -# response_item: ResponsePart | str | None, -# previous_response_items: list[ResponsePart | str], -# channel_id: str, -# ) -> Generator[ResponsePart, None, None]: -# """Process a response part for a specific channel.""" -# if channel_id not in self.channel_processors: -# raise ValueError(f"Attempted to process response for unregistered channel {channel_id}") - -# if response_item is None: -# return - -# content = ( -# response_item.content if isinstance(response_item, BaseMessage) else response_item -# ) - -# # Ensure content is a string -# if not isinstance(content, str): -# logger.warning(f"Received non-string content: {type(content)}") -# content = str(content) if content is not None else "" - -# # Process the new content through the channel's citation processor -# yield from self.channel_processors[channel_id].multi_process_token(content) - -# def remove_channel(self, channel_id: str) -> None: -# """Remove a channel and its associated processor.""" -# if channel_id in self.channel_processors: -# del self.channel_processors[channel_id] diff --git a/backend/onyx/chat/stream_processing/citation_processing.py b/backend/onyx/chat/stream_processing/citation_processing.py index 6f844646a..071b28c34 100644 --- a/backend/onyx/chat/stream_processing/citation_processing.py +++ b/backend/onyx/chat/stream_processing/citation_processing.py @@ -4,7 +4,6 @@ from collections.abc import Generator from onyx.chat.models import CitationInfo from onyx.chat.models import LlmDoc from onyx.chat.models import OnyxAnswerPiece -from onyx.chat.models import ResponsePart from onyx.chat.stream_processing.utils import DocumentIdOrderMapping from onyx.configs.chat_configs import STOP_STREAM_PAT from onyx.prompts.constants import TRIPLE_BACKTICK @@ -41,164 +40,6 @@ class CitationProcessor: self.current_citations: list[int] = [] self.past_cite_count = 0 - # TODO: should reference previous citation processing, rework previous, or completely use new one? - def multi_process_token( - self, parsed_object: ResponsePart - ) -> Generator[ResponsePart, None, None]: - # if isinstance(parsed_object,OnyxAnswerPiece): - # # standard citation processing - # yield from self.process_token(parsed_object.answer_piece) - - # elif isinstance(parsed_object, AgentAnswerPiece): - # # citation processing for agent answer pieces - # for token in self.process_token(parsed_object.answer_piece): - # if isinstance(token, CitationInfo): - # yield token - # else: - # yield AgentAnswerPiece(answer_piece=token.answer_piece or '', - # answer_type=parsed_object.answer_type, level=parsed_object.level, - # level_question_nr=parsed_object.level_question_nr) - - # level = getattr(parsed_object, "level", None) - # level_question_nr = getattr(parsed_object, "level_question_nr", None) - - # if isinstance(parsed_object, (AgentAnswerPiece, OnyxAnswerPiece)): - # # logger.debug(f"FA {parsed_object.answer_piece}") - # if isinstance(parsed_object, AgentAnswerPiece): - # token = parsed_object.answer_piece - # level = parsed_object.level - # level_question_nr = parsed_object.level_question_nr - # else: - # yield parsed_object - # return - # # raise ValueError( - # # f"Invalid parsed object type: {type(parsed_object)}" - # # ) - - # if not citation_potential[level][level_question_nr] and token: - # if token.startswith(" ["): - # citation_potential[level][level_question_nr] = True - # current_yield_components[level][level_question_nr] = [token] - # else: - # yield parsed_object - # elif token and citation_potential[level][level_question_nr]: - # current_yield_components[level][level_question_nr].append(token) - # current_yield_str[level][level_question_nr] = "".join( - # current_yield_components[level][level_question_nr] - # ) - - # if current_yield_str[level][level_question_nr].strip().startswith( - # "[D" - # ) or current_yield_str[level][level_question_nr].strip().startswith( - # "[Q" - # ): - # citation_potential[level][level_question_nr] = True - - # else: - # citation_potential[level][level_question_nr] = False - # parsed_object = _set_combined_token_value( - # current_yield_str[level][level_question_nr], parsed_object - # ) - # yield parsed_object - - # if ( - # len(current_yield_components[level][level_question_nr]) > 15 - # ): # ??? 15? - # citation_potential[level][level_question_nr] = False - # parsed_object = _set_combined_token_value( - # current_yield_str[level][level_question_nr], parsed_object - # ) - # yield parsed_object - # elif "]" in current_yield_str[level][level_question_nr]: - # section_split = current_yield_str[level][level_question_nr].split( - # "]" - # ) - # section_split[0] + "]" # dead code? - # start_of_next_section = "]".join(section_split[1:]) - # citation_string = current_yield_str[level][level_question_nr][ - # : -len(start_of_next_section) - # ] - # if "[D" in citation_string: - # cite_open_bracket_marker, cite_close_bracket_marker = ( - # "[", - # "]", - # ) - # cite_identifyer = "D" - - # try: - # cited_document = int( - # citation_string[level][level_question_nr][2:-1] - # ) - # if level and level_question_nr: - # link = agent_document_citations[int(level)][ - # int(level_question_nr) - # ][cited_document].link - # else: - # link = "" - # except (ValueError, IndexError): - # link = "" - # elif "[Q" in citation_string: - # cite_open_bracket_marker, cite_close_bracket_marker = ( - # "{", - # "}", - # ) - # cite_identifyer = "Q" - # else: - # pass - - # citation_string = citation_string.replace( - # "[" + cite_identifyer, - # cite_open_bracket_marker * 2, - # ).replace("]", cite_close_bracket_marker * 2) - - # if cite_identifyer == "D": - # citation_string += f"({link})" - - # parsed_object = _set_combined_token_value( - # citation_string, parsed_object - # ) - - # yield parsed_object - - # current_yield_components[level][level_question_nr] = [ - # start_of_next_section - # ] - # if not start_of_next_section.strip().startswith("["): - # citation_potential[level][level_question_nr] = False - - # elif isinstance(parsed_object, ExtendedToolResponse): - # if parsed_object.id == "search_response_summary": - # level = parsed_object.level - # level_question_nr = parsed_object.level_question_nr - # for inference_section in parsed_object.response.top_sections: - # doc_link = inference_section.center_chunk.source_links[0] - # doc_title = inference_section.center_chunk.title - # doc_id = inference_section.center_chunk.document_id - - # if ( - # doc_id - # not in agent_question_citations_used_docs[level][ - # level_question_nr - # ] - # ): - # if level not in agent_document_citations: - # agent_document_citations[level] = {} - # if level_question_nr not in agent_document_citations[level]: - # agent_document_citations[level][level_question_nr] = [] - - # agent_document_citations[level][level_question_nr].append( - # AgentDocumentCitations( - # document_id=doc_id, - # document_title=doc_title, - # link=doc_link, - # ) - # ) - # agent_question_citations_used_docs[level][ - # level_question_nr - # ].append(doc_id) - - yield parsed_object - def process_token( self, token: str | None ) -> Generator[OnyxAnswerPiece | CitationInfo, None, None]: diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index 523508655..605059081 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -41,6 +41,7 @@ DEFAULT_CC_PAIR_ID = 1 # subquestion level and question number for basic flow BASIC_KEY = (-1, -1) AGENT_SEARCH_INITIAL_KEY = (0, 0) +CANCEL_CHECK_INTERVAL = 20 # Postgres connection constants for application_name POSTGRES_WEB_APP_NAME = "web" POSTGRES_INDEXER_APP_NAME = "indexer" diff --git a/backend/tests/unit/onyx/chat/conftest.py b/backend/tests/unit/onyx/chat/conftest.py index d68627074..69b835c56 100644 --- a/backend/tests/unit/onyx/chat/conftest.py +++ b/backend/tests/unit/onyx/chat/conftest.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock import pytest from langchain_core.messages import SystemMessage -from onyx.agents.agent_search.models import AgentSearchConfig from onyx.chat.chat_utils import llm_doc_from_inference_section from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationConfig @@ -14,22 +13,17 @@ from onyx.chat.models import OnyxContext from onyx.chat.models import OnyxContexts from onyx.chat.models import PromptConfig from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder -from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message -from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message from onyx.configs.constants import DocumentSource from onyx.context.search.models import InferenceChunk from onyx.context.search.models import InferenceSection -from onyx.context.search.models import SearchRequest from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLMConfig -from onyx.tools.force import ForceUseTool from onyx.tools.models import ToolResponse from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.search_like_tool_utils import ( FINAL_CONTEXT_DOCUMENTS_ID, ) -from onyx.tools.utils import explicit_tool_calling_supported QUERY = "Test question" DEFAULT_SEARCH_ARGS = {"query": "search"} @@ -40,43 +34,6 @@ def answer_style_config() -> AnswerStyleConfig: return AnswerStyleConfig(citation_config=CitationConfig()) -@pytest.fixture -def agent_search_config( - mock_llm: LLM, mock_search_tool: SearchTool, prompt_config: PromptConfig -) -> AgentSearchConfig: - prompt_builder = AnswerPromptBuilder( - user_message=default_build_user_message( - user_query=QUERY, - prompt_config=prompt_config, - files=[], - single_message_history=None, - ), - message_history=[], - llm_config=mock_llm.config, - raw_user_query=QUERY, - raw_user_uploaded_files=[], - single_message_history=None, - ) - prompt_builder.update_system_prompt(default_build_system_message(prompt_config)) - using_tool_calling_llm = explicit_tool_calling_supported( - mock_llm.config.model_provider, mock_llm.config.model_name - ) - return AgentSearchConfig( - search_request=SearchRequest(query=QUERY), - primary_llm=mock_llm, - fast_llm=mock_llm, - search_tool=mock_search_tool, - force_use_tool=ForceUseTool(force_use=False, tool_name=""), - prompt_builder=prompt_builder, - chat_session_id=None, - message_id=1, - use_persistence=True, - db_session=None, - use_agentic_search=False, - using_tool_calling_llm=using_tool_calling_llm, - ) - - @pytest.fixture def prompt_config() -> PromptConfig: return PromptConfig( @@ -89,7 +46,7 @@ def prompt_config() -> PromptConfig: @pytest.fixture def mock_llm() -> MagicMock: - mock_llm_obj = MagicMock() + mock_llm_obj = MagicMock(spec=LLM) mock_llm_obj.config = LLMConfig( model_provider="openai", model_name="gpt-4o", diff --git a/backend/tests/unit/onyx/chat/test_answer.py b/backend/tests/unit/onyx/chat/test_answer.py index cb80be9e1..3c725fa27 100644 --- a/backend/tests/unit/onyx/chat/test_answer.py +++ b/backend/tests/unit/onyx/chat/test_answer.py @@ -65,6 +65,7 @@ def answer_instance( search_request=SearchRequest(query=QUERY), chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), current_agent_message_id=0, + use_agentic_persistence=False, ) diff --git a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py index 3aa937565..e00a61893 100644 --- a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py +++ b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py @@ -11,6 +11,7 @@ from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import PromptConfig from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.context.search.models import SearchRequest +from onyx.llm.interfaces import LLM from onyx.tools.force import ForceUseTool from onyx.tools.tool_implementations.search.search_tool import SearchTool from tests.regression.answer_quality.run_qa import _process_and_write_query_results @@ -38,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag( question = config["question"] skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"] - mock_llm = Mock() + mock_llm = Mock(spec=LLM) mock_llm.config = Mock() mock_llm.config.model_name = "gpt-4o-mini" mock_llm.stream = Mock() @@ -66,6 +67,7 @@ def test_skip_gen_ai_answer_generation_flag( ), chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), current_agent_message_id=0, + use_agentic_persistence=False, ) results = list(answer.processed_streamed_output) for res in results: