updated rerank function arguments (#3988)

This commit is contained in:
joachim-danswer 2025-02-13 14:13:14 -08:00 committed by GitHub
parent 29c84d7707
commit 667b9e04c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 49 deletions

View File

@ -21,10 +21,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RERANKING_STATS from onyx.configs.agent_configs import AGENT_RERANKING_STATS
from onyx.context.search.models import InferenceSection from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchRequest from onyx.context.search.models import RerankingDetails
from onyx.context.search.pipeline import retrieval_preprocessing
from onyx.context.search.postprocessing.postprocessing import rerank_sections from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.context.search.postprocessing.postprocessing import should_rerank
from onyx.db.engine import get_session_context_manager from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings
def rerank_documents( def rerank_documents(
@ -39,6 +40,8 @@ def rerank_documents(
# Rerank post retrieval and verification. First, create a search query # Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections # then create the list of reranked sections
# If no question defined/question is None in the state, use the original
# question from the search request as query
graph_config = cast(GraphConfig, config["metadata"]["config"]) graph_config = cast(GraphConfig, config["metadata"]["config"])
question = ( question = (
@ -47,39 +50,28 @@ def rerank_documents(
assert ( assert (
graph_config.tooling.search_tool graph_config.tooling.search_tool
), "search_tool must be provided for agentic search" ), "search_tool must be provided for agentic search"
with get_session_context_manager() as db_session:
# we ignore some of the user specified fields since this search is
# internal to agentic search, but we still want to pass through
# persona (for stuff like document sets) and rerank settings
# (to not make an unnecessary db call).
search_request = SearchRequest(
query=question,
persona=graph_config.inputs.search_request.persona,
rerank_settings=graph_config.inputs.search_request.rerank_settings,
)
_search_query = retrieval_preprocessing(
search_request=search_request,
user=graph_config.tooling.search_tool.user, # bit of a hack
llm=graph_config.tooling.fast_llm,
db_session=db_session,
)
# skip section filtering # Note that these are passed in values from the API and are overrides which are typically None
rerank_settings = graph_config.inputs.search_request.rerank_settings
if ( if rerank_settings is None:
_search_query.rerank_settings with get_session_context_manager() as db_session:
and _search_query.rerank_settings.rerank_model_name search_settings = get_current_search_settings(db_session)
and _search_query.rerank_settings.num_rerank > 0 if not search_settings.disable_rerank_for_streaming:
and len(verified_documents) > 0 rerank_settings = RerankingDetails.from_db_model(search_settings)
):
if should_rerank(rerank_settings) and len(verified_documents) > 0:
if len(verified_documents) > 1: if len(verified_documents) > 1:
reranked_documents = rerank_sections( reranked_documents = rerank_sections(
_search_query, query_str=question,
verified_documents, # if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
) )
else: else:
num = "No" if len(verified_documents) == 0 else "One" logger.warning(
logger.warning(f"{num} verified document(s) found, skipping reranking") f"{len(verified_documents)} verified document(s) found, skipping reranking"
)
reranked_documents = verified_documents reranked_documents = verified_documents
else: else:
logger.warning("No reranking settings found, using unranked documents") logger.warning("No reranking settings found, using unranked documents")

View File

@ -61,6 +61,8 @@ class SearchPipeline:
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
prompt_config: PromptConfig | None = None, prompt_config: PromptConfig | None = None,
): ):
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
self.search_request = search_request self.search_request = search_request
self.user = user self.user = user
self.llm = llm self.llm = llm

View File

@ -15,6 +15,7 @@ from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceChunkUncleaned from onyx.context.search.models import InferenceChunkUncleaned
from onyx.context.search.models import InferenceSection from onyx.context.search.models import InferenceSection
from onyx.context.search.models import MAX_METRICS_CONTENT from onyx.context.search.models import MAX_METRICS_CONTENT
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RerankMetricsContainer from onyx.context.search.models import RerankMetricsContainer
from onyx.context.search.models import SearchQuery from onyx.context.search.models import SearchQuery
from onyx.document_index.document_index_utils import ( from onyx.document_index.document_index_utils import (
@ -77,7 +78,8 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
@log_function_time(print_only=True) @log_function_time(print_only=True)
def semantic_reranking( def semantic_reranking(
query: SearchQuery, query_str: str,
rerank_settings: RerankingDetails,
chunks: list[InferenceChunk], chunks: list[InferenceChunk],
model_min: int = CROSS_ENCODER_RANGE_MIN, model_min: int = CROSS_ENCODER_RANGE_MIN,
model_max: int = CROSS_ENCODER_RANGE_MAX, model_max: int = CROSS_ENCODER_RANGE_MAX,
@ -88,11 +90,9 @@ def semantic_reranking(
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
""" """
rerank_settings = query.rerank_settings assert (
rerank_settings.rerank_model_name
if not rerank_settings or not rerank_settings.rerank_model_name: ), "Reranking flow cannot run without a specific model"
# Should never reach this part of the flow without reranking settings
raise RuntimeError("Reranking flow should not be running")
chunks_to_rerank = chunks[: rerank_settings.num_rerank] chunks_to_rerank = chunks[: rerank_settings.num_rerank]
@ -107,7 +107,7 @@ def semantic_reranking(
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}" f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
for chunk in chunks_to_rerank for chunk in chunks_to_rerank
] ]
sim_scores_floats = cross_encoder.predict(query=query.query, passages=passages) sim_scores_floats = cross_encoder.predict(query=query_str, passages=passages)
# Old logic to handle multiple cross-encoders preserved but not used # Old logic to handle multiple cross-encoders preserved but not used
sim_scores = [numpy.array(sim_scores_floats)] sim_scores = [numpy.array(sim_scores_floats)]
@ -165,8 +165,20 @@ def semantic_reranking(
return list(ranked_chunks), list(ranked_indices) return list(ranked_chunks), list(ranked_indices)
def should_rerank(rerank_settings: RerankingDetails | None) -> bool:
"""Based on the RerankingDetails model, only run rerank if the following conditions are met:
- rerank_model_name is not None
- num_rerank is greater than 0
"""
if not rerank_settings:
return False
return bool(rerank_settings.rerank_model_name and rerank_settings.num_rerank > 0)
def rerank_sections( def rerank_sections(
query: SearchQuery, query_str: str,
rerank_settings: RerankingDetails,
sections_to_rerank: list[InferenceSection], sections_to_rerank: list[InferenceSection],
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> list[InferenceSection]: ) -> list[InferenceSection]:
@ -181,16 +193,13 @@ def rerank_sections(
""" """
chunks_to_rerank = [section.center_chunk for section in sections_to_rerank] chunks_to_rerank = [section.center_chunk for section in sections_to_rerank]
if not query.rerank_settings:
# Should never reach this part of the flow without reranking settings
raise RuntimeError("Reranking settings not found")
ranked_chunks, _ = semantic_reranking( ranked_chunks, _ = semantic_reranking(
query=query, query_str=query_str,
rerank_settings=rerank_settings,
chunks=chunks_to_rerank, chunks=chunks_to_rerank,
rerank_metrics_callback=rerank_metrics_callback, rerank_metrics_callback=rerank_metrics_callback,
) )
lower_chunks = chunks_to_rerank[query.rerank_settings.num_rerank :] lower_chunks = chunks_to_rerank[rerank_settings.num_rerank :]
# Scores from rerank cannot be meaningfully combined with scores without rerank # Scores from rerank cannot be meaningfully combined with scores without rerank
# However the ordering is still important # However the ordering is still important
@ -260,16 +269,13 @@ def search_postprocessing(
rerank_task_id = None rerank_task_id = None
sections_yielded = False sections_yielded = False
if ( if should_rerank(search_query.rerank_settings):
search_query.rerank_settings
and search_query.rerank_settings.rerank_model_name
and search_query.rerank_settings.num_rerank > 0
):
post_processing_tasks.append( post_processing_tasks.append(
FunctionCall( FunctionCall(
rerank_sections, rerank_sections,
( (
search_query, search_query.query,
search_query.rerank_settings, # Cannot be None here
retrieved_sections, retrieved_sections,
rerank_metrics_callback, rerank_metrics_callback,
), ),