Fix Agent Slowness (#3979)

This commit is contained in:
Yuhong Sun
2025-02-13 15:54:34 -08:00
committed by GitHub
parent c6434db7eb
commit 1a7aca06b9
14 changed files with 87 additions and 28 deletions

View File

@@ -83,6 +83,7 @@ def handle_search_request(
user=user, user=user,
llm=llm, llm=llm,
fast_llm=fast_llm, fast_llm=fast_llm,
skip_query_analysis=False,
db_session=db_session, db_session=db_session,
bypass_acl=False, bypass_acl=False,
) )

View File

@@ -23,6 +23,7 @@ from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
from onyx.context.search.models import InferenceSection from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager from onyx.db.engine import get_session_context_manager
from onyx.tools.models import SearchQueryInfo from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import ( from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID, SEARCH_RESPONSE_SUMMARY_ID,
) )
@@ -67,9 +68,12 @@ def retrieve_documents(
with get_session_context_manager() as db_session: with get_session_context_manager() as db_session:
for tool_response in search_tool.run( for tool_response in search_tool.run(
query=query_to_retrieve, query=query_to_retrieve,
force_no_rerank=True, override_kwargs=SearchToolOverrideKwargs(
alternate_db_session=db_session, force_no_rerank=True,
retrieved_sections_callback=callback_container.append, alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=not state.base_search,
),
): ):
# get retrieved docs to send to the rest of the graph # get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:

View File

@@ -58,6 +58,7 @@ from onyx.prompts.agent_search import (
) )
from onyx.prompts.prompt_utils import handle_onyx_date_awareness from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_constructor import SearchToolConfig from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import ( from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID, SEARCH_RESPONSE_SUMMARY_ID,
@@ -218,7 +219,10 @@ def get_test_config(
using_tool_calling_llm=using_tool_calling_llm, using_tool_calling_llm=using_tool_calling_llm,
) )
chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID") chat_session_id = (
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
or "00000000-0000-0000-0000-000000000000"
)
assert ( assert (
chat_session_id is not None chat_session_id is not None
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests" ), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
@@ -341,8 +345,12 @@ def retrieve_search_docs(
with get_session_context_manager() as db_session: with get_session_context_manager() as db_session:
for tool_response in search_tool.run( for tool_response in search_tool.run(
query=question, query=question,
force_no_rerank=True, override_kwargs=SearchToolOverrideKwargs(
alternate_db_session=db_session, force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=None,
skip_query_analysis=False,
),
): ):
# get retrieved docs to send to the rest of the graph # get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:

View File

@@ -51,6 +51,7 @@ class SearchPipeline:
user: User | None, user: User | None,
llm: LLM, llm: LLM,
fast_llm: LLM, fast_llm: LLM,
skip_query_analysis: bool,
db_session: Session, db_session: Session,
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
retrieval_metrics_callback: ( retrieval_metrics_callback: (
@@ -67,6 +68,7 @@ class SearchPipeline:
self.user = user self.user = user
self.llm = llm self.llm = llm
self.fast_llm = fast_llm self.fast_llm = fast_llm
self.skip_query_analysis = skip_query_analysis
self.db_session = db_session self.db_session = db_session
self.bypass_acl = bypass_acl self.bypass_acl = bypass_acl
self.retrieval_metrics_callback = retrieval_metrics_callback self.retrieval_metrics_callback = retrieval_metrics_callback
@@ -108,6 +110,7 @@ class SearchPipeline:
search_request=self.search_request, search_request=self.search_request,
user=self.user, user=self.user,
llm=self.llm, llm=self.llm,
skip_query_analysis=self.skip_query_analysis,
db_session=self.db_session, db_session=self.db_session,
bypass_acl=self.bypass_acl, bypass_acl=self.bypass_acl,
) )
@@ -162,6 +165,12 @@ class SearchPipeline:
that have a corresponding chunk. that have a corresponding chunk.
This step should be fast for any document index implementation. This step should be fast for any document index implementation.
Current implementation timing is approximately broken down in timing as:
- 200 ms to get the embedding of the query
- 15 ms to get chunks from the document index
- possibly more to get additional surrounding chunks
- possibly more for query expansion (multilingual)
""" """
if self._retrieved_sections is not None: if self._retrieved_sections is not None:
return self._retrieved_sections return self._retrieved_sections

View File

@@ -50,11 +50,11 @@ def retrieval_preprocessing(
search_request: SearchRequest, search_request: SearchRequest,
user: User | None, user: User | None,
llm: LLM, llm: LLM,
skip_query_analysis: bool,
db_session: Session, db_session: Session,
bypass_acl: bool = False,
skip_query_analysis: bool = False,
base_recency_decay: float = BASE_RECENCY_DECAY,
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER, favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
base_recency_decay: float = BASE_RECENCY_DECAY,
bypass_acl: bool = False,
) -> SearchQuery: ) -> SearchQuery:
"""Logic is as follows: """Logic is as follows:
Any global disables apply first Any global disables apply first
@@ -146,7 +146,7 @@ def retrieval_preprocessing(
is_keyword, extracted_keywords = ( is_keyword, extracted_keywords = (
parallel_results[run_query_analysis.result_id] parallel_results[run_query_analysis.result_id]
if run_query_analysis if run_query_analysis
else (None, None) else (False, None)
) )
all_query_terms = query.split() all_query_terms = query.split()

View File

@@ -99,7 +99,7 @@ def _check_tokenizer_cache(
if not tokenizer: if not tokenizer:
logger.info( logger.info(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}" f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}"
) )
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL) tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)

View File

@@ -76,6 +76,7 @@ def gpt_search(
user=None, user=None,
llm=llm, llm=llm,
fast_llm=fast_llm, fast_llm=fast_llm,
skip_query_analysis=True,
db_session=db_session, db_session=db_session,
).reranked_sections ).reranked_sections

View File

@@ -34,7 +34,7 @@ Now respond to the following:
""".strip() """.strip()
class BaseTool(Tool): class BaseTool(Tool[None]):
def build_next_prompt( def build_next_prompt(
self, self,
prompt_builder: "AnswerPromptBuilder", prompt_builder: "AnswerPromptBuilder",

View File

@@ -1,11 +1,14 @@
from collections.abc import Callable
from typing import Any from typing import Any
from uuid import UUID from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
from pydantic import model_validator from pydantic import model_validator
from sqlalchemy.orm import Session
from onyx.context.search.enums import SearchType from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection
class ToolResponse(BaseModel): class ToolResponse(BaseModel):
@@ -57,5 +60,15 @@ class SearchQueryInfo(BaseModel):
recency_bias_multiplier: float recency_bias_multiplier: float
class SearchToolOverrideKwargs(BaseModel):
force_no_rerank: bool
alternate_db_session: Session | None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
skip_query_analysis: bool
class Config:
arbitrary_types_allowed = True
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID" CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID" MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"

View File

@@ -1,7 +1,9 @@
import abc import abc
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
from typing import Generic
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar
from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage from onyx.llm.models import PreviousMessage
@@ -14,7 +16,10 @@ if TYPE_CHECKING:
from onyx.tools.models import ToolResponse from onyx.tools.models import ToolResponse
class Tool(abc.ABC): OVERRIDE_T = TypeVar("OVERRIDE_T")
class Tool(abc.ABC, Generic[OVERRIDE_T]):
@property @property
@abc.abstractmethod @abc.abstractmethod
def name(self) -> str: def name(self) -> str:
@@ -57,7 +62,9 @@ class Tool(abc.ABC):
"""Actual execution of the tool""" """Actual execution of the tool"""
@abc.abstractmethod @abc.abstractmethod
def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]: def run(
self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any
) -> Generator["ToolResponse", None, None]:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod

View File

@@ -74,6 +74,7 @@ class CustomToolCallSummary(BaseModel):
tool_result: Any # The response data tool_result: Any # The response data
# override_kwargs is not supported for custom tools
class CustomTool(BaseTool): class CustomTool(BaseTool):
def __init__( def __init__(
self, self,
@@ -235,7 +236,9 @@ class CustomTool(BaseTool):
"""Actual execution of the tool""" """Actual execution of the tool"""
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: def run(
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
) -> Generator[ToolResponse, None, None]:
request_body = kwargs.get(REQUEST_BODY) request_body = kwargs.get(REQUEST_BODY)
path_params = {} path_params = {}

View File

@@ -79,7 +79,8 @@ class ImageShape(str, Enum):
LANDSCAPE = "landscape" LANDSCAPE = "landscape"
class ImageGenerationTool(Tool): # override_kwargs is not supported for image generation tools
class ImageGenerationTool(Tool[None]):
_NAME = "run_image_generation" _NAME = "run_image_generation"
_DESCRIPTION = "Generate an image from a prompt." _DESCRIPTION = "Generate an image from a prompt."
_DISPLAY_NAME = "Image Generation" _DISPLAY_NAME = "Image Generation"
@@ -255,7 +256,9 @@ class ImageGenerationTool(Tool):
"An error occurred during image generation. Please try again later." "An error occurred during image generation. Please try again later."
) )
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: def run(
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
prompt = cast(str, kwargs["prompt"]) prompt = cast(str, kwargs["prompt"])
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE)) shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
format = self.output_format format = self.output_format

View File

@@ -106,7 +106,8 @@ def internet_search_response_to_search_docs(
] ]
class InternetSearchTool(Tool): # override_kwargs is not supported for internet search tools
class InternetSearchTool(Tool[None]):
_NAME = "run_internet_search" _NAME = "run_internet_search"
_DISPLAY_NAME = "Internet Search" _DISPLAY_NAME = "Internet Search"
_DESCRIPTION = "Perform an internet search for up-to-date information." _DESCRIPTION = "Perform an internet search for up-to-date information."
@@ -242,7 +243,9 @@ class InternetSearchTool(Tool):
], ],
) )
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: def run(
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["internet_search_query"]) query = cast(str, kwargs["internet_search_query"])
results = self._perform_search(query) results = self._perform_search(query)

View File

@@ -39,6 +39,7 @@ from onyx.secondary_llm_flows.choose_search import check_if_need_search
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
from onyx.tools.message import ToolCallSummary from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchQueryInfo from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
@@ -77,7 +78,7 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ
""" """
class SearchTool(Tool): class SearchTool(Tool[SearchToolOverrideKwargs]):
_NAME = "run_search" _NAME = "run_search"
_DISPLAY_NAME = "Search Tool" _DISPLAY_NAME = "Search Tool"
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION _DESCRIPTION = SEARCH_TOOL_DESCRIPTION
@@ -275,14 +276,19 @@ class SearchTool(Tool):
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs) yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: def run(
query = cast(str, kwargs["query"]) self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False)) ) -> Generator[ToolResponse, None, None]:
alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None)) query = cast(str, llm_kwargs["query"])
retrieved_sections_callback = cast( force_no_rerank = False
Callable[[list[InferenceSection]], None], alternate_db_session = None
kwargs.get("retrieved_sections_callback"), retrieved_sections_callback = None
) skip_query_analysis = False
if override_kwargs:
force_no_rerank = override_kwargs.force_no_rerank
alternate_db_session = override_kwargs.alternate_db_session
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
skip_query_analysis = override_kwargs.skip_query_analysis
if self.selected_sections: if self.selected_sections:
yield from self._build_response_for_specified_sections(query) yield from self._build_response_for_specified_sections(query)
@@ -324,6 +330,7 @@ class SearchTool(Tool):
user=self.user, user=self.user,
llm=self.llm, llm=self.llm,
fast_llm=self.fast_llm, fast_llm=self.fast_llm,
skip_query_analysis=skip_query_analysis,
bypass_acl=self.bypass_acl, bypass_acl=self.bypass_acl,
db_session=alternate_db_session or self.db_session, db_session=alternate_db_session or self.db_session,
prompt_config=self.prompt_config, prompt_config=self.prompt_config,