updated rerank function arguments ()

This commit is contained in:
joachim-danswer 2025-02-13 14:13:14 -08:00 committed by GitHub
parent 29c84d7707
commit 667b9e04c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 49 deletions
backend/onyx
agents/agent_search/deep_search/shared/expanded_retrieval/nodes
context/search

@ -21,10 +21,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import retrieval_preprocessing
from onyx.context.search.models import RerankingDetails
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.context.search.postprocessing.postprocessing import should_rerank
from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings
def rerank_documents(
@ -39,6 +40,8 @@ def rerank_documents(
# Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections
# If no question defined/question is None in the state, use the original
# question from the search request as query
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
@ -47,39 +50,28 @@ def rerank_documents(
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
with get_session_context_manager() as db_session:
# we ignore some of the user specified fields since this search is
# internal to agentic search, but we still want to pass through
# persona (for stuff like document sets) and rerank settings
# (to not make an unnecessary db call).
search_request = SearchRequest(
query=question,
persona=graph_config.inputs.search_request.persona,
rerank_settings=graph_config.inputs.search_request.rerank_settings,
)
_search_query = retrieval_preprocessing(
search_request=search_request,
user=graph_config.tooling.search_tool.user, # bit of a hack
llm=graph_config.tooling.fast_llm,
db_session=db_session,
)
# skip section filtering
# Note that these are passed in values from the API and are overrides which are typically None
rerank_settings = graph_config.inputs.search_request.rerank_settings
if (
_search_query.rerank_settings
and _search_query.rerank_settings.rerank_model_name
and _search_query.rerank_settings.num_rerank > 0
and len(verified_documents) > 0
):
if rerank_settings is None:
with get_session_context_manager() as db_session:
search_settings = get_current_search_settings(db_session)
if not search_settings.disable_rerank_for_streaming:
rerank_settings = RerankingDetails.from_db_model(search_settings)
if should_rerank(rerank_settings) and len(verified_documents) > 0:
if len(verified_documents) > 1:
reranked_documents = rerank_sections(
_search_query,
verified_documents,
query_str=question,
# if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
)
else:
num = "No" if len(verified_documents) == 0 else "One"
logger.warning(f"{num} verified document(s) found, skipping reranking")
logger.warning(
f"{len(verified_documents)} verified document(s) found, skipping reranking"
)
reranked_documents = verified_documents
else:
logger.warning("No reranking settings found, using unranked documents")

@ -61,6 +61,8 @@ class SearchPipeline:
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
prompt_config: PromptConfig | None = None,
):
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
self.search_request = search_request
self.user = user
self.llm = llm

@ -15,6 +15,7 @@ from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import MAX_METRICS_CONTENT
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RerankMetricsContainer
from onyx.context.search.models import SearchQuery
from onyx.document_index.document_index_utils import (
@ -77,7 +78,8 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
@log_function_time(print_only=True)
def semantic_reranking(
query: SearchQuery,
query_str: str,
rerank_settings: RerankingDetails,
chunks: list[InferenceChunk],
model_min: int = CROSS_ENCODER_RANGE_MIN,
model_max: int = CROSS_ENCODER_RANGE_MAX,
@ -88,11 +90,9 @@ def semantic_reranking(
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
"""
rerank_settings = query.rerank_settings
if not rerank_settings or not rerank_settings.rerank_model_name:
# Should never reach this part of the flow without reranking settings
raise RuntimeError("Reranking flow should not be running")
assert (
rerank_settings.rerank_model_name
), "Reranking flow cannot run without a specific model"
chunks_to_rerank = chunks[: rerank_settings.num_rerank]
@ -107,7 +107,7 @@ def semantic_reranking(
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
for chunk in chunks_to_rerank
]
sim_scores_floats = cross_encoder.predict(query=query.query, passages=passages)
sim_scores_floats = cross_encoder.predict(query=query_str, passages=passages)
# Old logic to handle multiple cross-encoders preserved but not used
sim_scores = [numpy.array(sim_scores_floats)]
@ -165,8 +165,20 @@ def semantic_reranking(
return list(ranked_chunks), list(ranked_indices)
def should_rerank(rerank_settings: RerankingDetails | None) -> bool:
"""Based on the RerankingDetails model, only run rerank if the following conditions are met:
- rerank_model_name is not None
- num_rerank is greater than 0
"""
if not rerank_settings:
return False
return bool(rerank_settings.rerank_model_name and rerank_settings.num_rerank > 0)
def rerank_sections(
query: SearchQuery,
query_str: str,
rerank_settings: RerankingDetails,
sections_to_rerank: list[InferenceSection],
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> list[InferenceSection]:
@ -181,16 +193,13 @@ def rerank_sections(
"""
chunks_to_rerank = [section.center_chunk for section in sections_to_rerank]
if not query.rerank_settings:
# Should never reach this part of the flow without reranking settings
raise RuntimeError("Reranking settings not found")
ranked_chunks, _ = semantic_reranking(
query=query,
query_str=query_str,
rerank_settings=rerank_settings,
chunks=chunks_to_rerank,
rerank_metrics_callback=rerank_metrics_callback,
)
lower_chunks = chunks_to_rerank[query.rerank_settings.num_rerank :]
lower_chunks = chunks_to_rerank[rerank_settings.num_rerank :]
# Scores from rerank cannot be meaningfully combined with scores without rerank
# However the ordering is still important
@ -260,16 +269,13 @@ def search_postprocessing(
rerank_task_id = None
sections_yielded = False
if (
search_query.rerank_settings
and search_query.rerank_settings.rerank_model_name
and search_query.rerank_settings.num_rerank > 0
):
if should_rerank(search_query.rerank_settings):
post_processing_tasks.append(
FunctionCall(
rerank_sections,
(
search_query,
search_query.query,
search_query.rerank_settings, # Cannot be None here
retrieved_sections,
rerank_metrics_callback,
),