no reranking if local model w/o GPU for Agent Search (#4011)

* no reranking if locql model w/o GPU

* more efficient gpu status calling

* fix unit tests

---------

Co-authored-by: Evan Lohn <evan@danswer.ai>
This commit is contained in:
joachim-danswer 2025-02-17 06:13:24 -08:00 committed by GitHub
parent 9324f426c0
commit 86bd121806
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 99 additions and 12 deletions

View File

@ -55,6 +55,7 @@ def rerank_documents(
# Note that these are passed in values from the API and are overrides which are typically None
rerank_settings = graph_config.inputs.search_request.rerank_settings
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
if rerank_settings is None:
with get_session_context_manager() as db_session:
@ -62,23 +63,31 @@ def rerank_documents(
if not search_settings.disable_rerank_for_streaming:
rerank_settings = RerankingDetails.from_db_model(search_settings)
# Initial default: no reranking. Will be overwritten below if reranking is warranted
reranked_documents = verified_documents
if should_rerank(rerank_settings) and len(verified_documents) > 0:
if len(verified_documents) > 1:
reranked_documents = rerank_sections(
query_str=question,
# if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
)
if not allow_agent_reranking:
logger.info("Use of local rerank model without GPU, skipping reranking")
# No reranking, stay with verified_documents as default
else:
# Reranking is warranted, use the rerank_sections functon
reranked_documents = rerank_sections(
query_str=question,
# if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
)
else:
logger.warning(
f"{len(verified_documents)} verified document(s) found, skipping reranking"
)
reranked_documents = verified_documents
# No reranking, stay with verified_documents as default
else:
logger.warning("No reranking settings found, using unranked documents")
reranked_documents = verified_documents
# No reranking, stay with verified_documents as default
if AGENT_RERANKING_STATS:
fit_scores = get_fit_scores(verified_documents, reranked_documents)
else:

View File

@ -67,6 +67,7 @@ class GraphSearchConfig(BaseModel):
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = True
skip_gen_ai_answer_generation: bool = False
allow_agent_reranking: bool = False
class GraphConfig(BaseModel):

View File

@ -29,6 +29,7 @@ from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.gpu_utils import gpu_status_request
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -80,6 +81,14 @@ class Answer:
and not skip_explicit_tool_calling
)
rerank_settings = search_request.rerank_settings
using_cloud_reranking = (
rerank_settings is not None
and rerank_settings.rerank_provider_type is not None
)
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
self.graph_inputs = GraphInputs(
search_request=search_request,
prompt_builder=prompt_builder,
@ -94,7 +103,6 @@ class Answer:
force_use_tool=force_use_tool,
using_tool_calling_llm=using_tool_calling_llm,
)
assert db_session, "db_session must be provided for agentic persistence"
self.graph_persistence = GraphPersistence(
db_session=db_session,
chat_session_id=chat_session_id,
@ -104,6 +112,7 @@ class Answer:
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
allow_refinement=True,
allow_agent_reranking=allow_agent_reranking,
)
self.graph_config = GraphConfig(
inputs=self.graph_inputs,

View File

@ -11,6 +11,7 @@ from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.messages import ToolCall
from langchain_core.messages import ToolCallChunk
from pytest_mock import MockerFixture
from sqlalchemy.orm import Session
from onyx.chat.answer import Answer
@ -25,6 +26,7 @@ from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import SearchRequest
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
@ -35,6 +37,7 @@ from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTEN
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from shared_configs.enums import RerankerProvider
from tests.unit.onyx.chat.conftest import DEFAULT_SEARCH_ARGS
from tests.unit.onyx.chat.conftest import QUERY
@ -44,6 +47,20 @@ def answer_instance(
mock_llm: LLM,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
mocker: MockerFixture,
) -> Answer:
mocker.patch(
"onyx.chat.answer.gpu_status_request",
return_value=True,
)
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
def _answer_fixture_impl(
mock_llm: LLM,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
rerank_settings: RerankingDetails | None = None,
) -> Answer:
return Answer(
prompt_builder=AnswerPromptBuilder(
@ -64,13 +81,13 @@ def answer_instance(
llm=mock_llm,
fast_llm=mock_llm,
force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None),
search_request=SearchRequest(query=QUERY),
search_request=SearchRequest(query=QUERY, rerank_settings=rerank_settings),
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
current_agent_message_id=0,
)
def test_basic_answer(answer_instance: Answer) -> None:
def test_basic_answer(answer_instance: Answer, mocker: MockerFixture) -> None:
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
mock_llm.stream.return_value = [
AIMessageChunk(content="This is a "),
@ -363,3 +380,49 @@ def test_is_cancelled(answer_instance: Answer) -> None:
# Verify LLM calls
mock_llm.stream.assert_called_once()
@pytest.mark.parametrize(
"gpu_enabled,is_local_model",
[
(True, False),
(False, True),
(True, True),
(False, False),
],
)
def test_no_slow_reranking(
gpu_enabled: bool,
is_local_model: bool,
mock_llm: LLM,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
mocker: MockerFixture,
) -> None:
mocker.patch(
"onyx.chat.answer.gpu_status_request",
return_value=gpu_enabled,
)
rerank_settings = (
None
if is_local_model
else RerankingDetails(
rerank_model_name="test_model",
rerank_api_url="test_url",
rerank_api_key="test_key",
num_rerank=10,
rerank_provider_type=RerankerProvider.COHERE,
)
)
answer_instance = _answer_fixture_impl(
mock_llm, answer_style_config, prompt_config, rerank_settings=rerank_settings
)
assert (
answer_instance.graph_config.inputs.search_request.rerank_settings
== rerank_settings
)
assert (
answer_instance.graph_config.behavior.allow_agent_reranking == gpu_enabled
or not is_local_model
)

View File

@ -36,7 +36,12 @@ def test_skip_gen_ai_answer_generation_flag(
mock_search_tool: SearchTool,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
mocker: MockerFixture,
) -> None:
mocker.patch(
"onyx.chat.answer.gpu_status_request",
return_value=True,
)
question = config["question"]
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]