mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-10 13:15:18 +02:00
Fix Agent Slowness (#3979)
This commit is contained in:
@@ -83,6 +83,7 @@ def handle_search_request(
|
|||||||
user=user,
|
user=user,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
fast_llm=fast_llm,
|
fast_llm=fast_llm,
|
||||||
|
skip_query_analysis=False,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
bypass_acl=False,
|
bypass_acl=False,
|
||||||
)
|
)
|
||||||
|
@@ -23,6 +23,7 @@ from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
|
|||||||
from onyx.context.search.models import InferenceSection
|
from onyx.context.search.models import InferenceSection
|
||||||
from onyx.db.engine import get_session_context_manager
|
from onyx.db.engine import get_session_context_manager
|
||||||
from onyx.tools.models import SearchQueryInfo
|
from onyx.tools.models import SearchQueryInfo
|
||||||
|
from onyx.tools.models import SearchToolOverrideKwargs
|
||||||
from onyx.tools.tool_implementations.search.search_tool import (
|
from onyx.tools.tool_implementations.search.search_tool import (
|
||||||
SEARCH_RESPONSE_SUMMARY_ID,
|
SEARCH_RESPONSE_SUMMARY_ID,
|
||||||
)
|
)
|
||||||
@@ -67,9 +68,12 @@ def retrieve_documents(
|
|||||||
with get_session_context_manager() as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
for tool_response in search_tool.run(
|
for tool_response in search_tool.run(
|
||||||
query=query_to_retrieve,
|
query=query_to_retrieve,
|
||||||
force_no_rerank=True,
|
override_kwargs=SearchToolOverrideKwargs(
|
||||||
alternate_db_session=db_session,
|
force_no_rerank=True,
|
||||||
retrieved_sections_callback=callback_container.append,
|
alternate_db_session=db_session,
|
||||||
|
retrieved_sections_callback=callback_container.append,
|
||||||
|
skip_query_analysis=not state.base_search,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
# get retrieved docs to send to the rest of the graph
|
# get retrieved docs to send to the rest of the graph
|
||||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||||
|
@@ -58,6 +58,7 @@ from onyx.prompts.agent_search import (
|
|||||||
)
|
)
|
||||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||||
from onyx.tools.force import ForceUseTool
|
from onyx.tools.force import ForceUseTool
|
||||||
|
from onyx.tools.models import SearchToolOverrideKwargs
|
||||||
from onyx.tools.tool_constructor import SearchToolConfig
|
from onyx.tools.tool_constructor import SearchToolConfig
|
||||||
from onyx.tools.tool_implementations.search.search_tool import (
|
from onyx.tools.tool_implementations.search.search_tool import (
|
||||||
SEARCH_RESPONSE_SUMMARY_ID,
|
SEARCH_RESPONSE_SUMMARY_ID,
|
||||||
@@ -218,7 +219,10 @@ def get_test_config(
|
|||||||
using_tool_calling_llm=using_tool_calling_llm,
|
using_tool_calling_llm=using_tool_calling_llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
|
chat_session_id = (
|
||||||
|
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
|
||||||
|
or "00000000-0000-0000-0000-000000000000"
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
chat_session_id is not None
|
chat_session_id is not None
|
||||||
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
|
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
|
||||||
@@ -341,8 +345,12 @@ def retrieve_search_docs(
|
|||||||
with get_session_context_manager() as db_session:
|
with get_session_context_manager() as db_session:
|
||||||
for tool_response in search_tool.run(
|
for tool_response in search_tool.run(
|
||||||
query=question,
|
query=question,
|
||||||
force_no_rerank=True,
|
override_kwargs=SearchToolOverrideKwargs(
|
||||||
alternate_db_session=db_session,
|
force_no_rerank=True,
|
||||||
|
alternate_db_session=db_session,
|
||||||
|
retrieved_sections_callback=None,
|
||||||
|
skip_query_analysis=False,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
# get retrieved docs to send to the rest of the graph
|
# get retrieved docs to send to the rest of the graph
|
||||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||||
|
@@ -51,6 +51,7 @@ class SearchPipeline:
|
|||||||
user: User | None,
|
user: User | None,
|
||||||
llm: LLM,
|
llm: LLM,
|
||||||
fast_llm: LLM,
|
fast_llm: LLM,
|
||||||
|
skip_query_analysis: bool,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
|
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
|
||||||
retrieval_metrics_callback: (
|
retrieval_metrics_callback: (
|
||||||
@@ -67,6 +68,7 @@ class SearchPipeline:
|
|||||||
self.user = user
|
self.user = user
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
self.fast_llm = fast_llm
|
self.fast_llm = fast_llm
|
||||||
|
self.skip_query_analysis = skip_query_analysis
|
||||||
self.db_session = db_session
|
self.db_session = db_session
|
||||||
self.bypass_acl = bypass_acl
|
self.bypass_acl = bypass_acl
|
||||||
self.retrieval_metrics_callback = retrieval_metrics_callback
|
self.retrieval_metrics_callback = retrieval_metrics_callback
|
||||||
@@ -108,6 +110,7 @@ class SearchPipeline:
|
|||||||
search_request=self.search_request,
|
search_request=self.search_request,
|
||||||
user=self.user,
|
user=self.user,
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
|
skip_query_analysis=self.skip_query_analysis,
|
||||||
db_session=self.db_session,
|
db_session=self.db_session,
|
||||||
bypass_acl=self.bypass_acl,
|
bypass_acl=self.bypass_acl,
|
||||||
)
|
)
|
||||||
@@ -162,6 +165,12 @@ class SearchPipeline:
|
|||||||
that have a corresponding chunk.
|
that have a corresponding chunk.
|
||||||
|
|
||||||
This step should be fast for any document index implementation.
|
This step should be fast for any document index implementation.
|
||||||
|
|
||||||
|
Current implementation timing is approximately broken down in timing as:
|
||||||
|
- 200 ms to get the embedding of the query
|
||||||
|
- 15 ms to get chunks from the document index
|
||||||
|
- possibly more to get additional surrounding chunks
|
||||||
|
- possibly more for query expansion (multilingual)
|
||||||
"""
|
"""
|
||||||
if self._retrieved_sections is not None:
|
if self._retrieved_sections is not None:
|
||||||
return self._retrieved_sections
|
return self._retrieved_sections
|
||||||
|
@@ -50,11 +50,11 @@ def retrieval_preprocessing(
|
|||||||
search_request: SearchRequest,
|
search_request: SearchRequest,
|
||||||
user: User | None,
|
user: User | None,
|
||||||
llm: LLM,
|
llm: LLM,
|
||||||
|
skip_query_analysis: bool,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
bypass_acl: bool = False,
|
|
||||||
skip_query_analysis: bool = False,
|
|
||||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
|
||||||
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
|
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
|
||||||
|
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||||
|
bypass_acl: bool = False,
|
||||||
) -> SearchQuery:
|
) -> SearchQuery:
|
||||||
"""Logic is as follows:
|
"""Logic is as follows:
|
||||||
Any global disables apply first
|
Any global disables apply first
|
||||||
@@ -146,7 +146,7 @@ def retrieval_preprocessing(
|
|||||||
is_keyword, extracted_keywords = (
|
is_keyword, extracted_keywords = (
|
||||||
parallel_results[run_query_analysis.result_id]
|
parallel_results[run_query_analysis.result_id]
|
||||||
if run_query_analysis
|
if run_query_analysis
|
||||||
else (None, None)
|
else (False, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
all_query_terms = query.split()
|
all_query_terms = query.split()
|
||||||
|
@@ -99,7 +99,7 @@ def _check_tokenizer_cache(
|
|||||||
|
|
||||||
if not tokenizer:
|
if not tokenizer:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
|
f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}"
|
||||||
)
|
)
|
||||||
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||||
|
|
||||||
|
@@ -76,6 +76,7 @@ def gpt_search(
|
|||||||
user=None,
|
user=None,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
fast_llm=fast_llm,
|
fast_llm=fast_llm,
|
||||||
|
skip_query_analysis=True,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
).reranked_sections
|
).reranked_sections
|
||||||
|
|
||||||
|
@@ -34,7 +34,7 @@ Now respond to the following:
|
|||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
class BaseTool(Tool):
|
class BaseTool(Tool[None]):
|
||||||
def build_next_prompt(
|
def build_next_prompt(
|
||||||
self,
|
self,
|
||||||
prompt_builder: "AnswerPromptBuilder",
|
prompt_builder: "AnswerPromptBuilder",
|
||||||
|
@@ -1,11 +1,14 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from onyx.context.search.enums import SearchType
|
from onyx.context.search.enums import SearchType
|
||||||
from onyx.context.search.models import IndexFilters
|
from onyx.context.search.models import IndexFilters
|
||||||
|
from onyx.context.search.models import InferenceSection
|
||||||
|
|
||||||
|
|
||||||
class ToolResponse(BaseModel):
|
class ToolResponse(BaseModel):
|
||||||
@@ -57,5 +60,15 @@ class SearchQueryInfo(BaseModel):
|
|||||||
recency_bias_multiplier: float
|
recency_bias_multiplier: float
|
||||||
|
|
||||||
|
|
||||||
|
class SearchToolOverrideKwargs(BaseModel):
|
||||||
|
force_no_rerank: bool
|
||||||
|
alternate_db_session: Session | None
|
||||||
|
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
|
||||||
|
skip_query_analysis: bool
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
|
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
|
||||||
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"
|
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"
|
||||||
|
@@ -1,7 +1,9 @@
|
|||||||
import abc
|
import abc
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from typing import Generic
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
from onyx.llm.interfaces import LLM
|
from onyx.llm.interfaces import LLM
|
||||||
from onyx.llm.models import PreviousMessage
|
from onyx.llm.models import PreviousMessage
|
||||||
@@ -14,7 +16,10 @@ if TYPE_CHECKING:
|
|||||||
from onyx.tools.models import ToolResponse
|
from onyx.tools.models import ToolResponse
|
||||||
|
|
||||||
|
|
||||||
class Tool(abc.ABC):
|
OVERRIDE_T = TypeVar("OVERRIDE_T")
|
||||||
|
|
||||||
|
|
||||||
|
class Tool(abc.ABC, Generic[OVERRIDE_T]):
|
||||||
@property
|
@property
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@@ -57,7 +62,9 @@ class Tool(abc.ABC):
|
|||||||
"""Actual execution of the tool"""
|
"""Actual execution of the tool"""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]:
|
def run(
|
||||||
|
self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any
|
||||||
|
) -> Generator["ToolResponse", None, None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
@@ -74,6 +74,7 @@ class CustomToolCallSummary(BaseModel):
|
|||||||
tool_result: Any # The response data
|
tool_result: Any # The response data
|
||||||
|
|
||||||
|
|
||||||
|
# override_kwargs is not supported for custom tools
|
||||||
class CustomTool(BaseTool):
|
class CustomTool(BaseTool):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -235,7 +236,9 @@ class CustomTool(BaseTool):
|
|||||||
|
|
||||||
"""Actual execution of the tool"""
|
"""Actual execution of the tool"""
|
||||||
|
|
||||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
def run(
|
||||||
|
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
|
||||||
|
) -> Generator[ToolResponse, None, None]:
|
||||||
request_body = kwargs.get(REQUEST_BODY)
|
request_body = kwargs.get(REQUEST_BODY)
|
||||||
|
|
||||||
path_params = {}
|
path_params = {}
|
||||||
|
@@ -79,7 +79,8 @@ class ImageShape(str, Enum):
|
|||||||
LANDSCAPE = "landscape"
|
LANDSCAPE = "landscape"
|
||||||
|
|
||||||
|
|
||||||
class ImageGenerationTool(Tool):
|
# override_kwargs is not supported for image generation tools
|
||||||
|
class ImageGenerationTool(Tool[None]):
|
||||||
_NAME = "run_image_generation"
|
_NAME = "run_image_generation"
|
||||||
_DESCRIPTION = "Generate an image from a prompt."
|
_DESCRIPTION = "Generate an image from a prompt."
|
||||||
_DISPLAY_NAME = "Image Generation"
|
_DISPLAY_NAME = "Image Generation"
|
||||||
@@ -255,7 +256,9 @@ class ImageGenerationTool(Tool):
|
|||||||
"An error occurred during image generation. Please try again later."
|
"An error occurred during image generation. Please try again later."
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
def run(
|
||||||
|
self, override_kwargs: None = None, **kwargs: str
|
||||||
|
) -> Generator[ToolResponse, None, None]:
|
||||||
prompt = cast(str, kwargs["prompt"])
|
prompt = cast(str, kwargs["prompt"])
|
||||||
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
|
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
|
||||||
format = self.output_format
|
format = self.output_format
|
||||||
|
@@ -106,7 +106,8 @@ def internet_search_response_to_search_docs(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class InternetSearchTool(Tool):
|
# override_kwargs is not supported for internet search tools
|
||||||
|
class InternetSearchTool(Tool[None]):
|
||||||
_NAME = "run_internet_search"
|
_NAME = "run_internet_search"
|
||||||
_DISPLAY_NAME = "Internet Search"
|
_DISPLAY_NAME = "Internet Search"
|
||||||
_DESCRIPTION = "Perform an internet search for up-to-date information."
|
_DESCRIPTION = "Perform an internet search for up-to-date information."
|
||||||
@@ -242,7 +243,9 @@ class InternetSearchTool(Tool):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
def run(
|
||||||
|
self, override_kwargs: None = None, **kwargs: str
|
||||||
|
) -> Generator[ToolResponse, None, None]:
|
||||||
query = cast(str, kwargs["internet_search_query"])
|
query = cast(str, kwargs["internet_search_query"])
|
||||||
|
|
||||||
results = self._perform_search(query)
|
results = self._perform_search(query)
|
||||||
|
@@ -39,6 +39,7 @@ from onyx.secondary_llm_flows.choose_search import check_if_need_search
|
|||||||
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||||
from onyx.tools.message import ToolCallSummary
|
from onyx.tools.message import ToolCallSummary
|
||||||
from onyx.tools.models import SearchQueryInfo
|
from onyx.tools.models import SearchQueryInfo
|
||||||
|
from onyx.tools.models import SearchToolOverrideKwargs
|
||||||
from onyx.tools.models import ToolResponse
|
from onyx.tools.models import ToolResponse
|
||||||
from onyx.tools.tool import Tool
|
from onyx.tools.tool import Tool
|
||||||
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
||||||
@@ -77,7 +78,7 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class SearchTool(Tool):
|
class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||||
_NAME = "run_search"
|
_NAME = "run_search"
|
||||||
_DISPLAY_NAME = "Search Tool"
|
_DISPLAY_NAME = "Search Tool"
|
||||||
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION
|
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION
|
||||||
@@ -275,14 +276,19 @@ class SearchTool(Tool):
|
|||||||
|
|
||||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||||
|
|
||||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
def run(
|
||||||
query = cast(str, kwargs["query"])
|
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
|
||||||
force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False))
|
) -> Generator[ToolResponse, None, None]:
|
||||||
alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None))
|
query = cast(str, llm_kwargs["query"])
|
||||||
retrieved_sections_callback = cast(
|
force_no_rerank = False
|
||||||
Callable[[list[InferenceSection]], None],
|
alternate_db_session = None
|
||||||
kwargs.get("retrieved_sections_callback"),
|
retrieved_sections_callback = None
|
||||||
)
|
skip_query_analysis = False
|
||||||
|
if override_kwargs:
|
||||||
|
force_no_rerank = override_kwargs.force_no_rerank
|
||||||
|
alternate_db_session = override_kwargs.alternate_db_session
|
||||||
|
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
|
||||||
|
skip_query_analysis = override_kwargs.skip_query_analysis
|
||||||
|
|
||||||
if self.selected_sections:
|
if self.selected_sections:
|
||||||
yield from self._build_response_for_specified_sections(query)
|
yield from self._build_response_for_specified_sections(query)
|
||||||
@@ -324,6 +330,7 @@ class SearchTool(Tool):
|
|||||||
user=self.user,
|
user=self.user,
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
fast_llm=self.fast_llm,
|
fast_llm=self.fast_llm,
|
||||||
|
skip_query_analysis=skip_query_analysis,
|
||||||
bypass_acl=self.bypass_acl,
|
bypass_acl=self.bypass_acl,
|
||||||
db_session=alternate_db_session or self.db_session,
|
db_session=alternate_db_session or self.db_session,
|
||||||
prompt_config=self.prompt_config,
|
prompt_config=self.prompt_config,
|
||||||
|
Reference in New Issue
Block a user