mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-18 19:43:26 +02:00
Fix Agent Slowness (#3979)
This commit is contained in:
@@ -83,6 +83,7 @@ def handle_search_request(
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
skip_query_analysis=False,
|
||||
db_session=db_session,
|
||||
bypass_acl=False,
|
||||
)
|
||||
|
@@ -23,6 +23,7 @@ from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
@@ -67,9 +68,12 @@ def retrieve_documents(
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=query_to_retrieve,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=not state.base_search,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
|
@@ -58,6 +58,7 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
@@ -218,7 +219,10 @@ def get_test_config(
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
|
||||
chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
|
||||
chat_session_id = (
|
||||
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
|
||||
or "00000000-0000-0000-0000-000000000000"
|
||||
)
|
||||
assert (
|
||||
chat_session_id is not None
|
||||
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
|
||||
@@ -341,8 +345,12 @@ def retrieve_search_docs(
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=None,
|
||||
skip_query_analysis=False,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
|
@@ -51,6 +51,7 @@ class SearchPipeline:
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
fast_llm: LLM,
|
||||
skip_query_analysis: bool,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
|
||||
retrieval_metrics_callback: (
|
||||
@@ -67,6 +68,7 @@ class SearchPipeline:
|
||||
self.user = user
|
||||
self.llm = llm
|
||||
self.fast_llm = fast_llm
|
||||
self.skip_query_analysis = skip_query_analysis
|
||||
self.db_session = db_session
|
||||
self.bypass_acl = bypass_acl
|
||||
self.retrieval_metrics_callback = retrieval_metrics_callback
|
||||
@@ -108,6 +110,7 @@ class SearchPipeline:
|
||||
search_request=self.search_request,
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
skip_query_analysis=self.skip_query_analysis,
|
||||
db_session=self.db_session,
|
||||
bypass_acl=self.bypass_acl,
|
||||
)
|
||||
@@ -162,6 +165,12 @@ class SearchPipeline:
|
||||
that have a corresponding chunk.
|
||||
|
||||
This step should be fast for any document index implementation.
|
||||
|
||||
Current implementation timing is approximately broken down in timing as:
|
||||
- 200 ms to get the embedding of the query
|
||||
- 15 ms to get chunks from the document index
|
||||
- possibly more to get additional surrounding chunks
|
||||
- possibly more for query expansion (multilingual)
|
||||
"""
|
||||
if self._retrieved_sections is not None:
|
||||
return self._retrieved_sections
|
||||
|
@@ -50,11 +50,11 @@ def retrieval_preprocessing(
|
||||
search_request: SearchRequest,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
skip_query_analysis: bool,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False,
|
||||
skip_query_analysis: bool = False,
|
||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
|
||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||
bypass_acl: bool = False,
|
||||
) -> SearchQuery:
|
||||
"""Logic is as follows:
|
||||
Any global disables apply first
|
||||
@@ -146,7 +146,7 @@ def retrieval_preprocessing(
|
||||
is_keyword, extracted_keywords = (
|
||||
parallel_results[run_query_analysis.result_id]
|
||||
if run_query_analysis
|
||||
else (None, None)
|
||||
else (False, None)
|
||||
)
|
||||
|
||||
all_query_terms = query.split()
|
||||
|
@@ -99,7 +99,7 @@ def _check_tokenizer_cache(
|
||||
|
||||
if not tokenizer:
|
||||
logger.info(
|
||||
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
|
||||
f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}"
|
||||
)
|
||||
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
|
||||
|
@@ -76,6 +76,7 @@ def gpt_search(
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
skip_query_analysis=True,
|
||||
db_session=db_session,
|
||||
).reranked_sections
|
||||
|
||||
|
@@ -34,7 +34,7 @@ Now respond to the following:
|
||||
""".strip()
|
||||
|
||||
|
||||
class BaseTool(Tool):
|
||||
class BaseTool(Tool[None]):
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: "AnswerPromptBuilder",
|
||||
|
@@ -1,11 +1,14 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
@@ -57,5 +60,15 @@ class SearchQueryInfo(BaseModel):
|
||||
recency_bias_multiplier: float
|
||||
|
||||
|
||||
class SearchToolOverrideKwargs(BaseModel):
|
||||
force_no_rerank: bool
|
||||
alternate_db_session: Session | None
|
||||
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
|
||||
skip_query_analysis: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
|
||||
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"
|
||||
|
@@ -1,7 +1,9 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
@@ -14,7 +16,10 @@ if TYPE_CHECKING:
|
||||
from onyx.tools.models import ToolResponse
|
||||
|
||||
|
||||
class Tool(abc.ABC):
|
||||
OVERRIDE_T = TypeVar("OVERRIDE_T")
|
||||
|
||||
|
||||
class Tool(abc.ABC, Generic[OVERRIDE_T]):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
@@ -57,7 +62,9 @@ class Tool(abc.ABC):
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]:
|
||||
def run(
|
||||
self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any
|
||||
) -> Generator["ToolResponse", None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@@ -74,6 +74,7 @@ class CustomToolCallSummary(BaseModel):
|
||||
tool_result: Any # The response data
|
||||
|
||||
|
||||
# override_kwargs is not supported for custom tools
|
||||
class CustomTool(BaseTool):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -235,7 +236,9 @@ class CustomTool(BaseTool):
|
||||
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
request_body = kwargs.get(REQUEST_BODY)
|
||||
|
||||
path_params = {}
|
||||
|
@@ -79,7 +79,8 @@ class ImageShape(str, Enum):
|
||||
LANDSCAPE = "landscape"
|
||||
|
||||
|
||||
class ImageGenerationTool(Tool):
|
||||
# override_kwargs is not supported for image generation tools
|
||||
class ImageGenerationTool(Tool[None]):
|
||||
_NAME = "run_image_generation"
|
||||
_DESCRIPTION = "Generate an image from a prompt."
|
||||
_DISPLAY_NAME = "Image Generation"
|
||||
@@ -255,7 +256,9 @@ class ImageGenerationTool(Tool):
|
||||
"An error occurred during image generation. Please try again later."
|
||||
)
|
||||
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, override_kwargs: None = None, **kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
prompt = cast(str, kwargs["prompt"])
|
||||
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
|
||||
format = self.output_format
|
||||
|
@@ -106,7 +106,8 @@ def internet_search_response_to_search_docs(
|
||||
]
|
||||
|
||||
|
||||
class InternetSearchTool(Tool):
|
||||
# override_kwargs is not supported for internet search tools
|
||||
class InternetSearchTool(Tool[None]):
|
||||
_NAME = "run_internet_search"
|
||||
_DISPLAY_NAME = "Internet Search"
|
||||
_DESCRIPTION = "Perform an internet search for up-to-date information."
|
||||
@@ -242,7 +243,9 @@ class InternetSearchTool(Tool):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, override_kwargs: None = None, **kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, kwargs["internet_search_query"])
|
||||
|
||||
results = self._perform_search(query)
|
||||
|
@@ -39,6 +39,7 @@ from onyx.secondary_llm_flows.choose_search import check_if_need_search
|
||||
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
||||
@@ -77,7 +78,7 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ
|
||||
"""
|
||||
|
||||
|
||||
class SearchTool(Tool):
|
||||
class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
_NAME = "run_search"
|
||||
_DISPLAY_NAME = "Search Tool"
|
||||
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION
|
||||
@@ -275,14 +276,19 @@ class SearchTool(Tool):
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, kwargs["query"])
|
||||
force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False))
|
||||
alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None))
|
||||
retrieved_sections_callback = cast(
|
||||
Callable[[list[InferenceSection]], None],
|
||||
kwargs.get("retrieved_sections_callback"),
|
||||
)
|
||||
def run(
|
||||
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, llm_kwargs["query"])
|
||||
force_no_rerank = False
|
||||
alternate_db_session = None
|
||||
retrieved_sections_callback = None
|
||||
skip_query_analysis = False
|
||||
if override_kwargs:
|
||||
force_no_rerank = override_kwargs.force_no_rerank
|
||||
alternate_db_session = override_kwargs.alternate_db_session
|
||||
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
|
||||
skip_query_analysis = override_kwargs.skip_query_analysis
|
||||
|
||||
if self.selected_sections:
|
||||
yield from self._build_response_for_specified_sections(query)
|
||||
@@ -324,6 +330,7 @@ class SearchTool(Tool):
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
fast_llm=self.fast_llm,
|
||||
skip_query_analysis=skip_query_analysis,
|
||||
bypass_acl=self.bypass_acl,
|
||||
db_session=alternate_db_session or self.db_session,
|
||||
prompt_config=self.prompt_config,
|
||||
|
Reference in New Issue
Block a user