mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-23 18:20:11 +02:00
updated rerank function arguments (#3988)
This commit is contained in:
parent
29c84d7707
commit
667b9e04c5
@ -21,10 +21,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
|||||||
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
|
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
|
||||||
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
|
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
|
||||||
from onyx.context.search.models import InferenceSection
|
from onyx.context.search.models import InferenceSection
|
||||||
from onyx.context.search.models import SearchRequest
|
from onyx.context.search.models import RerankingDetails
|
||||||
from onyx.context.search.pipeline import retrieval_preprocessing
|
|
||||||
from onyx.context.search.postprocessing.postprocessing import rerank_sections
|
from onyx.context.search.postprocessing.postprocessing import rerank_sections
|
||||||
|
from onyx.context.search.postprocessing.postprocessing import should_rerank
|
||||||
from onyx.db.engine import get_session_context_manager
|
from onyx.db.engine import get_session_context_manager
|
||||||
|
from onyx.db.search_settings import get_current_search_settings
|
||||||
|
|
||||||
|
|
||||||
def rerank_documents(
|
def rerank_documents(
|
||||||
@ -39,6 +40,8 @@ def rerank_documents(
|
|||||||
|
|
||||||
# Rerank post retrieval and verification. First, create a search query
|
# Rerank post retrieval and verification. First, create a search query
|
||||||
# then create the list of reranked sections
|
# then create the list of reranked sections
|
||||||
|
# If no question defined/question is None in the state, use the original
|
||||||
|
# question from the search request as query
|
||||||
|
|
||||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||||
question = (
|
question = (
|
||||||
@ -47,39 +50,28 @@ def rerank_documents(
|
|||||||
assert (
|
assert (
|
||||||
graph_config.tooling.search_tool
|
graph_config.tooling.search_tool
|
||||||
), "search_tool must be provided for agentic search"
|
), "search_tool must be provided for agentic search"
|
||||||
with get_session_context_manager() as db_session:
|
|
||||||
# we ignore some of the user specified fields since this search is
|
|
||||||
# internal to agentic search, but we still want to pass through
|
|
||||||
# persona (for stuff like document sets) and rerank settings
|
|
||||||
# (to not make an unnecessary db call).
|
|
||||||
search_request = SearchRequest(
|
|
||||||
query=question,
|
|
||||||
persona=graph_config.inputs.search_request.persona,
|
|
||||||
rerank_settings=graph_config.inputs.search_request.rerank_settings,
|
|
||||||
)
|
|
||||||
_search_query = retrieval_preprocessing(
|
|
||||||
search_request=search_request,
|
|
||||||
user=graph_config.tooling.search_tool.user, # bit of a hack
|
|
||||||
llm=graph_config.tooling.fast_llm,
|
|
||||||
db_session=db_session,
|
|
||||||
)
|
|
||||||
|
|
||||||
# skip section filtering
|
# Note that these are passed in values from the API and are overrides which are typically None
|
||||||
|
rerank_settings = graph_config.inputs.search_request.rerank_settings
|
||||||
|
|
||||||
if (
|
if rerank_settings is None:
|
||||||
_search_query.rerank_settings
|
with get_session_context_manager() as db_session:
|
||||||
and _search_query.rerank_settings.rerank_model_name
|
search_settings = get_current_search_settings(db_session)
|
||||||
and _search_query.rerank_settings.num_rerank > 0
|
if not search_settings.disable_rerank_for_streaming:
|
||||||
and len(verified_documents) > 0
|
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||||
):
|
|
||||||
|
if should_rerank(rerank_settings) and len(verified_documents) > 0:
|
||||||
if len(verified_documents) > 1:
|
if len(verified_documents) > 1:
|
||||||
reranked_documents = rerank_sections(
|
reranked_documents = rerank_sections(
|
||||||
_search_query,
|
query_str=question,
|
||||||
verified_documents,
|
# if runnable, then rerank_settings is not None
|
||||||
|
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||||
|
sections_to_rerank=verified_documents,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
num = "No" if len(verified_documents) == 0 else "One"
|
logger.warning(
|
||||||
logger.warning(f"{num} verified document(s) found, skipping reranking")
|
f"{len(verified_documents)} verified document(s) found, skipping reranking"
|
||||||
|
)
|
||||||
reranked_documents = verified_documents
|
reranked_documents = verified_documents
|
||||||
else:
|
else:
|
||||||
logger.warning("No reranking settings found, using unranked documents")
|
logger.warning("No reranking settings found, using unranked documents")
|
||||||
|
@ -61,6 +61,8 @@ class SearchPipeline:
|
|||||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||||
prompt_config: PromptConfig | None = None,
|
prompt_config: PromptConfig | None = None,
|
||||||
):
|
):
|
||||||
|
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
|
||||||
|
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
|
||||||
self.search_request = search_request
|
self.search_request = search_request
|
||||||
self.user = user
|
self.user = user
|
||||||
self.llm = llm
|
self.llm = llm
|
||||||
|
@ -15,6 +15,7 @@ from onyx.context.search.models import InferenceChunk
|
|||||||
from onyx.context.search.models import InferenceChunkUncleaned
|
from onyx.context.search.models import InferenceChunkUncleaned
|
||||||
from onyx.context.search.models import InferenceSection
|
from onyx.context.search.models import InferenceSection
|
||||||
from onyx.context.search.models import MAX_METRICS_CONTENT
|
from onyx.context.search.models import MAX_METRICS_CONTENT
|
||||||
|
from onyx.context.search.models import RerankingDetails
|
||||||
from onyx.context.search.models import RerankMetricsContainer
|
from onyx.context.search.models import RerankMetricsContainer
|
||||||
from onyx.context.search.models import SearchQuery
|
from onyx.context.search.models import SearchQuery
|
||||||
from onyx.document_index.document_index_utils import (
|
from onyx.document_index.document_index_utils import (
|
||||||
@ -77,7 +78,8 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
|
|||||||
|
|
||||||
@log_function_time(print_only=True)
|
@log_function_time(print_only=True)
|
||||||
def semantic_reranking(
|
def semantic_reranking(
|
||||||
query: SearchQuery,
|
query_str: str,
|
||||||
|
rerank_settings: RerankingDetails,
|
||||||
chunks: list[InferenceChunk],
|
chunks: list[InferenceChunk],
|
||||||
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
||||||
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
||||||
@ -88,11 +90,9 @@ def semantic_reranking(
|
|||||||
|
|
||||||
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
||||||
"""
|
"""
|
||||||
rerank_settings = query.rerank_settings
|
assert (
|
||||||
|
rerank_settings.rerank_model_name
|
||||||
if not rerank_settings or not rerank_settings.rerank_model_name:
|
), "Reranking flow cannot run without a specific model"
|
||||||
# Should never reach this part of the flow without reranking settings
|
|
||||||
raise RuntimeError("Reranking flow should not be running")
|
|
||||||
|
|
||||||
chunks_to_rerank = chunks[: rerank_settings.num_rerank]
|
chunks_to_rerank = chunks[: rerank_settings.num_rerank]
|
||||||
|
|
||||||
@ -107,7 +107,7 @@ def semantic_reranking(
|
|||||||
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
|
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
|
||||||
for chunk in chunks_to_rerank
|
for chunk in chunks_to_rerank
|
||||||
]
|
]
|
||||||
sim_scores_floats = cross_encoder.predict(query=query.query, passages=passages)
|
sim_scores_floats = cross_encoder.predict(query=query_str, passages=passages)
|
||||||
|
|
||||||
# Old logic to handle multiple cross-encoders preserved but not used
|
# Old logic to handle multiple cross-encoders preserved but not used
|
||||||
sim_scores = [numpy.array(sim_scores_floats)]
|
sim_scores = [numpy.array(sim_scores_floats)]
|
||||||
@ -165,8 +165,20 @@ def semantic_reranking(
|
|||||||
return list(ranked_chunks), list(ranked_indices)
|
return list(ranked_chunks), list(ranked_indices)
|
||||||
|
|
||||||
|
|
||||||
|
def should_rerank(rerank_settings: RerankingDetails | None) -> bool:
|
||||||
|
"""Based on the RerankingDetails model, only run rerank if the following conditions are met:
|
||||||
|
- rerank_model_name is not None
|
||||||
|
- num_rerank is greater than 0
|
||||||
|
"""
|
||||||
|
if not rerank_settings:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return bool(rerank_settings.rerank_model_name and rerank_settings.num_rerank > 0)
|
||||||
|
|
||||||
|
|
||||||
def rerank_sections(
|
def rerank_sections(
|
||||||
query: SearchQuery,
|
query_str: str,
|
||||||
|
rerank_settings: RerankingDetails,
|
||||||
sections_to_rerank: list[InferenceSection],
|
sections_to_rerank: list[InferenceSection],
|
||||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||||
) -> list[InferenceSection]:
|
) -> list[InferenceSection]:
|
||||||
@ -181,16 +193,13 @@ def rerank_sections(
|
|||||||
"""
|
"""
|
||||||
chunks_to_rerank = [section.center_chunk for section in sections_to_rerank]
|
chunks_to_rerank = [section.center_chunk for section in sections_to_rerank]
|
||||||
|
|
||||||
if not query.rerank_settings:
|
|
||||||
# Should never reach this part of the flow without reranking settings
|
|
||||||
raise RuntimeError("Reranking settings not found")
|
|
||||||
|
|
||||||
ranked_chunks, _ = semantic_reranking(
|
ranked_chunks, _ = semantic_reranking(
|
||||||
query=query,
|
query_str=query_str,
|
||||||
|
rerank_settings=rerank_settings,
|
||||||
chunks=chunks_to_rerank,
|
chunks=chunks_to_rerank,
|
||||||
rerank_metrics_callback=rerank_metrics_callback,
|
rerank_metrics_callback=rerank_metrics_callback,
|
||||||
)
|
)
|
||||||
lower_chunks = chunks_to_rerank[query.rerank_settings.num_rerank :]
|
lower_chunks = chunks_to_rerank[rerank_settings.num_rerank :]
|
||||||
|
|
||||||
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
||||||
# However the ordering is still important
|
# However the ordering is still important
|
||||||
@ -260,16 +269,13 @@ def search_postprocessing(
|
|||||||
|
|
||||||
rerank_task_id = None
|
rerank_task_id = None
|
||||||
sections_yielded = False
|
sections_yielded = False
|
||||||
if (
|
if should_rerank(search_query.rerank_settings):
|
||||||
search_query.rerank_settings
|
|
||||||
and search_query.rerank_settings.rerank_model_name
|
|
||||||
and search_query.rerank_settings.num_rerank > 0
|
|
||||||
):
|
|
||||||
post_processing_tasks.append(
|
post_processing_tasks.append(
|
||||||
FunctionCall(
|
FunctionCall(
|
||||||
rerank_sections,
|
rerank_sections,
|
||||||
(
|
(
|
||||||
search_query,
|
search_query.query,
|
||||||
|
search_query.rerank_settings, # Cannot be None here
|
||||||
retrieved_sections,
|
retrieved_sections,
|
||||||
rerank_metrics_callback,
|
rerank_metrics_callback,
|
||||||
),
|
),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user