diff --git a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/rerank_documents.py b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/rerank_documents.py index 63bdbd015..f598a65fd 100644 --- a/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/rerank_documents.py +++ b/backend/onyx/agents/agent_search/deep_search/shared/expanded_retrieval/nodes/rerank_documents.py @@ -55,6 +55,7 @@ def rerank_documents( # Note that these are passed in values from the API and are overrides which are typically None rerank_settings = graph_config.inputs.search_request.rerank_settings + allow_agent_reranking = graph_config.behavior.allow_agent_reranking if rerank_settings is None: with get_session_context_manager() as db_session: @@ -62,23 +63,31 @@ def rerank_documents( if not search_settings.disable_rerank_for_streaming: rerank_settings = RerankingDetails.from_db_model(search_settings) + # Initial default: no reranking. Will be overwritten below if reranking is warranted + reranked_documents = verified_documents + if should_rerank(rerank_settings) and len(verified_documents) > 0: if len(verified_documents) > 1: - reranked_documents = rerank_sections( - query_str=question, - # if runnable, then rerank_settings is not None - rerank_settings=cast(RerankingDetails, rerank_settings), - sections_to_rerank=verified_documents, - ) + if not allow_agent_reranking: + logger.info("Use of local rerank model without GPU, skipping reranking") + # No reranking, stay with verified_documents as default + + else: + # Reranking is warranted, use the rerank_sections functon + reranked_documents = rerank_sections( + query_str=question, + # if runnable, then rerank_settings is not None + rerank_settings=cast(RerankingDetails, rerank_settings), + sections_to_rerank=verified_documents, + ) else: logger.warning( f"{len(verified_documents)} verified document(s) found, skipping reranking" ) - reranked_documents = verified_documents + # No reranking, stay with verified_documents as default else: logger.warning("No reranking settings found, using unranked documents") - reranked_documents = verified_documents - + # No reranking, stay with verified_documents as default if AGENT_RERANKING_STATS: fit_scores = get_fit_scores(verified_documents, reranked_documents) else: diff --git a/backend/onyx/agents/agent_search/models.py b/backend/onyx/agents/agent_search/models.py index 3bfaac66c..1904ae7ea 100644 --- a/backend/onyx/agents/agent_search/models.py +++ b/backend/onyx/agents/agent_search/models.py @@ -67,6 +67,7 @@ class GraphSearchConfig(BaseModel): # Whether to allow creation of refinement questions (and entity extraction, etc.) allow_refinement: bool = True skip_gen_ai_answer_generation: bool = False + allow_agent_reranking: bool = False class GraphConfig(BaseModel): diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index 8addfe5ba..118c7aaf5 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -29,6 +29,7 @@ from onyx.tools.force import ForceUseTool from onyx.tools.tool import Tool from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.utils import explicit_tool_calling_supported +from onyx.utils.gpu_utils import gpu_status_request from onyx.utils.logger import setup_logger logger = setup_logger() @@ -80,6 +81,14 @@ class Answer: and not skip_explicit_tool_calling ) + rerank_settings = search_request.rerank_settings + + using_cloud_reranking = ( + rerank_settings is not None + and rerank_settings.rerank_provider_type is not None + ) + allow_agent_reranking = gpu_status_request() or using_cloud_reranking + self.graph_inputs = GraphInputs( search_request=search_request, prompt_builder=prompt_builder, @@ -94,7 +103,6 @@ class Answer: force_use_tool=force_use_tool, using_tool_calling_llm=using_tool_calling_llm, ) - assert db_session, "db_session must be provided for agentic persistence" self.graph_persistence = GraphPersistence( db_session=db_session, chat_session_id=chat_session_id, @@ -104,6 +112,7 @@ class Answer: use_agentic_search=use_agentic_search, skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, allow_refinement=True, + allow_agent_reranking=allow_agent_reranking, ) self.graph_config = GraphConfig( inputs=self.graph_inputs, diff --git a/backend/tests/unit/onyx/chat/test_answer.py b/backend/tests/unit/onyx/chat/test_answer.py index 763a3b6f6..175a3d58a 100644 --- a/backend/tests/unit/onyx/chat/test_answer.py +++ b/backend/tests/unit/onyx/chat/test_answer.py @@ -11,6 +11,7 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import SystemMessage from langchain_core.messages import ToolCall from langchain_core.messages import ToolCallChunk +from pytest_mock import MockerFixture from sqlalchemy.orm import Session from onyx.chat.answer import Answer @@ -25,6 +26,7 @@ from onyx.chat.models import StreamStopReason from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message +from onyx.context.search.models import RerankingDetails from onyx.context.search.models import SearchRequest from onyx.llm.interfaces import LLM from onyx.tools.force import ForceUseTool @@ -35,6 +37,7 @@ from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTEN from onyx.tools.tool_implementations.search_like_tool_utils import ( FINAL_CONTEXT_DOCUMENTS_ID, ) +from shared_configs.enums import RerankerProvider from tests.unit.onyx.chat.conftest import DEFAULT_SEARCH_ARGS from tests.unit.onyx.chat.conftest import QUERY @@ -44,6 +47,20 @@ def answer_instance( mock_llm: LLM, answer_style_config: AnswerStyleConfig, prompt_config: PromptConfig, + mocker: MockerFixture, +) -> Answer: + mocker.patch( + "onyx.chat.answer.gpu_status_request", + return_value=True, + ) + return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config) + + +def _answer_fixture_impl( + mock_llm: LLM, + answer_style_config: AnswerStyleConfig, + prompt_config: PromptConfig, + rerank_settings: RerankingDetails | None = None, ) -> Answer: return Answer( prompt_builder=AnswerPromptBuilder( @@ -64,13 +81,13 @@ def answer_instance( llm=mock_llm, fast_llm=mock_llm, force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None), - search_request=SearchRequest(query=QUERY), + search_request=SearchRequest(query=QUERY, rerank_settings=rerank_settings), chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), current_agent_message_id=0, ) -def test_basic_answer(answer_instance: Answer) -> None: +def test_basic_answer(answer_instance: Answer, mocker: MockerFixture) -> None: mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm) mock_llm.stream.return_value = [ AIMessageChunk(content="This is a "), @@ -363,3 +380,49 @@ def test_is_cancelled(answer_instance: Answer) -> None: # Verify LLM calls mock_llm.stream.assert_called_once() + + +@pytest.mark.parametrize( + "gpu_enabled,is_local_model", + [ + (True, False), + (False, True), + (True, True), + (False, False), + ], +) +def test_no_slow_reranking( + gpu_enabled: bool, + is_local_model: bool, + mock_llm: LLM, + answer_style_config: AnswerStyleConfig, + prompt_config: PromptConfig, + mocker: MockerFixture, +) -> None: + mocker.patch( + "onyx.chat.answer.gpu_status_request", + return_value=gpu_enabled, + ) + rerank_settings = ( + None + if is_local_model + else RerankingDetails( + rerank_model_name="test_model", + rerank_api_url="test_url", + rerank_api_key="test_key", + num_rerank=10, + rerank_provider_type=RerankerProvider.COHERE, + ) + ) + answer_instance = _answer_fixture_impl( + mock_llm, answer_style_config, prompt_config, rerank_settings=rerank_settings + ) + + assert ( + answer_instance.graph_config.inputs.search_request.rerank_settings + == rerank_settings + ) + assert ( + answer_instance.graph_config.behavior.allow_agent_reranking == gpu_enabled + or not is_local_model + ) diff --git a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py index 00f77279a..146a08f60 100644 --- a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py +++ b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py @@ -36,7 +36,12 @@ def test_skip_gen_ai_answer_generation_flag( mock_search_tool: SearchTool, answer_style_config: AnswerStyleConfig, prompt_config: PromptConfig, + mocker: MockerFixture, ) -> None: + mocker.patch( + "onyx.chat.answer.gpu_status_request", + return_value=True, + ) question = config["question"] skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]