Fix Agent Slowness (#3979)

This commit is contained in:
Yuhong Sun
2025-02-13 15:54:34 -08:00
committed by GitHub
parent c6434db7eb
commit 1a7aca06b9
14 changed files with 87 additions and 28 deletions

View File

@@ -83,6 +83,7 @@ def handle_search_request(
user=user,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=False,
db_session=db_session,
bypass_acl=False,
)

View File

@@ -23,6 +23,7 @@ from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -67,9 +68,12 @@ def retrieve_documents(
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=not state.base_search,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:

View File

@@ -58,6 +58,7 @@ from onyx.prompts.agent_search import (
)
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
@@ -218,7 +219,10 @@ def get_test_config(
using_tool_calling_llm=using_tool_calling_llm,
)
chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
chat_session_id = (
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
or "00000000-0000-0000-0000-000000000000"
)
assert (
chat_session_id is not None
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
@@ -341,8 +345,12 @@ def retrieve_search_docs(
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=question,
force_no_rerank=True,
alternate_db_session=db_session,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=None,
skip_query_analysis=False,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:

View File

@@ -51,6 +51,7 @@ class SearchPipeline:
user: User | None,
llm: LLM,
fast_llm: LLM,
skip_query_analysis: bool,
db_session: Session,
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
retrieval_metrics_callback: (
@@ -67,6 +68,7 @@ class SearchPipeline:
self.user = user
self.llm = llm
self.fast_llm = fast_llm
self.skip_query_analysis = skip_query_analysis
self.db_session = db_session
self.bypass_acl = bypass_acl
self.retrieval_metrics_callback = retrieval_metrics_callback
@@ -108,6 +110,7 @@ class SearchPipeline:
search_request=self.search_request,
user=self.user,
llm=self.llm,
skip_query_analysis=self.skip_query_analysis,
db_session=self.db_session,
bypass_acl=self.bypass_acl,
)
@@ -162,6 +165,12 @@ class SearchPipeline:
that have a corresponding chunk.
This step should be fast for any document index implementation.
Current implementation timing is approximately broken down in timing as:
- 200 ms to get the embedding of the query
- 15 ms to get chunks from the document index
- possibly more to get additional surrounding chunks
- possibly more for query expansion (multilingual)
"""
if self._retrieved_sections is not None:
return self._retrieved_sections

View File

@@ -50,11 +50,11 @@ def retrieval_preprocessing(
search_request: SearchRequest,
user: User | None,
llm: LLM,
skip_query_analysis: bool,
db_session: Session,
bypass_acl: bool = False,
skip_query_analysis: bool = False,
base_recency_decay: float = BASE_RECENCY_DECAY,
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
base_recency_decay: float = BASE_RECENCY_DECAY,
bypass_acl: bool = False,
) -> SearchQuery:
"""Logic is as follows:
Any global disables apply first
@@ -146,7 +146,7 @@ def retrieval_preprocessing(
is_keyword, extracted_keywords = (
parallel_results[run_query_analysis.result_id]
if run_query_analysis
else (None, None)
else (False, None)
)
all_query_terms = query.split()

View File

@@ -99,7 +99,7 @@ def _check_tokenizer_cache(
if not tokenizer:
logger.info(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}"
)
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)

View File

@@ -76,6 +76,7 @@ def gpt_search(
user=None,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=True,
db_session=db_session,
).reranked_sections

View File

@@ -34,7 +34,7 @@ Now respond to the following:
""".strip()
class BaseTool(Tool):
class BaseTool(Tool[None]):
def build_next_prompt(
self,
prompt_builder: "AnswerPromptBuilder",

View File

@@ -1,11 +1,14 @@
from collections.abc import Callable
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session
from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection
class ToolResponse(BaseModel):
@@ -57,5 +60,15 @@ class SearchQueryInfo(BaseModel):
recency_bias_multiplier: float
class SearchToolOverrideKwargs(BaseModel):
force_no_rerank: bool
alternate_db_session: Session | None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
skip_query_analysis: bool
class Config:
arbitrary_types_allowed = True
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"

View File

@@ -1,7 +1,9 @@
import abc
from collections.abc import Generator
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
@@ -14,7 +16,10 @@ if TYPE_CHECKING:
from onyx.tools.models import ToolResponse
class Tool(abc.ABC):
OVERRIDE_T = TypeVar("OVERRIDE_T")
class Tool(abc.ABC, Generic[OVERRIDE_T]):
@property
@abc.abstractmethod
def name(self) -> str:
@@ -57,7 +62,9 @@ class Tool(abc.ABC):
"""Actual execution of the tool"""
@abc.abstractmethod
def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]:
def run(
self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any
) -> Generator["ToolResponse", None, None]:
raise NotImplementedError
@abc.abstractmethod

View File

@@ -74,6 +74,7 @@ class CustomToolCallSummary(BaseModel):
tool_result: Any # The response data
# override_kwargs is not supported for custom tools
class CustomTool(BaseTool):
def __init__(
self,
@@ -235,7 +236,9 @@ class CustomTool(BaseTool):
"""Actual execution of the tool"""
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
def run(
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
) -> Generator[ToolResponse, None, None]:
request_body = kwargs.get(REQUEST_BODY)
path_params = {}

View File

@@ -79,7 +79,8 @@ class ImageShape(str, Enum):
LANDSCAPE = "landscape"
class ImageGenerationTool(Tool):
# override_kwargs is not supported for image generation tools
class ImageGenerationTool(Tool[None]):
_NAME = "run_image_generation"
_DESCRIPTION = "Generate an image from a prompt."
_DISPLAY_NAME = "Image Generation"
@@ -255,7 +256,9 @@ class ImageGenerationTool(Tool):
"An error occurred during image generation. Please try again later."
)
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
def run(
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
prompt = cast(str, kwargs["prompt"])
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
format = self.output_format

View File

@@ -106,7 +106,8 @@ def internet_search_response_to_search_docs(
]
class InternetSearchTool(Tool):
# override_kwargs is not supported for internet search tools
class InternetSearchTool(Tool[None]):
_NAME = "run_internet_search"
_DISPLAY_NAME = "Internet Search"
_DESCRIPTION = "Perform an internet search for up-to-date information."
@@ -242,7 +243,9 @@ class InternetSearchTool(Tool):
],
)
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
def run(
self, override_kwargs: None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["internet_search_query"])
results = self._perform_search(query)

View File

@@ -39,6 +39,7 @@ from onyx.secondary_llm_flows.choose_search import check_if_need_search
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
@@ -77,7 +78,7 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ
"""
class SearchTool(Tool):
class SearchTool(Tool[SearchToolOverrideKwargs]):
_NAME = "run_search"
_DISPLAY_NAME = "Search Tool"
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION
@@ -275,14 +276,19 @@ class SearchTool(Tool):
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
query = cast(str, kwargs["query"])
force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False))
alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None))
retrieved_sections_callback = cast(
Callable[[list[InferenceSection]], None],
kwargs.get("retrieved_sections_callback"),
)
def run(
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
query = cast(str, llm_kwargs["query"])
force_no_rerank = False
alternate_db_session = None
retrieved_sections_callback = None
skip_query_analysis = False
if override_kwargs:
force_no_rerank = override_kwargs.force_no_rerank
alternate_db_session = override_kwargs.alternate_db_session
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
skip_query_analysis = override_kwargs.skip_query_analysis
if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
@@ -324,6 +330,7 @@ class SearchTool(Tool):
user=self.user,
llm=self.llm,
fast_llm=self.fast_llm,
skip_query_analysis=skip_query_analysis,
bypass_acl=self.bypass_acl,
db_session=alternate_db_session or self.db_session,
prompt_config=self.prompt_config,