mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 13:22:42 +01:00
improved basic search latency (#4186)
* improved basic search latency * address PR comments + minor cleanup
This commit is contained in:
parent
29382656fc
commit
b7da91e3ae
@ -153,8 +153,9 @@ def generate_initial_answer(
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
|
@ -179,8 +179,9 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
|
@ -13,7 +13,6 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@ -144,8 +143,6 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
|
||||
if result.query_info is not None:
|
||||
query_info = result.query_info
|
||||
break
|
||||
return query_info or SearchQueryInfo(
|
||||
predicted_search=None,
|
||||
final_filters=IndexFilters(access_control_list=None),
|
||||
recency_bias_multiplier=1.0,
|
||||
)
|
||||
|
||||
assert query_info is not None, "must have query info"
|
||||
return query_info
|
||||
|
@ -56,8 +56,9 @@ def format_results(
|
||||
relevance_list = relevance_from_docs(reranked_documents)
|
||||
for tool_response in yield_search_responses(
|
||||
query=state.question,
|
||||
reranked_sections=state.retrieved_documents,
|
||||
final_context_sections=reranked_documents,
|
||||
get_retrieved_sections=lambda: reranked_documents,
|
||||
get_reranked_sections=lambda: state.retrieved_documents,
|
||||
get_final_context_sections=lambda: reranked_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
|
@ -91,7 +91,7 @@ def retrieve_documents(
|
||||
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
|
||||
|
||||
if AGENT_RETRIEVAL_STATS:
|
||||
pre_rerank_docs = callback_container[0]
|
||||
pre_rerank_docs = callback_container[0] if callback_container else []
|
||||
fit_scores = get_fit_scores(
|
||||
pre_rerank_docs,
|
||||
retrieved_docs,
|
||||
|
@ -44,7 +44,9 @@ def call_tool(
|
||||
tool = tool_choice.tool
|
||||
tool_args = tool_choice.tool_args
|
||||
tool_id = tool_choice.id
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_runner = ToolRunner(
|
||||
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
|
||||
)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
emit_packet(tool_kickoff, writer)
|
||||
|
@ -15,8 +15,17 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import TimeoutThread
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -25,6 +34,7 @@ logger = setup_logger()
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
@log_function_time(print_only=True)
|
||||
def choose_tool(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
@ -37,6 +47,31 @@ def choose_tool(
|
||||
should_stream_answer = state.should_stream_answer
|
||||
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
embedding_thread: TimeoutThread[Embedding] | None = None
|
||||
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
|
||||
override_kwargs: SearchToolOverrideKwargs | None = None
|
||||
if (
|
||||
not agent_config.behavior.use_agentic_search
|
||||
and agent_config.tooling.search_tool is not None
|
||||
and (
|
||||
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
|
||||
)
|
||||
):
|
||||
override_kwargs = SearchToolOverrideKwargs()
|
||||
# Run in a background thread to avoid blocking the main thread
|
||||
embedding_thread = run_in_background(
|
||||
get_query_embedding,
|
||||
agent_config.inputs.search_request.query,
|
||||
agent_config.persistence.db_session,
|
||||
)
|
||||
keyword_thread = run_in_background(
|
||||
query_analysis,
|
||||
agent_config.inputs.search_request.query,
|
||||
)
|
||||
|
||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||
|
||||
@ -47,7 +82,6 @@ def choose_tool(
|
||||
tools = [
|
||||
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
|
||||
]
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
@ -71,11 +105,22 @@ def choose_tool(
|
||||
# If we have a tool and tool args, we are ready to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
if embedding_thread and tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@ -153,10 +198,22 @@ def choose_tool(
|
||||
logger.debug(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
if embedding_thread and selected_tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and selected_tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
|
@ -9,18 +9,23 @@ from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_DOC_CONTENT_ID,
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
context_from_inference_section,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def basic_use_tool_response(
|
||||
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BasicOutput:
|
||||
@ -50,11 +55,13 @@ def basic_use_tool_response(
|
||||
for yield_item in tool_call_responses:
|
||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
|
||||
search_contexts = cast(OnyxContexts, yield_item.response).contexts
|
||||
for doc in search_contexts:
|
||||
if doc.document_id not in initial_search_results:
|
||||
initial_search_results.append(doc)
|
||||
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
search_response_summary = cast(SearchResponseSummary, yield_item.response)
|
||||
for section in search_response_summary.top_sections:
|
||||
if section.center_chunk.document_id not in initial_search_results:
|
||||
initial_search_results.append(
|
||||
context_from_inference_section(section)
|
||||
)
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||
|
@ -2,6 +2,7 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
@ -35,6 +36,7 @@ class ToolChoice(BaseModel):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
@ -13,6 +13,11 @@ AGENT_NEGATIVE_VALUE_STR = "no"
|
||||
AGENT_ANSWER_SEPARATOR = "Answer:"
|
||||
|
||||
|
||||
EMBEDDING_KEY = "embedding"
|
||||
IS_KEYWORD_KEY = "is_keyword"
|
||||
KEYWORDS_KEY = "keywords"
|
||||
|
||||
|
||||
class AgentLLMErrorType(str, Enum):
|
||||
TIMEOUT = "timeout"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
|
@ -15,6 +15,8 @@ from onyx.chat.stream_processing.answer_response_handler import (
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
# This is Legacy code that is not used anymore.
|
||||
# It is kept here for reference.
|
||||
class LLMResponseHandlerManager:
|
||||
"""
|
||||
This class is responsible for postprocessing the LLM response stream.
|
||||
|
@ -90,97 +90,97 @@ class CitationProcessor:
|
||||
next(group for group in citation.groups() if group is not None)
|
||||
)
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
final_citation_num = self.final_order_mapping[
|
||||
if not (1 <= numerical_value <= self.max_citation_num):
|
||||
continue
|
||||
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
final_citation_num = self.final_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
|
||||
if final_citation_num not in self.citation_order:
|
||||
self.citation_order.append(final_citation_num)
|
||||
|
||||
citation_order_idx = self.citation_order.index(final_citation_num) + 1
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_order_mapping:
|
||||
displayed_citation_num = self.display_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
|
||||
if final_citation_num not in self.citation_order:
|
||||
self.citation_order.append(final_citation_num)
|
||||
|
||||
citation_order_idx = (
|
||||
self.citation_order.index(final_citation_num) + 1
|
||||
else:
|
||||
displayed_citation_num = final_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_order_mapping:
|
||||
displayed_citation_num = self.display_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
else:
|
||||
displayed_citation_num = final_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if final_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: length_to_add + start]
|
||||
+ self.curr_segment[real_start + diff :]
|
||||
)
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
# Handle edge case where LLM outputs citation itself
|
||||
if self.curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
|
||||
if match:
|
||||
try:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# citation_num is now the number post initial ranking, i.e. as displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Manual LLM citation didn't properly cite documents {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Manual LLM citation wasn't able to close brackets"
|
||||
)
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(final_citation_num)
|
||||
|
||||
if citation_order_idx not in self.cited_inds:
|
||||
self.cited_inds.add(citation_order_idx)
|
||||
yield CitationInfo(
|
||||
# citation number is now the one that was displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if final_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
else:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: length_to_add + start]
|
||||
+ self.curr_segment[real_start + diff :]
|
||||
)
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
last_citation_end = end + length_to_add
|
||||
# Handle edge case where LLM outputs citation itself
|
||||
if self.curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
|
||||
if match:
|
||||
try:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# citation_num is now the number post initial ranking, i.e. as displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Manual LLM citation didn't properly cite documents {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Manual LLM citation wasn't able to close brackets"
|
||||
)
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(final_citation_num)
|
||||
|
||||
if citation_order_idx not in self.cited_inds:
|
||||
self.cited_inds.add(citation_order_idx)
|
||||
yield CitationInfo(
|
||||
# citation number is now the one that was displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
start, end = citation.span()
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
else:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
|
||||
last_citation_end = end + length_to_add
|
||||
|
||||
if last_citation_end > 0:
|
||||
result += self.curr_segment[:last_citation_end]
|
||||
|
@ -16,7 +16,7 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.models import BaseChunk
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
MAX_METRICS_CONTENT = (
|
||||
200 # Just need enough characters to identify where in the doc the chunk is
|
||||
@ -151,6 +151,10 @@ class SearchRequest(ChunkContext):
|
||||
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
precomputed_is_keyword: bool | None = None
|
||||
precomputed_keywords: list[str] | None = None
|
||||
|
||||
|
||||
class SearchQuery(ChunkContext):
|
||||
"Processed Request that is directly passed to the SearchPipeline"
|
||||
@ -175,6 +179,8 @@ class SearchQuery(ChunkContext):
|
||||
offset: int = 0
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
|
||||
|
||||
class RetrievalDetails(ChunkContext):
|
||||
# Use LLM to determine whether to do a retrieval or only rely on existing history
|
||||
|
@ -331,6 +331,14 @@ class SearchPipeline:
|
||||
self._retrieved_sections = expanded_inference_sections
|
||||
return expanded_inference_sections
|
||||
|
||||
@property
|
||||
def retrieved_sections(self) -> list[InferenceSection]:
|
||||
if self._retrieved_sections is not None:
|
||||
return self._retrieved_sections
|
||||
|
||||
self._retrieved_sections = self._get_sections()
|
||||
return self._retrieved_sections
|
||||
|
||||
@property
|
||||
def reranked_sections(self) -> list[InferenceSection]:
|
||||
"""Reranking is always done at the chunk level since section merging could create arbitrarily
|
||||
@ -343,7 +351,7 @@ class SearchPipeline:
|
||||
if self._reranked_sections is not None:
|
||||
return self._reranked_sections
|
||||
|
||||
retrieved_sections = self._get_sections()
|
||||
retrieved_sections = self.retrieved_sections
|
||||
if self.retrieved_sections_callback is not None:
|
||||
self.retrieved_sections_callback(retrieved_sections)
|
||||
|
||||
|
@ -117,8 +117,12 @@ def retrieval_preprocessing(
|
||||
else None
|
||||
)
|
||||
|
||||
# Sometimes this is pre-computed in parallel with other heavy tasks to improve
|
||||
# latency, and in that case we don't need to run the model again
|
||||
run_query_analysis = (
|
||||
None if skip_query_analysis else FunctionCall(query_analysis, (query,), {})
|
||||
None
|
||||
if (skip_query_analysis or search_request.precomputed_is_keyword is not None)
|
||||
else FunctionCall(query_analysis, (query,), {})
|
||||
)
|
||||
|
||||
functions_to_run = [
|
||||
@ -143,11 +147,12 @@ def retrieval_preprocessing(
|
||||
|
||||
# The extracted keywords right now are not very reliable, not using for now
|
||||
# Can maybe use for highlighting
|
||||
is_keyword, extracted_keywords = (
|
||||
parallel_results[run_query_analysis.result_id]
|
||||
if run_query_analysis
|
||||
else (False, None)
|
||||
)
|
||||
is_keyword, _extracted_keywords = False, None
|
||||
if search_request.precomputed_is_keyword is not None:
|
||||
is_keyword = search_request.precomputed_is_keyword
|
||||
_extracted_keywords = search_request.precomputed_keywords
|
||||
elif run_query_analysis:
|
||||
is_keyword, _extracted_keywords = parallel_results[run_query_analysis.result_id]
|
||||
|
||||
all_query_terms = query.split()
|
||||
processed_keywords = (
|
||||
@ -247,4 +252,5 @@ def retrieval_preprocessing(
|
||||
chunks_above=chunks_above,
|
||||
chunks_below=chunks_below,
|
||||
full_doc=search_request.full_doc,
|
||||
precomputed_query_embedding=search_request.precomputed_query_embedding,
|
||||
)
|
||||
|
@ -31,7 +31,7 @@ from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -109,6 +109,20 @@ def combine_retrieval_results(
|
||||
return sorted_chunks
|
||||
|
||||
|
||||
def get_query_embedding(query: str, db_session: Session) -> Embedding:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
query_embedding = model.encode([query], text_type=EmbedTextType.QUERY)[0]
|
||||
return query_embedding
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def doc_index_retrieval(
|
||||
query: SearchQuery,
|
||||
@ -121,17 +135,10 @@ def doc_index_retrieval(
|
||||
from the large chunks to the referenced chunks,
|
||||
dedupes the chunks, and cleans the chunks.
|
||||
"""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
query_embedding = query.precomputed_query_embedding or get_query_embedding(
|
||||
query.query, db_session
|
||||
)
|
||||
|
||||
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
|
||||
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
@ -250,6 +257,9 @@ def retrieve_chunks(
|
||||
simplified_queries.add(simplified_rephrase)
|
||||
|
||||
q_copy = query.copy(update={"query": rephrase}, deep=True)
|
||||
q_copy.precomputed_query_embedding = (
|
||||
None # need to recompute for each rephrase
|
||||
)
|
||||
run_queries.append(
|
||||
(
|
||||
doc_index_retrieval,
|
||||
|
@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
@ -60,11 +61,15 @@ class SearchQueryInfo(BaseModel):
|
||||
recency_bias_multiplier: float
|
||||
|
||||
|
||||
# None indicates that the default value should be used
|
||||
class SearchToolOverrideKwargs(BaseModel):
|
||||
force_no_rerank: bool
|
||||
alternate_db_session: Session | None
|
||||
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
|
||||
skip_query_analysis: bool
|
||||
force_no_rerank: bool | None = None
|
||||
alternate_db_session: Session | None = None
|
||||
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None = None
|
||||
skip_query_analysis: bool | None = None
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
precomputed_is_keyword: bool | None = None
|
||||
precomputed_keywords: list[str] | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
@ -3,6 +3,7 @@ from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -11,7 +12,6 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import ContextualPruningConfig
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
@ -42,6 +42,9 @@ from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
context_from_inference_section,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
build_next_prompt_for_search_like_tool,
|
||||
@ -281,16 +284,23 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, llm_kwargs[QUERY_FIELD])
|
||||
precomputed_query_embedding = None
|
||||
precomputed_is_keyword = None
|
||||
precomputed_keywords = None
|
||||
force_no_rerank = False
|
||||
alternate_db_session = None
|
||||
retrieved_sections_callback = None
|
||||
skip_query_analysis = False
|
||||
if override_kwargs:
|
||||
force_no_rerank = override_kwargs.force_no_rerank
|
||||
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
|
||||
alternate_db_session = override_kwargs.alternate_db_session
|
||||
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
|
||||
skip_query_analysis = override_kwargs.skip_query_analysis
|
||||
|
||||
skip_query_analysis = use_alt_not_None(
|
||||
override_kwargs.skip_query_analysis, False
|
||||
)
|
||||
precomputed_query_embedding = override_kwargs.precomputed_query_embedding
|
||||
precomputed_is_keyword = override_kwargs.precomputed_is_keyword
|
||||
precomputed_keywords = override_kwargs.precomputed_keywords
|
||||
if self.selected_sections:
|
||||
yield from self._build_response_for_specified_sections(query)
|
||||
return
|
||||
@ -327,6 +337,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
if self.retrieval_options
|
||||
else None
|
||||
),
|
||||
precomputed_query_embedding=precomputed_query_embedding,
|
||||
precomputed_is_keyword=precomputed_is_keyword,
|
||||
precomputed_keywords=precomputed_keywords,
|
||||
),
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
@ -345,8 +358,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
)
|
||||
yield from yield_search_responses(
|
||||
query,
|
||||
search_pipeline.reranked_sections,
|
||||
search_pipeline.final_context_sections,
|
||||
lambda: search_pipeline.retrieved_sections,
|
||||
lambda: search_pipeline.reranked_sections,
|
||||
lambda: search_pipeline.final_context_sections,
|
||||
search_query_info,
|
||||
lambda: search_pipeline.section_relevance,
|
||||
self,
|
||||
@ -383,10 +397,16 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
# SearchTool passed in to allow for access to SearchTool properties.
|
||||
# We can't just call SearchTool methods in the graph because we're operating on
|
||||
# the retrieved docs (reranking, deduping, etc.) after the SearchTool has run.
|
||||
#
|
||||
# The various inference sections are passed in as functions to allow for lazy
|
||||
# evaluation. The SearchPipeline object properties that they correspond to are
|
||||
# actually functions defined with @property decorators, and passing them into
|
||||
# this function causes them to get evaluated immediately which is undesirable.
|
||||
def yield_search_responses(
|
||||
query: str,
|
||||
reranked_sections: list[InferenceSection],
|
||||
final_context_sections: list[InferenceSection],
|
||||
get_retrieved_sections: Callable[[], list[InferenceSection]],
|
||||
get_reranked_sections: Callable[[], list[InferenceSection]],
|
||||
get_final_context_sections: Callable[[], list[InferenceSection]],
|
||||
search_query_info: SearchQueryInfo,
|
||||
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
|
||||
search_tool: SearchTool,
|
||||
@ -395,7 +415,7 @@ def yield_search_responses(
|
||||
id=SEARCH_RESPONSE_SUMMARY_ID,
|
||||
response=SearchResponseSummary(
|
||||
rephrased_query=query,
|
||||
top_sections=final_context_sections,
|
||||
top_sections=get_retrieved_sections(),
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=search_query_info.predicted_search,
|
||||
final_filters=search_query_info.final_filters,
|
||||
@ -407,13 +427,8 @@ def yield_search_responses(
|
||||
id=SEARCH_DOC_CONTENT_ID,
|
||||
response=OnyxContexts(
|
||||
contexts=[
|
||||
OnyxContext(
|
||||
content=section.combined_content,
|
||||
document_id=section.center_chunk.document_id,
|
||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||
blurb=section.center_chunk.blurb,
|
||||
)
|
||||
for section in reranked_sections
|
||||
context_from_inference_section(section)
|
||||
for section in get_reranked_sections()
|
||||
]
|
||||
),
|
||||
)
|
||||
@ -424,6 +439,7 @@ def yield_search_responses(
|
||||
response=section_relevance,
|
||||
)
|
||||
|
||||
final_context_sections = get_final_context_sections()
|
||||
pruned_sections = prune_sections(
|
||||
sections=final_context_sections,
|
||||
section_relevance_list=section_relevance_list_impl(
|
||||
@ -438,3 +454,10 @@ def yield_search_responses(
|
||||
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def use_alt_not_None(value: T | None, alt: T) -> T:
|
||||
return value if value is not None else alt
|
||||
|
@ -1,4 +1,5 @@
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.prompt_utils import clean_up_source
|
||||
|
||||
@ -29,3 +30,12 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict:
|
||||
"%B %d, %Y %H:%M"
|
||||
)
|
||||
return doc_dict
|
||||
|
||||
|
||||
def context_from_inference_section(section: InferenceSection) -> OnyxContext:
|
||||
return OnyxContext(
|
||||
content=section.combined_content,
|
||||
document_id=section.center_chunk.document_id,
|
||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||
blurb=section.center_chunk.blurb,
|
||||
)
|
||||
|
@ -1,6 +1,8 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
@ -11,10 +13,16 @@ from onyx.tools.tool import Tool
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
class ToolRunner:
|
||||
def __init__(self, tool: Tool, args: dict[str, Any]):
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class ToolRunner(Generic[R]):
|
||||
def __init__(
|
||||
self, tool: Tool[R], args: dict[str, Any], override_kwargs: R | None = None
|
||||
):
|
||||
self.tool = tool
|
||||
self.args = args
|
||||
self.override_kwargs = override_kwargs
|
||||
|
||||
self._tool_responses: list[ToolResponse] | None = None
|
||||
|
||||
@ -27,7 +35,9 @@ class ToolRunner:
|
||||
return
|
||||
|
||||
tool_responses: list[ToolResponse] = []
|
||||
for tool_response in self.tool.run(**self.args):
|
||||
for tool_response in self.tool.run(
|
||||
override_kwargs=self.override_kwargs, **self.args
|
||||
):
|
||||
yield tool_response
|
||||
tool_responses.append(tool_response)
|
||||
|
||||
|
@ -118,7 +118,7 @@ def run_functions_in_parallel(
|
||||
return results
|
||||
|
||||
|
||||
class TimeoutThread(threading.Thread):
|
||||
class TimeoutThread(threading.Thread, Generic[R]):
|
||||
def __init__(
|
||||
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
):
|
||||
@ -159,3 +159,34 @@ def run_with_timeout(
|
||||
task.end()
|
||||
|
||||
return task.result
|
||||
|
||||
|
||||
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
|
||||
# difficult to use. It's up to the programmer to call wait_on_background on the thread after
|
||||
# the code you want to run in parallel is finished. As with all python thread parallelism,
|
||||
# this is only useful for I/O bound tasks.
|
||||
def run_in_background(
|
||||
func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
) -> TimeoutThread[R]:
|
||||
"""
|
||||
Runs a function in a background thread. Returns a TimeoutThread object that can be used
|
||||
to wait for the function to finish with wait_on_background.
|
||||
"""
|
||||
context = contextvars.copy_context()
|
||||
# Timeout not used in the non-blocking case
|
||||
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
|
||||
task.start()
|
||||
return task
|
||||
|
||||
|
||||
def wait_on_background(task: TimeoutThread[R]) -> R:
|
||||
"""
|
||||
Used in conjunction with run_in_background. blocks until the task is finished,
|
||||
then returns the result of the task.
|
||||
"""
|
||||
task.join()
|
||||
|
||||
if task.exception is not None:
|
||||
raise task.exception
|
||||
|
||||
return task.result
|
||||
|
@ -1,8 +1,14 @@
|
||||
import contextvars
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Create a context variable for testing
|
||||
test_context_var = contextvars.ContextVar("test_var", default="default")
|
||||
|
||||
|
||||
def test_run_with_timeout_completes() -> None:
|
||||
@ -59,3 +65,86 @@ def test_run_with_timeout_with_args_and_kwargs() -> None:
|
||||
# Test with positional and keyword args
|
||||
result2 = run_with_timeout(1.0, complex_function, x=5, y=3, multiply=True)
|
||||
assert result2 == 15
|
||||
|
||||
|
||||
def test_run_in_background_and_wait_success() -> None:
|
||||
"""Test that run_in_background and wait_on_background work correctly for successful execution"""
|
||||
|
||||
def background_function(x: int) -> int:
|
||||
time.sleep(0.1) # Small delay to ensure it's actually running in background
|
||||
return x * 2
|
||||
|
||||
# Start the background task
|
||||
task = run_in_background(background_function, 21)
|
||||
|
||||
# Verify we can do other work while task is running
|
||||
start_time = time.time()
|
||||
result = wait_on_background(task)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
assert result == 42
|
||||
assert elapsed >= 0.1 # Verify we actually waited for the sleep
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
|
||||
def test_run_in_background_propagates_exceptions() -> None:
|
||||
"""Test that exceptions in background tasks are properly propagated"""
|
||||
|
||||
def error_function() -> None:
|
||||
time.sleep(0.1) # Small delay to ensure it's actually running in background
|
||||
raise ValueError("Test background error")
|
||||
|
||||
task = run_in_background(error_function)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
wait_on_background(task)
|
||||
|
||||
assert "Test background error" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_run_in_background_with_args_and_kwargs() -> None:
|
||||
"""Test that args and kwargs are properly passed to the background function"""
|
||||
|
||||
def complex_function(x: int, y: int, multiply: bool = False) -> int:
|
||||
time.sleep(0.1) # Small delay to ensure it's actually running in background
|
||||
if multiply:
|
||||
return x * y
|
||||
return x + y
|
||||
|
||||
# Test with args
|
||||
task1 = run_in_background(complex_function, 5, 3)
|
||||
result1 = wait_on_background(task1)
|
||||
assert result1 == 8
|
||||
|
||||
# Test with args and kwargs
|
||||
task2 = run_in_background(complex_function, 5, 3, multiply=True)
|
||||
result2 = wait_on_background(task2)
|
||||
assert result2 == 15
|
||||
|
||||
|
||||
def test_multiple_background_tasks() -> None:
|
||||
"""Test running multiple background tasks concurrently"""
|
||||
|
||||
def slow_add(x: int, y: int) -> int:
|
||||
time.sleep(0.2) # Make each task take some time
|
||||
return x + y
|
||||
|
||||
# Start multiple tasks
|
||||
start_time = time.time()
|
||||
task1 = run_in_background(slow_add, 1, 2)
|
||||
task2 = run_in_background(slow_add, 3, 4)
|
||||
task3 = run_in_background(slow_add, 5, 6)
|
||||
|
||||
# Wait for all results
|
||||
result1 = wait_on_background(task1)
|
||||
result2 = wait_on_background(task2)
|
||||
result3 = wait_on_background(task3)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Verify results
|
||||
assert result1 == 3
|
||||
assert result2 == 7
|
||||
assert result3 == 11
|
||||
|
||||
# Verify tasks ran in parallel (total time should be ~0.2s, not ~0.6s)
|
||||
assert 0.2 <= elapsed < 0.4 # Allow some buffer for test environment variations
|
||||
|
@ -4,7 +4,9 @@ import time
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Create a test contextvar
|
||||
test_var = contextvars.ContextVar("test_var", default="default")
|
||||
@ -129,3 +131,39 @@ def test_contextvar_isolation_between_runs() -> None:
|
||||
|
||||
# Verify second run results
|
||||
assert all(result in ["thread3", "thread4"] for result in second_results)
|
||||
|
||||
|
||||
def test_run_in_background_preserves_contextvar() -> None:
|
||||
"""Test that run_in_background preserves contextvar values and modifications are isolated"""
|
||||
|
||||
def modify_and_sleep() -> tuple[str, str]:
|
||||
"""Modifies contextvar, sleeps, and returns original, modified, and final values"""
|
||||
original = test_var.get()
|
||||
test_var.set("modified_in_background")
|
||||
time.sleep(0.1) # Ensure we can check main thread during execution
|
||||
final = test_var.get()
|
||||
return original, final
|
||||
|
||||
# Set initial value in main thread
|
||||
token = test_var.set("initial_value")
|
||||
try:
|
||||
# Start background task
|
||||
task = run_in_background(modify_and_sleep)
|
||||
|
||||
# Verify main thread value remains unchanged while task runs
|
||||
assert test_var.get() == "initial_value"
|
||||
|
||||
# Get results from background thread
|
||||
original, modified = wait_on_background(task)
|
||||
|
||||
# Verify the background thread:
|
||||
# 1. Saw the initial value
|
||||
assert original == "initial_value"
|
||||
# 2. Successfully modified its own copy
|
||||
assert modified == "modified_in_background"
|
||||
|
||||
# Verify main thread value is still unchanged after task completion
|
||||
assert test_var.get() == "initial_value"
|
||||
finally:
|
||||
# Clean up
|
||||
test_var.reset(token)
|
||||
|
Loading…
x
Reference in New Issue
Block a user