improved basic search latency (#4186)

* improved basic search latency

* address PR comments + minor cleanup
This commit is contained in:
evan-danswer 2025-03-06 14:22:59 -08:00 committed by GitHub
parent 29382656fc
commit b7da91e3ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 460 additions and 149 deletions

View File

@ -153,8 +153,9 @@ def generate_initial_answer(
)
for tool_response in yield_search_responses(
query=question,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,

View File

@ -179,8 +179,9 @@ def generate_validate_refined_answer(
)
for tool_response in yield_search_responses(
query=question,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,

View File

@ -13,7 +13,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import IndexFilters
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
@ -144,8 +143,6 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
if result.query_info is not None:
query_info = result.query_info
break
return query_info or SearchQueryInfo(
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=1.0,
)
assert query_info is not None, "must have query info"
return query_info

View File

@ -56,8 +56,9 @@ def format_results(
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
query=state.question,
reranked_sections=state.retrieved_documents,
final_context_sections=reranked_documents,
get_retrieved_sections=lambda: reranked_documents,
get_reranked_sections=lambda: state.retrieved_documents,
get_final_context_sections=lambda: reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,

View File

@ -91,7 +91,7 @@ def retrieve_documents(
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0]
pre_rerank_docs = callback_container[0] if callback_container else []
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,

View File

@ -44,7 +44,9 @@ def call_tool(
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(tool, tool_args)
tool_runner = ToolRunner(
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
)
tool_kickoff = tool_runner.kickoff()
emit_packet(tool_kickoff, writer)

View File

@ -15,8 +15,17 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@ -25,6 +34,7 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
@log_function_time(print_only=True)
def choose_tool(
state: ToolChoiceState,
config: RunnableConfig,
@ -37,6 +47,31 @@ def choose_tool(
should_stream_answer = state.should_stream_answer
agent_config = cast(GraphConfig, config["metadata"]["config"])
force_use_tool = agent_config.tooling.force_use_tool
embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
override_kwargs: SearchToolOverrideKwargs | None = None
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and (
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
)
):
override_kwargs = SearchToolOverrideKwargs()
# Run in a background thread to avoid blocking the main thread
embedding_thread = run_in_background(
get_query_embedding,
agent_config.inputs.search_request.query,
agent_config.persistence.db_session,
)
keyword_thread = run_in_background(
query_analysis,
agent_config.inputs.search_request.query,
)
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
@ -47,7 +82,6 @@ def choose_tool(
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
force_use_tool = agent_config.tooling.force_use_tool
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
@ -71,11 +105,22 @@ def choose_tool(
# If we have a tool and tool args, we are ready to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
if embedding_thread and tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
search_tool_override_kwargs=override_kwargs,
),
)
@ -153,10 +198,22 @@ def choose_tool(
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
if embedding_thread and selected_tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and selected_tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
search_tool_override_kwargs=override_kwargs,
),
)

View File

@ -9,18 +9,23 @@ from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_DOC_CONTENT_ID,
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
@log_function_time(print_only=True)
def basic_use_tool_response(
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BasicOutput:
@ -50,11 +55,13 @@ def basic_use_tool_response(
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
search_contexts = cast(OnyxContexts, yield_item.response).contexts
for doc in search_contexts:
if doc.document_id not in initial_search_results:
initial_search_results.append(doc)
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(
context_from_inference_section(section)
)
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@ -2,6 +2,7 @@ from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
@ -35,6 +36,7 @@ class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None
class Config:
arbitrary_types_allowed = True

View File

@ -13,6 +13,11 @@ AGENT_NEGATIVE_VALUE_STR = "no"
AGENT_ANSWER_SEPARATOR = "Answer:"
EMBEDDING_KEY = "embedding"
IS_KEYWORD_KEY = "is_keyword"
KEYWORDS_KEY = "keywords"
class AgentLLMErrorType(str, Enum):
TIMEOUT = "timeout"
RATE_LIMIT = "rate_limit"

View File

@ -15,6 +15,8 @@ from onyx.chat.stream_processing.answer_response_handler import (
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
# This is Legacy code that is not used anymore.
# It is kept here for reference.
class LLMResponseHandlerManager:
"""
This class is responsible for postprocessing the LLM response stream.

View File

@ -90,97 +90,97 @@ class CitationProcessor:
next(group for group in citation.groups() if group is not None)
)
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[
if not (1 <= numerical_value <= self.max_citation_num):
continue
context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
citation_order_idx = self.citation_order.index(final_citation_num) + 1
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
citation_order_idx = (
self.citation_order.index(final_citation_num) + 1
else:
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
else:
logger.warning(
"Manual LLM citation wasn't able to close brackets"
)
continue
link = context_llm_doc.link
self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
last_citation_end = end + length_to_add
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
else:
logger.warning(
"Manual LLM citation wasn't able to close brackets"
)
continue
link = context_llm_doc.link
self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
last_citation_end = end + length_to_add
if last_citation_end > 0:
result += self.curr_segment[:last_citation_end]

View File

@ -16,7 +16,7 @@ from onyx.db.models import SearchSettings
from onyx.indexing.models import BaseChunk
from onyx.indexing.models import IndexingSetting
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
MAX_METRICS_CONTENT = (
200 # Just need enough characters to identify where in the doc the chunk is
@ -151,6 +151,10 @@ class SearchRequest(ChunkContext):
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
model_config = ConfigDict(arbitrary_types_allowed=True)
precomputed_query_embedding: Embedding | None = None
precomputed_is_keyword: bool | None = None
precomputed_keywords: list[str] | None = None
class SearchQuery(ChunkContext):
"Processed Request that is directly passed to the SearchPipeline"
@ -175,6 +179,8 @@ class SearchQuery(ChunkContext):
offset: int = 0
model_config = ConfigDict(frozen=True)
precomputed_query_embedding: Embedding | None = None
class RetrievalDetails(ChunkContext):
# Use LLM to determine whether to do a retrieval or only rely on existing history

View File

@ -331,6 +331,14 @@ class SearchPipeline:
self._retrieved_sections = expanded_inference_sections
return expanded_inference_sections
@property
def retrieved_sections(self) -> list[InferenceSection]:
if self._retrieved_sections is not None:
return self._retrieved_sections
self._retrieved_sections = self._get_sections()
return self._retrieved_sections
@property
def reranked_sections(self) -> list[InferenceSection]:
"""Reranking is always done at the chunk level since section merging could create arbitrarily
@ -343,7 +351,7 @@ class SearchPipeline:
if self._reranked_sections is not None:
return self._reranked_sections
retrieved_sections = self._get_sections()
retrieved_sections = self.retrieved_sections
if self.retrieved_sections_callback is not None:
self.retrieved_sections_callback(retrieved_sections)

View File

@ -117,8 +117,12 @@ def retrieval_preprocessing(
else None
)
# Sometimes this is pre-computed in parallel with other heavy tasks to improve
# latency, and in that case we don't need to run the model again
run_query_analysis = (
None if skip_query_analysis else FunctionCall(query_analysis, (query,), {})
None
if (skip_query_analysis or search_request.precomputed_is_keyword is not None)
else FunctionCall(query_analysis, (query,), {})
)
functions_to_run = [
@ -143,11 +147,12 @@ def retrieval_preprocessing(
# The extracted keywords right now are not very reliable, not using for now
# Can maybe use for highlighting
is_keyword, extracted_keywords = (
parallel_results[run_query_analysis.result_id]
if run_query_analysis
else (False, None)
)
is_keyword, _extracted_keywords = False, None
if search_request.precomputed_is_keyword is not None:
is_keyword = search_request.precomputed_is_keyword
_extracted_keywords = search_request.precomputed_keywords
elif run_query_analysis:
is_keyword, _extracted_keywords = parallel_results[run_query_analysis.result_id]
all_query_terms = query.split()
processed_keywords = (
@ -247,4 +252,5 @@ def retrieval_preprocessing(
chunks_above=chunks_above,
chunks_below=chunks_below,
full_doc=search_request.full_doc,
precomputed_query_embedding=search_request.precomputed_query_embedding,
)

View File

@ -31,7 +31,7 @@ from onyx.utils.timing import log_function_time
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@ -109,6 +109,20 @@ def combine_retrieval_results(
return sorted_chunks
def get_query_embedding(query: str, db_session: Session) -> Embedding:
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode([query], text_type=EmbedTextType.QUERY)[0]
return query_embedding
@log_function_time(print_only=True)
def doc_index_retrieval(
query: SearchQuery,
@ -121,17 +135,10 @@ def doc_index_retrieval(
from the large chunks to the referenced chunks,
dedupes the chunks, and cleans the chunks.
"""
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
query_embedding = query.precomputed_query_embedding or get_query_embedding(
query.query, db_session
)
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
top_chunks = document_index.hybrid_retrieval(
query=query.query,
query_embedding=query_embedding,
@ -250,6 +257,9 @@ def retrieve_chunks(
simplified_queries.add(simplified_rephrase)
q_copy = query.copy(update={"query": rephrase}, deep=True)
q_copy.precomputed_query_embedding = (
None # need to recompute for each rephrase
)
run_queries.append(
(
doc_index_retrieval,

View File

@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection
from shared_configs.model_server_models import Embedding
class ToolResponse(BaseModel):
@ -60,11 +61,15 @@ class SearchQueryInfo(BaseModel):
recency_bias_multiplier: float
# None indicates that the default value should be used
class SearchToolOverrideKwargs(BaseModel):
force_no_rerank: bool
alternate_db_session: Session | None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
skip_query_analysis: bool
force_no_rerank: bool | None = None
alternate_db_session: Session | None = None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None = None
skip_query_analysis: bool | None = None
precomputed_query_embedding: Embedding | None = None
precomputed_is_keyword: bool | None = None
precomputed_keywords: list[str] | None = None
class Config:
arbitrary_types_allowed = True

View File

@ -3,6 +3,7 @@ from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import cast
from typing import TypeVar
from sqlalchemy.orm import Session
@ -11,7 +12,6 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ContextualPruningConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
@ -42,6 +42,9 @@ from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
from onyx.tools.tool_implementations.search_like_tool_utils import (
build_next_prompt_for_search_like_tool,
@ -281,16 +284,23 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
query = cast(str, llm_kwargs[QUERY_FIELD])
precomputed_query_embedding = None
precomputed_is_keyword = None
precomputed_keywords = None
force_no_rerank = False
alternate_db_session = None
retrieved_sections_callback = None
skip_query_analysis = False
if override_kwargs:
force_no_rerank = override_kwargs.force_no_rerank
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
alternate_db_session = override_kwargs.alternate_db_session
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
skip_query_analysis = override_kwargs.skip_query_analysis
skip_query_analysis = use_alt_not_None(
override_kwargs.skip_query_analysis, False
)
precomputed_query_embedding = override_kwargs.precomputed_query_embedding
precomputed_is_keyword = override_kwargs.precomputed_is_keyword
precomputed_keywords = override_kwargs.precomputed_keywords
if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
return
@ -327,6 +337,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
if self.retrieval_options
else None
),
precomputed_query_embedding=precomputed_query_embedding,
precomputed_is_keyword=precomputed_is_keyword,
precomputed_keywords=precomputed_keywords,
),
user=self.user,
llm=self.llm,
@ -345,8 +358,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
)
yield from yield_search_responses(
query,
search_pipeline.reranked_sections,
search_pipeline.final_context_sections,
lambda: search_pipeline.retrieved_sections,
lambda: search_pipeline.reranked_sections,
lambda: search_pipeline.final_context_sections,
search_query_info,
lambda: search_pipeline.section_relevance,
self,
@ -383,10 +397,16 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
# SearchTool passed in to allow for access to SearchTool properties.
# We can't just call SearchTool methods in the graph because we're operating on
# the retrieved docs (reranking, deduping, etc.) after the SearchTool has run.
#
# The various inference sections are passed in as functions to allow for lazy
# evaluation. The SearchPipeline object properties that they correspond to are
# actually functions defined with @property decorators, and passing them into
# this function causes them to get evaluated immediately which is undesirable.
def yield_search_responses(
query: str,
reranked_sections: list[InferenceSection],
final_context_sections: list[InferenceSection],
get_retrieved_sections: Callable[[], list[InferenceSection]],
get_reranked_sections: Callable[[], list[InferenceSection]],
get_final_context_sections: Callable[[], list[InferenceSection]],
search_query_info: SearchQueryInfo,
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
search_tool: SearchTool,
@ -395,7 +415,7 @@ def yield_search_responses(
id=SEARCH_RESPONSE_SUMMARY_ID,
response=SearchResponseSummary(
rephrased_query=query,
top_sections=final_context_sections,
top_sections=get_retrieved_sections(),
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=search_query_info.predicted_search,
final_filters=search_query_info.final_filters,
@ -407,13 +427,8 @@ def yield_search_responses(
id=SEARCH_DOC_CONTENT_ID,
response=OnyxContexts(
contexts=[
OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)
for section in reranked_sections
context_from_inference_section(section)
for section in get_reranked_sections()
]
),
)
@ -424,6 +439,7 @@ def yield_search_responses(
response=section_relevance,
)
final_context_sections = get_final_context_sections()
pruned_sections = prune_sections(
sections=final_context_sections,
section_relevance_list=section_relevance_list_impl(
@ -438,3 +454,10 @@ def yield_search_responses(
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
T = TypeVar("T")
def use_alt_not_None(value: T | None, alt: T) -> T:
return value if value is not None else alt

View File

@ -1,4 +1,5 @@
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceSection
from onyx.prompts.prompt_utils import clean_up_source
@ -29,3 +30,12 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict:
"%B %d, %Y %H:%M"
)
return doc_dict
def context_from_inference_section(section: InferenceSection) -> OnyxContext:
return OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)

View File

@ -1,6 +1,8 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import Generic
from typing import TypeVar
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
@ -11,10 +13,16 @@ from onyx.tools.tool import Tool
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
class ToolRunner:
def __init__(self, tool: Tool, args: dict[str, Any]):
R = TypeVar("R")
class ToolRunner(Generic[R]):
def __init__(
self, tool: Tool[R], args: dict[str, Any], override_kwargs: R | None = None
):
self.tool = tool
self.args = args
self.override_kwargs = override_kwargs
self._tool_responses: list[ToolResponse] | None = None
@ -27,7 +35,9 @@ class ToolRunner:
return
tool_responses: list[ToolResponse] = []
for tool_response in self.tool.run(**self.args):
for tool_response in self.tool.run(
override_kwargs=self.override_kwargs, **self.args
):
yield tool_response
tool_responses.append(tool_response)

View File

@ -118,7 +118,7 @@ def run_functions_in_parallel(
return results
class TimeoutThread(threading.Thread):
class TimeoutThread(threading.Thread, Generic[R]):
def __init__(
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
):
@ -159,3 +159,34 @@ def run_with_timeout(
task.end()
return task.result
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
# difficult to use. It's up to the programmer to call wait_on_background on the thread after
# the code you want to run in parallel is finished. As with all python thread parallelism,
# this is only useful for I/O bound tasks.
def run_in_background(
func: Callable[..., R], *args: Any, **kwargs: Any
) -> TimeoutThread[R]:
"""
Runs a function in a background thread. Returns a TimeoutThread object that can be used
to wait for the function to finish with wait_on_background.
"""
context = contextvars.copy_context()
# Timeout not used in the non-blocking case
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
task.start()
return task
def wait_on_background(task: TimeoutThread[R]) -> R:
"""
Used in conjunction with run_in_background. blocks until the task is finished,
then returns the result of the task.
"""
task.join()
if task.exception is not None:
raise task.exception
return task.result

View File

@ -1,8 +1,14 @@
import contextvars
import time
import pytest
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.threadpool_concurrency import wait_on_background
# Create a context variable for testing
test_context_var = contextvars.ContextVar("test_var", default="default")
def test_run_with_timeout_completes() -> None:
@ -59,3 +65,86 @@ def test_run_with_timeout_with_args_and_kwargs() -> None:
# Test with positional and keyword args
result2 = run_with_timeout(1.0, complex_function, x=5, y=3, multiply=True)
assert result2 == 15
def test_run_in_background_and_wait_success() -> None:
"""Test that run_in_background and wait_on_background work correctly for successful execution"""
def background_function(x: int) -> int:
time.sleep(0.1) # Small delay to ensure it's actually running in background
return x * 2
# Start the background task
task = run_in_background(background_function, 21)
# Verify we can do other work while task is running
start_time = time.time()
result = wait_on_background(task)
elapsed = time.time() - start_time
assert result == 42
assert elapsed >= 0.1 # Verify we actually waited for the sleep
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
def test_run_in_background_propagates_exceptions() -> None:
"""Test that exceptions in background tasks are properly propagated"""
def error_function() -> None:
time.sleep(0.1) # Small delay to ensure it's actually running in background
raise ValueError("Test background error")
task = run_in_background(error_function)
with pytest.raises(ValueError) as exc_info:
wait_on_background(task)
assert "Test background error" in str(exc_info.value)
def test_run_in_background_with_args_and_kwargs() -> None:
"""Test that args and kwargs are properly passed to the background function"""
def complex_function(x: int, y: int, multiply: bool = False) -> int:
time.sleep(0.1) # Small delay to ensure it's actually running in background
if multiply:
return x * y
return x + y
# Test with args
task1 = run_in_background(complex_function, 5, 3)
result1 = wait_on_background(task1)
assert result1 == 8
# Test with args and kwargs
task2 = run_in_background(complex_function, 5, 3, multiply=True)
result2 = wait_on_background(task2)
assert result2 == 15
def test_multiple_background_tasks() -> None:
"""Test running multiple background tasks concurrently"""
def slow_add(x: int, y: int) -> int:
time.sleep(0.2) # Make each task take some time
return x + y
# Start multiple tasks
start_time = time.time()
task1 = run_in_background(slow_add, 1, 2)
task2 = run_in_background(slow_add, 3, 4)
task3 = run_in_background(slow_add, 5, 6)
# Wait for all results
result1 = wait_on_background(task1)
result2 = wait_on_background(task2)
result3 = wait_on_background(task3)
elapsed = time.time() - start_time
# Verify results
assert result1 == 3
assert result2 == 7
assert result3 == 11
# Verify tasks ran in parallel (total time should be ~0.2s, not ~0.6s)
assert 0.2 <= elapsed < 0.4 # Allow some buffer for test environment variations

View File

@ -4,7 +4,9 @@ import time
from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.threadpool_concurrency import wait_on_background
# Create a test contextvar
test_var = contextvars.ContextVar("test_var", default="default")
@ -129,3 +131,39 @@ def test_contextvar_isolation_between_runs() -> None:
# Verify second run results
assert all(result in ["thread3", "thread4"] for result in second_results)
def test_run_in_background_preserves_contextvar() -> None:
"""Test that run_in_background preserves contextvar values and modifications are isolated"""
def modify_and_sleep() -> tuple[str, str]:
"""Modifies contextvar, sleeps, and returns original, modified, and final values"""
original = test_var.get()
test_var.set("modified_in_background")
time.sleep(0.1) # Ensure we can check main thread during execution
final = test_var.get()
return original, final
# Set initial value in main thread
token = test_var.set("initial_value")
try:
# Start background task
task = run_in_background(modify_and_sleep)
# Verify main thread value remains unchanged while task runs
assert test_var.get() == "initial_value"
# Get results from background thread
original, modified = wait_on_background(task)
# Verify the background thread:
# 1. Saw the initial value
assert original == "initial_value"
# 2. Successfully modified its own copy
assert modified == "modified_in_background"
# Verify main thread value is still unchanged after task completion
assert test_var.get() == "initial_value"
finally:
# Clean up
test_var.reset(token)