diff --git a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py index 1269b3dd4..7a7c8ffc2 100644 --- a/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py +++ b/backend/onyx/agents/agent_search/deep_search/initial/generate_initial_answer/nodes/generate_initial_answer.py @@ -153,8 +153,9 @@ def generate_initial_answer( ) for tool_response in yield_search_responses( query=question, - reranked_sections=answer_generation_documents.streaming_documents, - final_context_sections=answer_generation_documents.context_documents, + get_retrieved_sections=lambda: answer_generation_documents.context_documents, + get_reranked_sections=lambda: answer_generation_documents.streaming_documents, + get_final_context_sections=lambda: answer_generation_documents.context_documents, search_query_info=query_info, get_section_relevance=lambda: relevance_list, search_tool=graph_config.tooling.search_tool, diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py index 9782c1340..b17c39a6d 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/generate_validate_refined_answer.py @@ -179,8 +179,9 @@ def generate_validate_refined_answer( ) for tool_response in yield_search_responses( query=question, - reranked_sections=answer_generation_documents.streaming_documents, - final_context_sections=answer_generation_documents.context_documents, + get_retrieved_sections=lambda: answer_generation_documents.context_documents, + get_reranked_sections=lambda: answer_generation_documents.streaming_documents, + get_final_context_sections=lambda: answer_generation_documents.context_documents, search_query_info=query_info, get_section_relevance=lambda: relevance_list, search_tool=graph_config.tooling.search_tool, diff --git a/backend/onyx/agents/agent_search/deep_search/main/operations.py b/backend/onyx/agents/agent_search/deep_search/main/operations.py index 152581e10..46d41d477 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/operations.py +++ b/backend/onyx/agents/agent_search/deep_search/main/operations.py @@ -13,7 +13,6 @@ from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason from onyx.chat.models import StreamType from onyx.chat.models import SubQuestionPiece -from onyx.context.search.models import IndexFilters from onyx.tools.models import SearchQueryInfo from onyx.utils.logger import setup_logger @@ -144,8 +143,6 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo: if result.query_info is not None: query_info = result.query_info break - return query_info or SearchQueryInfo( - predicted_search=None, - final_filters=IndexFilters(access_control_list=None), - recency_bias_multiplier=1.0, - ) + + assert query_info is not None, "must have query info" + return query_info diff --git a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/format_results.py b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/format_results.py index 5683f4c70..272f02e4a 100644 --- a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/format_results.py +++ b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/format_results.py @@ -56,8 +56,9 @@ def format_results( relevance_list = relevance_from_docs(reranked_documents) for tool_response in yield_search_responses( query=state.question, - reranked_sections=state.retrieved_documents, - final_context_sections=reranked_documents, + get_retrieved_sections=lambda: reranked_documents, + get_reranked_sections=lambda: state.retrieved_documents, + get_final_context_sections=lambda: reranked_documents, search_query_info=query_info, get_section_relevance=lambda: relevance_list, search_tool=graph_config.tooling.search_tool, diff --git a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py index 380761451..8de1eeae9 100644 --- a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py +++ b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/retrieve_documents.py @@ -91,7 +91,7 @@ def retrieve_documents( retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS] if AGENT_RETRIEVAL_STATS: - pre_rerank_docs = callback_container[0] + pre_rerank_docs = callback_container[0] if callback_container else [] fit_scores = get_fit_scores( pre_rerank_docs, retrieved_docs, diff --git a/backend/onyx/agents/agent_search/orchestration/nodes/call_tool.py b/backend/onyx/agents/agent_search/orchestration/nodes/call_tool.py index 5265d5a61..8b596e778 100644 --- a/backend/onyx/agents/agent_search/orchestration/nodes/call_tool.py +++ b/backend/onyx/agents/agent_search/orchestration/nodes/call_tool.py @@ -44,7 +44,9 @@ def call_tool( tool = tool_choice.tool tool_args = tool_choice.tool_args tool_id = tool_choice.id - tool_runner = ToolRunner(tool, tool_args) + tool_runner = ToolRunner( + tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs + ) tool_kickoff = tool_runner.kickoff() emit_packet(tool_kickoff, writer) diff --git a/backend/onyx/agents/agent_search/orchestration/nodes/choose_tool.py b/backend/onyx/agents/agent_search/orchestration/nodes/choose_tool.py index 59094d12e..1562d5d72 100644 --- a/backend/onyx/agents/agent_search/orchestration/nodes/choose_tool.py +++ b/backend/onyx/agents/agent_search/orchestration/nodes/choose_tool.py @@ -15,8 +15,17 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name from onyx.chat.tool_handling.tool_response_handler import ( get_tool_call_for_non_tool_calling_llm_impl, ) +from onyx.context.search.preprocessing.preprocessing import query_analysis +from onyx.context.search.retrieval.search_runner import get_query_embedding +from onyx.tools.models import SearchToolOverrideKwargs from onyx.tools.tool import Tool +from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.utils.logger import setup_logger +from onyx.utils.threadpool_concurrency import run_in_background +from onyx.utils.threadpool_concurrency import TimeoutThread +from onyx.utils.threadpool_concurrency import wait_on_background +from onyx.utils.timing import log_function_time +from shared_configs.model_server_models import Embedding logger = setup_logger() @@ -25,6 +34,7 @@ logger = setup_logger() # and a function that handles extracting the necessary fields # from the state and config # TODO: fan-out to multiple tool call nodes? Make this configurable? +@log_function_time(print_only=True) def choose_tool( state: ToolChoiceState, config: RunnableConfig, @@ -37,6 +47,31 @@ def choose_tool( should_stream_answer = state.should_stream_answer agent_config = cast(GraphConfig, config["metadata"]["config"]) + + force_use_tool = agent_config.tooling.force_use_tool + + embedding_thread: TimeoutThread[Embedding] | None = None + keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None + override_kwargs: SearchToolOverrideKwargs | None = None + if ( + not agent_config.behavior.use_agentic_search + and agent_config.tooling.search_tool is not None + and ( + not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name + ) + ): + override_kwargs = SearchToolOverrideKwargs() + # Run in a background thread to avoid blocking the main thread + embedding_thread = run_in_background( + get_query_embedding, + agent_config.inputs.search_request.query, + agent_config.persistence.db_session, + ) + keyword_thread = run_in_background( + query_analysis, + agent_config.inputs.search_request.query, + ) + using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder @@ -47,7 +82,6 @@ def choose_tool( tools = [ tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools ] - force_use_tool = agent_config.tooling.force_use_tool tool, tool_args = None, None if force_use_tool.force_use and force_use_tool.args is not None: @@ -71,11 +105,22 @@ def choose_tool( # If we have a tool and tool args, we are ready to request a tool call. # This only happens if the tool call was forced or we are using a non-tool calling LLM. if tool and tool_args: + if embedding_thread and tool.name == SearchTool._NAME: + # Wait for the embedding thread to finish + embedding = wait_on_background(embedding_thread) + assert override_kwargs is not None, "must have override kwargs" + override_kwargs.precomputed_query_embedding = embedding + if keyword_thread and tool.name == SearchTool._NAME: + is_keyword, keywords = wait_on_background(keyword_thread) + assert override_kwargs is not None, "must have override kwargs" + override_kwargs.precomputed_is_keyword = is_keyword + override_kwargs.precomputed_keywords = keywords return ToolChoiceUpdate( tool_choice=ToolChoice( tool=tool, tool_args=tool_args, id=str(uuid4()), + search_tool_override_kwargs=override_kwargs, ), ) @@ -153,10 +198,22 @@ def choose_tool( logger.debug(f"Selected tool: {selected_tool.name}") logger.debug(f"Selected tool call request: {selected_tool_call_request}") + if embedding_thread and selected_tool.name == SearchTool._NAME: + # Wait for the embedding thread to finish + embedding = wait_on_background(embedding_thread) + assert override_kwargs is not None, "must have override kwargs" + override_kwargs.precomputed_query_embedding = embedding + if keyword_thread and selected_tool.name == SearchTool._NAME: + is_keyword, keywords = wait_on_background(keyword_thread) + assert override_kwargs is not None, "must have override kwargs" + override_kwargs.precomputed_is_keyword = is_keyword + override_kwargs.precomputed_keywords = keywords + return ToolChoiceUpdate( tool_choice=ToolChoice( tool=selected_tool, tool_args=selected_tool_call_request["args"], id=selected_tool_call_request["id"], + search_tool_override_kwargs=override_kwargs, ), ) diff --git a/backend/onyx/agents/agent_search/orchestration/nodes/use_tool_response.py b/backend/onyx/agents/agent_search/orchestration/nodes/use_tool_response.py index 6874aae97..4bec51bc5 100644 --- a/backend/onyx/agents/agent_search/orchestration/nodes/use_tool_response.py +++ b/backend/onyx/agents/agent_search/orchestration/nodes/use_tool_response.py @@ -9,18 +9,23 @@ from onyx.agents.agent_search.basic.states import BasicState from onyx.agents.agent_search.basic.utils import process_llm_stream from onyx.agents.agent_search.models import GraphConfig from onyx.chat.models import LlmDoc -from onyx.chat.models import OnyxContexts from onyx.tools.tool_implementations.search.search_tool import ( - SEARCH_DOC_CONTENT_ID, + SEARCH_RESPONSE_SUMMARY_ID, +) +from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary +from onyx.tools.tool_implementations.search.search_utils import ( + context_from_inference_section, ) from onyx.tools.tool_implementations.search_like_tool_utils import ( FINAL_CONTEXT_DOCUMENTS_ID, ) from onyx.utils.logger import setup_logger +from onyx.utils.timing import log_function_time logger = setup_logger() +@log_function_time(print_only=True) def basic_use_tool_response( state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None ) -> BasicOutput: @@ -50,11 +55,13 @@ def basic_use_tool_response( for yield_item in tool_call_responses: if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID: final_search_results = cast(list[LlmDoc], yield_item.response) - elif yield_item.id == SEARCH_DOC_CONTENT_ID: - search_contexts = cast(OnyxContexts, yield_item.response).contexts - for doc in search_contexts: - if doc.document_id not in initial_search_results: - initial_search_results.append(doc) + elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID: + search_response_summary = cast(SearchResponseSummary, yield_item.response) + for section in search_response_summary.top_sections: + if section.center_chunk.document_id not in initial_search_results: + initial_search_results.append( + context_from_inference_section(section) + ) new_tool_call_chunk = AIMessageChunk(content="") if not agent_config.behavior.skip_gen_ai_answer_generation: diff --git a/backend/onyx/agents/agent_search/orchestration/states.py b/backend/onyx/agents/agent_search/orchestration/states.py index 266e71cf2..917f60921 100644 --- a/backend/onyx/agents/agent_search/orchestration/states.py +++ b/backend/onyx/agents/agent_search/orchestration/states.py @@ -2,6 +2,7 @@ from pydantic import BaseModel from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot from onyx.tools.message import ToolCallSummary +from onyx.tools.models import SearchToolOverrideKwargs from onyx.tools.models import ToolCallFinalResult from onyx.tools.models import ToolCallKickoff from onyx.tools.models import ToolResponse @@ -35,6 +36,7 @@ class ToolChoice(BaseModel): tool: Tool tool_args: dict id: str | None + search_tool_override_kwargs: SearchToolOverrideKwargs | None = None class Config: arbitrary_types_allowed = True diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/constants.py b/backend/onyx/agents/agent_search/shared_graph_utils/constants.py index 79ebcf338..ca7828cd8 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/constants.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/constants.py @@ -13,6 +13,11 @@ AGENT_NEGATIVE_VALUE_STR = "no" AGENT_ANSWER_SEPARATOR = "Answer:" +EMBEDDING_KEY = "embedding" +IS_KEYWORD_KEY = "is_keyword" +KEYWORDS_KEY = "keywords" + + class AgentLLMErrorType(str, Enum): TIMEOUT = "timeout" RATE_LIMIT = "rate_limit" diff --git a/backend/onyx/chat/llm_response_handler.py b/backend/onyx/chat/llm_response_handler.py index 2bf3e8476..37a30b433 100644 --- a/backend/onyx/chat/llm_response_handler.py +++ b/backend/onyx/chat/llm_response_handler.py @@ -15,6 +15,8 @@ from onyx.chat.stream_processing.answer_response_handler import ( from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler +# This is Legacy code that is not used anymore. +# It is kept here for reference. class LLMResponseHandlerManager: """ This class is responsible for postprocessing the LLM response stream. diff --git a/backend/onyx/chat/stream_processing/citation_processing.py b/backend/onyx/chat/stream_processing/citation_processing.py index 071b28c34..342472f5f 100644 --- a/backend/onyx/chat/stream_processing/citation_processing.py +++ b/backend/onyx/chat/stream_processing/citation_processing.py @@ -90,97 +90,97 @@ class CitationProcessor: next(group for group in citation.groups() if group is not None) ) - if 1 <= numerical_value <= self.max_citation_num: - context_llm_doc = self.context_docs[numerical_value - 1] - final_citation_num = self.final_order_mapping[ + if not (1 <= numerical_value <= self.max_citation_num): + continue + + context_llm_doc = self.context_docs[numerical_value - 1] + final_citation_num = self.final_order_mapping[ + context_llm_doc.document_id + ] + + if final_citation_num not in self.citation_order: + self.citation_order.append(final_citation_num) + + citation_order_idx = self.citation_order.index(final_citation_num) + 1 + + # get the value that was displayed to user, should always + # be in the display_doc_order_dict. But check anyways + if context_llm_doc.document_id in self.display_order_mapping: + displayed_citation_num = self.display_order_mapping[ context_llm_doc.document_id ] - - if final_citation_num not in self.citation_order: - self.citation_order.append(final_citation_num) - - citation_order_idx = ( - self.citation_order.index(final_citation_num) + 1 + else: + displayed_citation_num = final_citation_num + logger.warning( + f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead." ) - # get the value that was displayed to user, should always - # be in the display_doc_order_dict. But check anyways - if context_llm_doc.document_id in self.display_order_mapping: - displayed_citation_num = self.display_order_mapping[ - context_llm_doc.document_id - ] - else: - displayed_citation_num = final_citation_num - logger.warning( - f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead." - ) - - # Skip consecutive citations of the same work - if final_citation_num in self.current_citations: - start, end = citation.span() - real_start = length_to_add + start - diff = end - start - self.curr_segment = ( - self.curr_segment[: length_to_add + start] - + self.curr_segment[real_start + diff :] - ) - length_to_add -= diff - continue - - # Handle edge case where LLM outputs citation itself - if self.curr_segment.startswith("[["): - match = re.match(r"\[\[(\d+)\]\]", self.curr_segment) - if match: - try: - doc_id = int(match.group(1)) - context_llm_doc = self.context_docs[doc_id - 1] - yield CitationInfo( - # citation_num is now the number post initial ranking, i.e. as displayed to user - citation_num=displayed_citation_num, - document_id=context_llm_doc.document_id, - ) - except Exception as e: - logger.warning( - f"Manual LLM citation didn't properly cite documents {e}" - ) - else: - logger.warning( - "Manual LLM citation wasn't able to close brackets" - ) - continue - - link = context_llm_doc.link - - self.past_cite_count = len(self.llm_out) - self.current_citations.append(final_citation_num) - - if citation_order_idx not in self.cited_inds: - self.cited_inds.add(citation_order_idx) - yield CitationInfo( - # citation number is now the one that was displayed to user - citation_num=displayed_citation_num, - document_id=context_llm_doc.document_id, - ) - + # Skip consecutive citations of the same work + if final_citation_num in self.current_citations: start, end = citation.span() - if link: - prev_length = len(self.curr_segment) - self.curr_segment = ( - self.curr_segment[: start + length_to_add] - + f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user - + self.curr_segment[end + length_to_add :] - ) - length_to_add += len(self.curr_segment) - prev_length - else: - prev_length = len(self.curr_segment) - self.curr_segment = ( - self.curr_segment[: start + length_to_add] - + f"[[{displayed_citation_num}]]()" # use the value that was displayed to user - + self.curr_segment[end + length_to_add :] - ) - length_to_add += len(self.curr_segment) - prev_length + real_start = length_to_add + start + diff = end - start + self.curr_segment = ( + self.curr_segment[: length_to_add + start] + + self.curr_segment[real_start + diff :] + ) + length_to_add -= diff + continue - last_citation_end = end + length_to_add + # Handle edge case where LLM outputs citation itself + if self.curr_segment.startswith("[["): + match = re.match(r"\[\[(\d+)\]\]", self.curr_segment) + if match: + try: + doc_id = int(match.group(1)) + context_llm_doc = self.context_docs[doc_id - 1] + yield CitationInfo( + # citation_num is now the number post initial ranking, i.e. as displayed to user + citation_num=displayed_citation_num, + document_id=context_llm_doc.document_id, + ) + except Exception as e: + logger.warning( + f"Manual LLM citation didn't properly cite documents {e}" + ) + else: + logger.warning( + "Manual LLM citation wasn't able to close brackets" + ) + continue + + link = context_llm_doc.link + + self.past_cite_count = len(self.llm_out) + self.current_citations.append(final_citation_num) + + if citation_order_idx not in self.cited_inds: + self.cited_inds.add(citation_order_idx) + yield CitationInfo( + # citation number is now the one that was displayed to user + citation_num=displayed_citation_num, + document_id=context_llm_doc.document_id, + ) + + start, end = citation.span() + if link: + prev_length = len(self.curr_segment) + self.curr_segment = ( + self.curr_segment[: start + length_to_add] + + f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user + + self.curr_segment[end + length_to_add :] + ) + length_to_add += len(self.curr_segment) - prev_length + else: + prev_length = len(self.curr_segment) + self.curr_segment = ( + self.curr_segment[: start + length_to_add] + + f"[[{displayed_citation_num}]]()" # use the value that was displayed to user + + self.curr_segment[end + length_to_add :] + ) + length_to_add += len(self.curr_segment) - prev_length + + last_citation_end = end + length_to_add if last_citation_end > 0: result += self.curr_segment[:last_citation_end] diff --git a/backend/onyx/context/search/models.py b/backend/onyx/context/search/models.py index 3d19db186..980ed9644 100644 --- a/backend/onyx/context/search/models.py +++ b/backend/onyx/context/search/models.py @@ -16,7 +16,7 @@ from onyx.db.models import SearchSettings from onyx.indexing.models import BaseChunk from onyx.indexing.models import IndexingSetting from shared_configs.enums import RerankerProvider - +from shared_configs.model_server_models import Embedding MAX_METRICS_CONTENT = ( 200 # Just need enough characters to identify where in the doc the chunk is @@ -151,6 +151,10 @@ class SearchRequest(ChunkContext): evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED model_config = ConfigDict(arbitrary_types_allowed=True) + precomputed_query_embedding: Embedding | None = None + precomputed_is_keyword: bool | None = None + precomputed_keywords: list[str] | None = None + class SearchQuery(ChunkContext): "Processed Request that is directly passed to the SearchPipeline" @@ -175,6 +179,8 @@ class SearchQuery(ChunkContext): offset: int = 0 model_config = ConfigDict(frozen=True) + precomputed_query_embedding: Embedding | None = None + class RetrievalDetails(ChunkContext): # Use LLM to determine whether to do a retrieval or only rely on existing history diff --git a/backend/onyx/context/search/pipeline.py b/backend/onyx/context/search/pipeline.py index faf7a8988..c810b7089 100644 --- a/backend/onyx/context/search/pipeline.py +++ b/backend/onyx/context/search/pipeline.py @@ -331,6 +331,14 @@ class SearchPipeline: self._retrieved_sections = expanded_inference_sections return expanded_inference_sections + @property + def retrieved_sections(self) -> list[InferenceSection]: + if self._retrieved_sections is not None: + return self._retrieved_sections + + self._retrieved_sections = self._get_sections() + return self._retrieved_sections + @property def reranked_sections(self) -> list[InferenceSection]: """Reranking is always done at the chunk level since section merging could create arbitrarily @@ -343,7 +351,7 @@ class SearchPipeline: if self._reranked_sections is not None: return self._reranked_sections - retrieved_sections = self._get_sections() + retrieved_sections = self.retrieved_sections if self.retrieved_sections_callback is not None: self.retrieved_sections_callback(retrieved_sections) diff --git a/backend/onyx/context/search/preprocessing/preprocessing.py b/backend/onyx/context/search/preprocessing/preprocessing.py index d18ddd32b..814579e58 100644 --- a/backend/onyx/context/search/preprocessing/preprocessing.py +++ b/backend/onyx/context/search/preprocessing/preprocessing.py @@ -117,8 +117,12 @@ def retrieval_preprocessing( else None ) + # Sometimes this is pre-computed in parallel with other heavy tasks to improve + # latency, and in that case we don't need to run the model again run_query_analysis = ( - None if skip_query_analysis else FunctionCall(query_analysis, (query,), {}) + None + if (skip_query_analysis or search_request.precomputed_is_keyword is not None) + else FunctionCall(query_analysis, (query,), {}) ) functions_to_run = [ @@ -143,11 +147,12 @@ def retrieval_preprocessing( # The extracted keywords right now are not very reliable, not using for now # Can maybe use for highlighting - is_keyword, extracted_keywords = ( - parallel_results[run_query_analysis.result_id] - if run_query_analysis - else (False, None) - ) + is_keyword, _extracted_keywords = False, None + if search_request.precomputed_is_keyword is not None: + is_keyword = search_request.precomputed_is_keyword + _extracted_keywords = search_request.precomputed_keywords + elif run_query_analysis: + is_keyword, _extracted_keywords = parallel_results[run_query_analysis.result_id] all_query_terms = query.split() processed_keywords = ( @@ -247,4 +252,5 @@ def retrieval_preprocessing( chunks_above=chunks_above, chunks_below=chunks_below, full_doc=search_request.full_doc, + precomputed_query_embedding=search_request.precomputed_query_embedding, ) diff --git a/backend/onyx/context/search/retrieval/search_runner.py b/backend/onyx/context/search/retrieval/search_runner.py index 64491a20a..6c77167ad 100644 --- a/backend/onyx/context/search/retrieval/search_runner.py +++ b/backend/onyx/context/search/retrieval/search_runner.py @@ -31,7 +31,7 @@ from onyx.utils.timing import log_function_time from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.enums import EmbedTextType - +from shared_configs.model_server_models import Embedding logger = setup_logger() @@ -109,6 +109,20 @@ def combine_retrieval_results( return sorted_chunks +def get_query_embedding(query: str, db_session: Session) -> Embedding: + search_settings = get_current_search_settings(db_session) + + model = EmbeddingModel.from_db_model( + search_settings=search_settings, + # The below are globally set, this flow always uses the indexing one + server_host=MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + ) + + query_embedding = model.encode([query], text_type=EmbedTextType.QUERY)[0] + return query_embedding + + @log_function_time(print_only=True) def doc_index_retrieval( query: SearchQuery, @@ -121,17 +135,10 @@ def doc_index_retrieval( from the large chunks to the referenced chunks, dedupes the chunks, and cleans the chunks. """ - search_settings = get_current_search_settings(db_session) - - model = EmbeddingModel.from_db_model( - search_settings=search_settings, - # The below are globally set, this flow always uses the indexing one - server_host=MODEL_SERVER_HOST, - server_port=MODEL_SERVER_PORT, + query_embedding = query.precomputed_query_embedding or get_query_embedding( + query.query, db_session ) - query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0] - top_chunks = document_index.hybrid_retrieval( query=query.query, query_embedding=query_embedding, @@ -250,6 +257,9 @@ def retrieve_chunks( simplified_queries.add(simplified_rephrase) q_copy = query.copy(update={"query": rephrase}, deep=True) + q_copy.precomputed_query_embedding = ( + None # need to recompute for each rephrase + ) run_queries.append( ( doc_index_retrieval, diff --git a/backend/onyx/tools/models.py b/backend/onyx/tools/models.py index 1e343e74c..c26e0b942 100644 --- a/backend/onyx/tools/models.py +++ b/backend/onyx/tools/models.py @@ -9,6 +9,7 @@ from sqlalchemy.orm import Session from onyx.context.search.enums import SearchType from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceSection +from shared_configs.model_server_models import Embedding class ToolResponse(BaseModel): @@ -60,11 +61,15 @@ class SearchQueryInfo(BaseModel): recency_bias_multiplier: float +# None indicates that the default value should be used class SearchToolOverrideKwargs(BaseModel): - force_no_rerank: bool - alternate_db_session: Session | None - retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None - skip_query_analysis: bool + force_no_rerank: bool | None = None + alternate_db_session: Session | None = None + retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None = None + skip_query_analysis: bool | None = None + precomputed_query_embedding: Embedding | None = None + precomputed_is_keyword: bool | None = None + precomputed_keywords: list[str] | None = None class Config: arbitrary_types_allowed = True diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index 4b556e471..2ca5d8bb6 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -3,6 +3,7 @@ from collections.abc import Callable from collections.abc import Generator from typing import Any from typing import cast +from typing import TypeVar from sqlalchemy.orm import Session @@ -11,7 +12,6 @@ from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import ContextualPruningConfig from onyx.chat.models import DocumentPruningConfig from onyx.chat.models import LlmDoc -from onyx.chat.models import OnyxContext from onyx.chat.models import OnyxContexts from onyx.chat.models import PromptConfig from onyx.chat.models import SectionRelevancePiece @@ -42,6 +42,9 @@ from onyx.tools.models import SearchQueryInfo from onyx.tools.models import SearchToolOverrideKwargs from onyx.tools.models import ToolResponse from onyx.tools.tool import Tool +from onyx.tools.tool_implementations.search.search_utils import ( + context_from_inference_section, +) from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict from onyx.tools.tool_implementations.search_like_tool_utils import ( build_next_prompt_for_search_like_tool, @@ -281,16 +284,23 @@ class SearchTool(Tool[SearchToolOverrideKwargs]): self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any ) -> Generator[ToolResponse, None, None]: query = cast(str, llm_kwargs[QUERY_FIELD]) + precomputed_query_embedding = None + precomputed_is_keyword = None + precomputed_keywords = None force_no_rerank = False alternate_db_session = None retrieved_sections_callback = None skip_query_analysis = False if override_kwargs: - force_no_rerank = override_kwargs.force_no_rerank + force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False) alternate_db_session = override_kwargs.alternate_db_session retrieved_sections_callback = override_kwargs.retrieved_sections_callback - skip_query_analysis = override_kwargs.skip_query_analysis - + skip_query_analysis = use_alt_not_None( + override_kwargs.skip_query_analysis, False + ) + precomputed_query_embedding = override_kwargs.precomputed_query_embedding + precomputed_is_keyword = override_kwargs.precomputed_is_keyword + precomputed_keywords = override_kwargs.precomputed_keywords if self.selected_sections: yield from self._build_response_for_specified_sections(query) return @@ -327,6 +337,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]): if self.retrieval_options else None ), + precomputed_query_embedding=precomputed_query_embedding, + precomputed_is_keyword=precomputed_is_keyword, + precomputed_keywords=precomputed_keywords, ), user=self.user, llm=self.llm, @@ -345,8 +358,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]): ) yield from yield_search_responses( query, - search_pipeline.reranked_sections, - search_pipeline.final_context_sections, + lambda: search_pipeline.retrieved_sections, + lambda: search_pipeline.reranked_sections, + lambda: search_pipeline.final_context_sections, search_query_info, lambda: search_pipeline.section_relevance, self, @@ -383,10 +397,16 @@ class SearchTool(Tool[SearchToolOverrideKwargs]): # SearchTool passed in to allow for access to SearchTool properties. # We can't just call SearchTool methods in the graph because we're operating on # the retrieved docs (reranking, deduping, etc.) after the SearchTool has run. +# +# The various inference sections are passed in as functions to allow for lazy +# evaluation. The SearchPipeline object properties that they correspond to are +# actually functions defined with @property decorators, and passing them into +# this function causes them to get evaluated immediately which is undesirable. def yield_search_responses( query: str, - reranked_sections: list[InferenceSection], - final_context_sections: list[InferenceSection], + get_retrieved_sections: Callable[[], list[InferenceSection]], + get_reranked_sections: Callable[[], list[InferenceSection]], + get_final_context_sections: Callable[[], list[InferenceSection]], search_query_info: SearchQueryInfo, get_section_relevance: Callable[[], list[SectionRelevancePiece] | None], search_tool: SearchTool, @@ -395,7 +415,7 @@ def yield_search_responses( id=SEARCH_RESPONSE_SUMMARY_ID, response=SearchResponseSummary( rephrased_query=query, - top_sections=final_context_sections, + top_sections=get_retrieved_sections(), predicted_flow=QueryFlow.QUESTION_ANSWER, predicted_search=search_query_info.predicted_search, final_filters=search_query_info.final_filters, @@ -407,13 +427,8 @@ def yield_search_responses( id=SEARCH_DOC_CONTENT_ID, response=OnyxContexts( contexts=[ - OnyxContext( - content=section.combined_content, - document_id=section.center_chunk.document_id, - semantic_identifier=section.center_chunk.semantic_identifier, - blurb=section.center_chunk.blurb, - ) - for section in reranked_sections + context_from_inference_section(section) + for section in get_reranked_sections() ] ), ) @@ -424,6 +439,7 @@ def yield_search_responses( response=section_relevance, ) + final_context_sections = get_final_context_sections() pruned_sections = prune_sections( sections=final_context_sections, section_relevance_list=section_relevance_list_impl( @@ -438,3 +454,10 @@ def yield_search_responses( llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections] yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) + + +T = TypeVar("T") + + +def use_alt_not_None(value: T | None, alt: T) -> T: + return value if value is not None else alt diff --git a/backend/onyx/tools/tool_implementations/search/search_utils.py b/backend/onyx/tools/tool_implementations/search/search_utils.py index dd44ca033..7b6c6383e 100644 --- a/backend/onyx/tools/tool_implementations/search/search_utils.py +++ b/backend/onyx/tools/tool_implementations/search/search_utils.py @@ -1,4 +1,5 @@ from onyx.chat.models import LlmDoc +from onyx.chat.models import OnyxContext from onyx.context.search.models import InferenceSection from onyx.prompts.prompt_utils import clean_up_source @@ -29,3 +30,12 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict: "%B %d, %Y %H:%M" ) return doc_dict + + +def context_from_inference_section(section: InferenceSection) -> OnyxContext: + return OnyxContext( + content=section.combined_content, + document_id=section.center_chunk.document_id, + semantic_identifier=section.center_chunk.semantic_identifier, + blurb=section.center_chunk.blurb, + ) diff --git a/backend/onyx/tools/tool_runner.py b/backend/onyx/tools/tool_runner.py index af124da4c..c5c2a73b4 100644 --- a/backend/onyx/tools/tool_runner.py +++ b/backend/onyx/tools/tool_runner.py @@ -1,6 +1,8 @@ from collections.abc import Callable from collections.abc import Generator from typing import Any +from typing import Generic +from typing import TypeVar from onyx.llm.interfaces import LLM from onyx.llm.models import PreviousMessage @@ -11,10 +13,16 @@ from onyx.tools.tool import Tool from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel -class ToolRunner: - def __init__(self, tool: Tool, args: dict[str, Any]): +R = TypeVar("R") + + +class ToolRunner(Generic[R]): + def __init__( + self, tool: Tool[R], args: dict[str, Any], override_kwargs: R | None = None + ): self.tool = tool self.args = args + self.override_kwargs = override_kwargs self._tool_responses: list[ToolResponse] | None = None @@ -27,7 +35,9 @@ class ToolRunner: return tool_responses: list[ToolResponse] = [] - for tool_response in self.tool.run(**self.args): + for tool_response in self.tool.run( + override_kwargs=self.override_kwargs, **self.args + ): yield tool_response tool_responses.append(tool_response) diff --git a/backend/onyx/utils/threadpool_concurrency.py b/backend/onyx/utils/threadpool_concurrency.py index 4ef87348f..fd8b70174 100644 --- a/backend/onyx/utils/threadpool_concurrency.py +++ b/backend/onyx/utils/threadpool_concurrency.py @@ -118,7 +118,7 @@ def run_functions_in_parallel( return results -class TimeoutThread(threading.Thread): +class TimeoutThread(threading.Thread, Generic[R]): def __init__( self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any ): @@ -159,3 +159,34 @@ def run_with_timeout( task.end() return task.result + + +# NOTE: this function should really only be used when run_functions_tuples_in_parallel is +# difficult to use. It's up to the programmer to call wait_on_background on the thread after +# the code you want to run in parallel is finished. As with all python thread parallelism, +# this is only useful for I/O bound tasks. +def run_in_background( + func: Callable[..., R], *args: Any, **kwargs: Any +) -> TimeoutThread[R]: + """ + Runs a function in a background thread. Returns a TimeoutThread object that can be used + to wait for the function to finish with wait_on_background. + """ + context = contextvars.copy_context() + # Timeout not used in the non-blocking case + task = TimeoutThread(-1, context.run, func, *args, **kwargs) + task.start() + return task + + +def wait_on_background(task: TimeoutThread[R]) -> R: + """ + Used in conjunction with run_in_background. blocks until the task is finished, + then returns the result of the task. + """ + task.join() + + if task.exception is not None: + raise task.exception + + return task.result diff --git a/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py b/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py index 74399e4d3..8b9505bbc 100644 --- a/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py +++ b/backend/tests/unit/onyx/utils/test_threadpool_concurrency.py @@ -1,8 +1,14 @@ +import contextvars import time import pytest +from onyx.utils.threadpool_concurrency import run_in_background from onyx.utils.threadpool_concurrency import run_with_timeout +from onyx.utils.threadpool_concurrency import wait_on_background + +# Create a context variable for testing +test_context_var = contextvars.ContextVar("test_var", default="default") def test_run_with_timeout_completes() -> None: @@ -59,3 +65,86 @@ def test_run_with_timeout_with_args_and_kwargs() -> None: # Test with positional and keyword args result2 = run_with_timeout(1.0, complex_function, x=5, y=3, multiply=True) assert result2 == 15 + + +def test_run_in_background_and_wait_success() -> None: + """Test that run_in_background and wait_on_background work correctly for successful execution""" + + def background_function(x: int) -> int: + time.sleep(0.1) # Small delay to ensure it's actually running in background + return x * 2 + + # Start the background task + task = run_in_background(background_function, 21) + + # Verify we can do other work while task is running + start_time = time.time() + result = wait_on_background(task) + elapsed = time.time() - start_time + + assert result == 42 + assert elapsed >= 0.1 # Verify we actually waited for the sleep + + +@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning") +def test_run_in_background_propagates_exceptions() -> None: + """Test that exceptions in background tasks are properly propagated""" + + def error_function() -> None: + time.sleep(0.1) # Small delay to ensure it's actually running in background + raise ValueError("Test background error") + + task = run_in_background(error_function) + + with pytest.raises(ValueError) as exc_info: + wait_on_background(task) + + assert "Test background error" in str(exc_info.value) + + +def test_run_in_background_with_args_and_kwargs() -> None: + """Test that args and kwargs are properly passed to the background function""" + + def complex_function(x: int, y: int, multiply: bool = False) -> int: + time.sleep(0.1) # Small delay to ensure it's actually running in background + if multiply: + return x * y + return x + y + + # Test with args + task1 = run_in_background(complex_function, 5, 3) + result1 = wait_on_background(task1) + assert result1 == 8 + + # Test with args and kwargs + task2 = run_in_background(complex_function, 5, 3, multiply=True) + result2 = wait_on_background(task2) + assert result2 == 15 + + +def test_multiple_background_tasks() -> None: + """Test running multiple background tasks concurrently""" + + def slow_add(x: int, y: int) -> int: + time.sleep(0.2) # Make each task take some time + return x + y + + # Start multiple tasks + start_time = time.time() + task1 = run_in_background(slow_add, 1, 2) + task2 = run_in_background(slow_add, 3, 4) + task3 = run_in_background(slow_add, 5, 6) + + # Wait for all results + result1 = wait_on_background(task1) + result2 = wait_on_background(task2) + result3 = wait_on_background(task3) + elapsed = time.time() - start_time + + # Verify results + assert result1 == 3 + assert result2 == 7 + assert result3 == 11 + + # Verify tasks ran in parallel (total time should be ~0.2s, not ~0.6s) + assert 0.2 <= elapsed < 0.4 # Allow some buffer for test environment variations diff --git a/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py b/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py index 4d6d9a6a3..ab92b4e55 100644 --- a/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py +++ b/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py @@ -4,7 +4,9 @@ import time from onyx.utils.threadpool_concurrency import FunctionCall from onyx.utils.threadpool_concurrency import run_functions_in_parallel from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel +from onyx.utils.threadpool_concurrency import run_in_background from onyx.utils.threadpool_concurrency import run_with_timeout +from onyx.utils.threadpool_concurrency import wait_on_background # Create a test contextvar test_var = contextvars.ContextVar("test_var", default="default") @@ -129,3 +131,39 @@ def test_contextvar_isolation_between_runs() -> None: # Verify second run results assert all(result in ["thread3", "thread4"] for result in second_results) + + +def test_run_in_background_preserves_contextvar() -> None: + """Test that run_in_background preserves contextvar values and modifications are isolated""" + + def modify_and_sleep() -> tuple[str, str]: + """Modifies contextvar, sleeps, and returns original, modified, and final values""" + original = test_var.get() + test_var.set("modified_in_background") + time.sleep(0.1) # Ensure we can check main thread during execution + final = test_var.get() + return original, final + + # Set initial value in main thread + token = test_var.set("initial_value") + try: + # Start background task + task = run_in_background(modify_and_sleep) + + # Verify main thread value remains unchanged while task runs + assert test_var.get() == "initial_value" + + # Get results from background thread + original, modified = wait_on_background(task) + + # Verify the background thread: + # 1. Saw the initial value + assert original == "initial_value" + # 2. Successfully modified its own copy + assert modified == "modified_in_background" + + # Verify main thread value is still unchanged after task completion + assert test_var.get() == "initial_value" + finally: + # Clean up + test_var.reset(token)