mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-02 08:58:11 +02:00
no reranking if local model w/o GPU for Agent Search (#4011)
* no reranking if locql model w/o GPU * more efficient gpu status calling * fix unit tests --------- Co-authored-by: Evan Lohn <evan@danswer.ai>
This commit is contained in:
parent
9324f426c0
commit
86bd121806
@ -55,6 +55,7 @@ def rerank_documents(
|
||||
|
||||
# Note that these are passed in values from the API and are overrides which are typically None
|
||||
rerank_settings = graph_config.inputs.search_request.rerank_settings
|
||||
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
|
||||
|
||||
if rerank_settings is None:
|
||||
with get_session_context_manager() as db_session:
|
||||
@ -62,23 +63,31 @@ def rerank_documents(
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
# Initial default: no reranking. Will be overwritten below if reranking is warranted
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if should_rerank(rerank_settings) and len(verified_documents) > 0:
|
||||
if len(verified_documents) > 1:
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
if not allow_agent_reranking:
|
||||
logger.info("Use of local rerank model without GPU, skipping reranking")
|
||||
# No reranking, stay with verified_documents as default
|
||||
|
||||
else:
|
||||
# Reranking is warranted, use the rerank_sections functon
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{len(verified_documents)} verified document(s) found, skipping reranking"
|
||||
)
|
||||
reranked_documents = verified_documents
|
||||
# No reranking, stay with verified_documents as default
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
reranked_documents = verified_documents
|
||||
|
||||
# No reranking, stay with verified_documents as default
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
|
@ -67,6 +67,7 @@ class GraphSearchConfig(BaseModel):
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
allow_agent_reranking: bool = False
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
|
@ -29,6 +29,7 @@ from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.gpu_utils import gpu_status_request
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -80,6 +81,14 @@ class Answer:
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
rerank_settings = search_request.rerank_settings
|
||||
|
||||
using_cloud_reranking = (
|
||||
rerank_settings is not None
|
||||
and rerank_settings.rerank_provider_type is not None
|
||||
)
|
||||
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
|
||||
|
||||
self.graph_inputs = GraphInputs(
|
||||
search_request=search_request,
|
||||
prompt_builder=prompt_builder,
|
||||
@ -94,7 +103,6 @@ class Answer:
|
||||
force_use_tool=force_use_tool,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
assert db_session, "db_session must be provided for agentic persistence"
|
||||
self.graph_persistence = GraphPersistence(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
@ -104,6 +112,7 @@ class Answer:
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
allow_refinement=True,
|
||||
allow_agent_reranking=allow_agent_reranking,
|
||||
)
|
||||
self.graph_config = GraphConfig(
|
||||
inputs=self.graph_inputs,
|
||||
|
@ -11,6 +11,7 @@ from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.messages import ToolCallChunk
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.answer import Answer
|
||||
@ -25,6 +26,7 @@ from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
@ -35,6 +37,7 @@ from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTEN
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from tests.unit.onyx.chat.conftest import DEFAULT_SEARCH_ARGS
|
||||
from tests.unit.onyx.chat.conftest import QUERY
|
||||
|
||||
@ -44,6 +47,20 @@ def answer_instance(
|
||||
mock_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
mocker: MockerFixture,
|
||||
) -> Answer:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
|
||||
|
||||
|
||||
def _answer_fixture_impl(
|
||||
mock_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
rerank_settings: RerankingDetails | None = None,
|
||||
) -> Answer:
|
||||
return Answer(
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
@ -64,13 +81,13 @@ def answer_instance(
|
||||
llm=mock_llm,
|
||||
fast_llm=mock_llm,
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None),
|
||||
search_request=SearchRequest(query=QUERY),
|
||||
search_request=SearchRequest(query=QUERY, rerank_settings=rerank_settings),
|
||||
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
||||
current_agent_message_id=0,
|
||||
)
|
||||
|
||||
|
||||
def test_basic_answer(answer_instance: Answer) -> None:
|
||||
def test_basic_answer(answer_instance: Answer, mocker: MockerFixture) -> None:
|
||||
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
|
||||
mock_llm.stream.return_value = [
|
||||
AIMessageChunk(content="This is a "),
|
||||
@ -363,3 +380,49 @@ def test_is_cancelled(answer_instance: Answer) -> None:
|
||||
|
||||
# Verify LLM calls
|
||||
mock_llm.stream.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gpu_enabled,is_local_model",
|
||||
[
|
||||
(True, False),
|
||||
(False, True),
|
||||
(True, True),
|
||||
(False, False),
|
||||
],
|
||||
)
|
||||
def test_no_slow_reranking(
|
||||
gpu_enabled: bool,
|
||||
is_local_model: bool,
|
||||
mock_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=gpu_enabled,
|
||||
)
|
||||
rerank_settings = (
|
||||
None
|
||||
if is_local_model
|
||||
else RerankingDetails(
|
||||
rerank_model_name="test_model",
|
||||
rerank_api_url="test_url",
|
||||
rerank_api_key="test_key",
|
||||
num_rerank=10,
|
||||
rerank_provider_type=RerankerProvider.COHERE,
|
||||
)
|
||||
)
|
||||
answer_instance = _answer_fixture_impl(
|
||||
mock_llm, answer_style_config, prompt_config, rerank_settings=rerank_settings
|
||||
)
|
||||
|
||||
assert (
|
||||
answer_instance.graph_config.inputs.search_request.rerank_settings
|
||||
== rerank_settings
|
||||
)
|
||||
assert (
|
||||
answer_instance.graph_config.behavior.allow_agent_reranking == gpu_enabled
|
||||
or not is_local_model
|
||||
)
|
||||
|
@ -36,7 +36,12 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
mock_search_tool: SearchTool,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
question = config["question"]
|
||||
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user