mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-13 13:20:15 +02:00
450 lines
18 KiB
Python
450 lines
18 KiB
Python
from collections import defaultdict
|
|
from collections.abc import Callable
|
|
from collections.abc import Iterator
|
|
from typing import cast
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.chat.models import PromptConfig
|
|
from onyx.chat.models import SectionRelevancePiece
|
|
from onyx.chat.prune_and_merge import _merge_sections
|
|
from onyx.chat.prune_and_merge import ChunkRange
|
|
from onyx.chat.prune_and_merge import merge_chunk_intervals
|
|
from onyx.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
|
from onyx.context.search.enums import LLMEvaluationType
|
|
from onyx.context.search.enums import QueryFlow
|
|
from onyx.context.search.enums import SearchType
|
|
from onyx.context.search.models import IndexFilters
|
|
from onyx.context.search.models import InferenceChunk
|
|
from onyx.context.search.models import InferenceSection
|
|
from onyx.context.search.models import RerankMetricsContainer
|
|
from onyx.context.search.models import RetrievalMetricsContainer
|
|
from onyx.context.search.models import SearchQuery
|
|
from onyx.context.search.models import SearchRequest
|
|
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
|
|
from onyx.context.search.postprocessing.postprocessing import search_postprocessing
|
|
from onyx.context.search.preprocessing.preprocessing import retrieval_preprocessing
|
|
from onyx.context.search.retrieval.search_runner import (
|
|
retrieve_chunks,
|
|
)
|
|
from onyx.context.search.utils import inference_section_from_chunks
|
|
from onyx.context.search.utils import relevant_sections_to_indices
|
|
from onyx.db.models import User
|
|
from onyx.db.search_settings import get_current_search_settings
|
|
from onyx.document_index.factory import get_default_document_index
|
|
from onyx.document_index.interfaces import VespaChunkRequest
|
|
from onyx.llm.interfaces import LLM
|
|
from onyx.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
|
|
from onyx.utils.logger import setup_logger
|
|
from onyx.utils.threadpool_concurrency import FunctionCall
|
|
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
|
from onyx.utils.timing import log_function_time
|
|
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
class SearchPipeline:
|
|
def __init__(
|
|
self,
|
|
search_request: SearchRequest,
|
|
user: User | None,
|
|
llm: LLM,
|
|
fast_llm: LLM,
|
|
skip_query_analysis: bool,
|
|
db_session: Session,
|
|
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
|
|
retrieval_metrics_callback: (
|
|
Callable[[RetrievalMetricsContainer], None] | None
|
|
) = None,
|
|
retrieved_sections_callback: Callable[[list[InferenceSection]], None]
|
|
| None = None,
|
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
|
prompt_config: PromptConfig | None = None,
|
|
):
|
|
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
|
|
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
|
|
self.search_request = search_request
|
|
self.user = user
|
|
self.llm = llm
|
|
self.fast_llm = fast_llm
|
|
self.skip_query_analysis = skip_query_analysis
|
|
self.db_session = db_session
|
|
self.bypass_acl = bypass_acl
|
|
self.retrieval_metrics_callback = retrieval_metrics_callback
|
|
self.rerank_metrics_callback = rerank_metrics_callback
|
|
|
|
self.search_settings = get_current_search_settings(db_session)
|
|
self.document_index = get_default_document_index(self.search_settings, None)
|
|
self.prompt_config: PromptConfig | None = prompt_config
|
|
|
|
# Preprocessing steps generate this
|
|
self._search_query: SearchQuery | None = None
|
|
self._predicted_search_type: SearchType | None = None
|
|
|
|
# Initial document index retrieval chunks
|
|
self._retrieved_chunks: list[InferenceChunk] | None = None
|
|
# Another call made to the document index to get surrounding sections
|
|
self._retrieved_sections: list[InferenceSection] | None = None
|
|
|
|
self.retrieved_sections_callback = retrieved_sections_callback
|
|
# Reranking and LLM section selection can be run together
|
|
# If only LLM selection is on, the reranked chunks are yielded immediatly
|
|
self._reranked_sections: list[InferenceSection] | None = None
|
|
self._final_context_sections: list[InferenceSection] | None = None
|
|
|
|
self._section_relevance: list[SectionRelevancePiece] | None = None
|
|
|
|
# Generates reranked chunks and LLM selections
|
|
self._postprocessing_generator: (
|
|
Iterator[list[InferenceSection] | list[SectionRelevancePiece]] | None
|
|
) = None
|
|
|
|
# No longer computed but keeping around in case it's reintroduced later
|
|
self._predicted_flow: QueryFlow | None = QueryFlow.QUESTION_ANSWER
|
|
|
|
"""Pre-processing"""
|
|
|
|
def _run_preprocessing(self) -> None:
|
|
final_search_query = retrieval_preprocessing(
|
|
search_request=self.search_request,
|
|
user=self.user,
|
|
llm=self.llm,
|
|
skip_query_analysis=self.skip_query_analysis,
|
|
db_session=self.db_session,
|
|
bypass_acl=self.bypass_acl,
|
|
)
|
|
self._search_query = final_search_query
|
|
self._predicted_search_type = final_search_query.search_type
|
|
|
|
@property
|
|
def search_query(self) -> SearchQuery:
|
|
if self._search_query is not None:
|
|
return self._search_query
|
|
|
|
self._run_preprocessing()
|
|
|
|
return cast(SearchQuery, self._search_query)
|
|
|
|
@property
|
|
def predicted_search_type(self) -> SearchType:
|
|
if self._predicted_search_type is not None:
|
|
return self._predicted_search_type
|
|
|
|
self._run_preprocessing()
|
|
return cast(SearchType, self._predicted_search_type)
|
|
|
|
@property
|
|
def predicted_flow(self) -> QueryFlow:
|
|
if self._predicted_flow is not None:
|
|
return self._predicted_flow
|
|
|
|
self._run_preprocessing()
|
|
return cast(QueryFlow, self._predicted_flow)
|
|
|
|
"""Retrieval and Postprocessing"""
|
|
|
|
def _get_chunks(self) -> list[InferenceChunk]:
|
|
if self._retrieved_chunks is not None:
|
|
return self._retrieved_chunks
|
|
|
|
# These chunks do not include large chunks and have been deduped
|
|
self._retrieved_chunks = retrieve_chunks(
|
|
query=self.search_query,
|
|
document_index=self.document_index,
|
|
db_session=self.db_session,
|
|
retrieval_metrics_callback=self.retrieval_metrics_callback,
|
|
)
|
|
|
|
return cast(list[InferenceChunk], self._retrieved_chunks)
|
|
|
|
@log_function_time(print_only=True)
|
|
def _get_sections(self) -> list[InferenceSection]:
|
|
"""Returns an expanded section from each of the chunks.
|
|
If whole docs (instead of above/below context) is specified then it will give back all of the whole docs
|
|
that have a corresponding chunk.
|
|
|
|
This step should be fast for any document index implementation.
|
|
|
|
Current implementation timing is approximately broken down in timing as:
|
|
- 200 ms to get the embedding of the query
|
|
- 15 ms to get chunks from the document index
|
|
- possibly more to get additional surrounding chunks
|
|
- possibly more for query expansion (multilingual)
|
|
"""
|
|
if self._retrieved_sections is not None:
|
|
return self._retrieved_sections
|
|
|
|
# These chunks are ordered, deduped, and contain no large chunks
|
|
retrieved_chunks = self._get_chunks()
|
|
|
|
# If ee is enabled, censor the chunk sections based on user access
|
|
# Otherwise, return the retrieved chunks
|
|
censored_chunks = fetch_ee_implementation_or_noop(
|
|
"onyx.external_permissions.post_query_censoring",
|
|
"_post_query_chunk_censoring",
|
|
retrieved_chunks,
|
|
)(
|
|
chunks=retrieved_chunks,
|
|
user=self.user,
|
|
)
|
|
|
|
above = self.search_query.chunks_above
|
|
below = self.search_query.chunks_below
|
|
|
|
expanded_inference_sections = []
|
|
inference_chunks: list[InferenceChunk] = []
|
|
chunk_requests: list[VespaChunkRequest] = []
|
|
|
|
# Full doc setting takes priority
|
|
if self.search_query.full_doc:
|
|
seen_document_ids = set()
|
|
|
|
# This preserves the ordering since the chunks are retrieved in score order
|
|
for chunk in censored_chunks:
|
|
if chunk.document_id not in seen_document_ids:
|
|
seen_document_ids.add(chunk.document_id)
|
|
chunk_requests.append(
|
|
VespaChunkRequest(
|
|
document_id=chunk.document_id,
|
|
)
|
|
)
|
|
|
|
inference_chunks.extend(
|
|
cleanup_chunks(
|
|
self.document_index.id_based_retrieval(
|
|
chunk_requests=chunk_requests,
|
|
filters=IndexFilters(access_control_list=None),
|
|
)
|
|
)
|
|
)
|
|
|
|
# Create a dictionary to group chunks by document_id
|
|
grouped_inference_chunks: dict[str, list[InferenceChunk]] = {}
|
|
for chunk in inference_chunks:
|
|
if chunk.document_id not in grouped_inference_chunks:
|
|
grouped_inference_chunks[chunk.document_id] = []
|
|
grouped_inference_chunks[chunk.document_id].append(chunk)
|
|
|
|
for chunk_group in grouped_inference_chunks.values():
|
|
inference_section = inference_section_from_chunks(
|
|
center_chunk=chunk_group[0],
|
|
chunks=chunk_group,
|
|
)
|
|
|
|
if inference_section is not None:
|
|
expanded_inference_sections.append(inference_section)
|
|
else:
|
|
logger.warning(
|
|
"Skipped creation of section for full docs, no chunks found"
|
|
)
|
|
|
|
self._retrieved_sections = expanded_inference_sections
|
|
return expanded_inference_sections
|
|
|
|
# General flow:
|
|
# - Combine chunks into lists by document_id
|
|
# - For each document, run merge-intervals to get combined ranges
|
|
# - This allows for less queries to the document index
|
|
# - Fetch all of the new chunks with contents for the combined ranges
|
|
# - Reiterate the chunks again and map to the results above based on the chunk.
|
|
# This maintains the original chunks ordering. Note, we cannot simply sort by score here
|
|
# as reranking flow may wipe the scores for a lot of the chunks.
|
|
doc_chunk_ranges_map = defaultdict(list)
|
|
for chunk in censored_chunks:
|
|
# The list of ranges for each document is ordered by score
|
|
doc_chunk_ranges_map[chunk.document_id].append(
|
|
ChunkRange(
|
|
chunks=[chunk],
|
|
start=max(0, chunk.chunk_id - above),
|
|
# No max known ahead of time, filter will handle this anyway
|
|
end=chunk.chunk_id + below,
|
|
)
|
|
)
|
|
|
|
# List of ranges, outside list represents documents, inner list represents ranges
|
|
merged_ranges = [
|
|
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
|
|
]
|
|
|
|
flat_ranges: list[ChunkRange] = [r for ranges in merged_ranges for r in ranges]
|
|
|
|
for chunk_range in flat_ranges:
|
|
# Don't need to fetch chunks within range for merging if chunk_above / below are 0.
|
|
if above == below == 0:
|
|
inference_chunks.extend(chunk_range.chunks)
|
|
|
|
else:
|
|
chunk_requests.append(
|
|
VespaChunkRequest(
|
|
document_id=chunk_range.chunks[0].document_id,
|
|
min_chunk_ind=chunk_range.start,
|
|
max_chunk_ind=chunk_range.end,
|
|
)
|
|
)
|
|
|
|
if chunk_requests:
|
|
inference_chunks.extend(
|
|
cleanup_chunks(
|
|
self.document_index.id_based_retrieval(
|
|
chunk_requests=chunk_requests,
|
|
filters=IndexFilters(access_control_list=None),
|
|
batch_retrieval=True,
|
|
)
|
|
)
|
|
)
|
|
|
|
doc_chunk_ind_to_chunk = {
|
|
(chunk.document_id, chunk.chunk_id): chunk for chunk in inference_chunks
|
|
}
|
|
|
|
# In case of failed parallel calls to Vespa, at least we should have the initial retrieved chunks
|
|
doc_chunk_ind_to_chunk.update(
|
|
{(chunk.document_id, chunk.chunk_id): chunk for chunk in censored_chunks}
|
|
)
|
|
|
|
# Build the surroundings for all of the initial retrieved chunks
|
|
for chunk in censored_chunks:
|
|
start_ind = max(0, chunk.chunk_id - above)
|
|
end_ind = chunk.chunk_id + below
|
|
|
|
# Since the index of the max_chunk is unknown, just allow it to be None and filter after
|
|
surrounding_chunks_or_none = [
|
|
doc_chunk_ind_to_chunk.get((chunk.document_id, chunk_ind))
|
|
for chunk_ind in range(start_ind, end_ind + 1) # end_ind is inclusive
|
|
]
|
|
# The None will apply to the would be "chunks" that are larger than the index of the last chunk
|
|
# of the document
|
|
surrounding_chunks = [
|
|
chunk for chunk in surrounding_chunks_or_none if chunk is not None
|
|
]
|
|
|
|
inference_section = inference_section_from_chunks(
|
|
center_chunk=chunk,
|
|
chunks=surrounding_chunks,
|
|
)
|
|
if inference_section is not None:
|
|
expanded_inference_sections.append(inference_section)
|
|
else:
|
|
logger.warning("Skipped creation of section, no chunks found")
|
|
|
|
self._retrieved_sections = expanded_inference_sections
|
|
return expanded_inference_sections
|
|
|
|
@property
|
|
def retrieved_sections(self) -> list[InferenceSection]:
|
|
if self._retrieved_sections is not None:
|
|
return self._retrieved_sections
|
|
|
|
self._retrieved_sections = self._get_sections()
|
|
return self._retrieved_sections
|
|
|
|
@property
|
|
def reranked_sections(self) -> list[InferenceSection]:
|
|
"""Reranking is always done at the chunk level since section merging could create arbitrarily
|
|
long sections which could be:
|
|
1. Longer than the maximum context limit of even large rerankers
|
|
2. Slow to calculate due to the quadratic scaling laws of Transformers
|
|
|
|
See implementation in search_postprocessing for details
|
|
"""
|
|
if self._reranked_sections is not None:
|
|
return self._reranked_sections
|
|
|
|
retrieved_sections = self.retrieved_sections
|
|
if self.retrieved_sections_callback is not None:
|
|
self.retrieved_sections_callback(retrieved_sections)
|
|
|
|
self._postprocessing_generator = search_postprocessing(
|
|
search_query=self.search_query,
|
|
retrieved_sections=retrieved_sections,
|
|
llm=self.fast_llm,
|
|
rerank_metrics_callback=self.rerank_metrics_callback,
|
|
)
|
|
|
|
self._reranked_sections = cast(
|
|
list[InferenceSection], next(self._postprocessing_generator)
|
|
)
|
|
|
|
return self._reranked_sections
|
|
|
|
@property
|
|
def final_context_sections(self) -> list[InferenceSection]:
|
|
if self._final_context_sections is not None:
|
|
return self._final_context_sections
|
|
|
|
self._final_context_sections = _merge_sections(sections=self.reranked_sections)
|
|
return self._final_context_sections
|
|
|
|
@property
|
|
def section_relevance(self) -> list[SectionRelevancePiece] | None:
|
|
if self._section_relevance is not None:
|
|
return self._section_relevance
|
|
|
|
if (
|
|
self.search_query.evaluation_type == LLMEvaluationType.SKIP
|
|
or DISABLE_LLM_DOC_RELEVANCE
|
|
):
|
|
return None
|
|
|
|
if self.search_query.evaluation_type == LLMEvaluationType.UNSPECIFIED:
|
|
raise ValueError(
|
|
"Attempted to access section relevance scores on search query with evaluation type `UNSPECIFIED`."
|
|
+ "The search query evaluation type should have been specified."
|
|
)
|
|
|
|
if self.search_query.evaluation_type == LLMEvaluationType.AGENTIC:
|
|
sections = self.final_context_sections
|
|
functions = [
|
|
FunctionCall(
|
|
evaluate_inference_section,
|
|
(section, self.search_query.query, self.llm),
|
|
)
|
|
for section in sections
|
|
]
|
|
try:
|
|
results = run_functions_in_parallel(function_calls=functions)
|
|
self._section_relevance = list(results.values())
|
|
except Exception as e:
|
|
raise ValueError(
|
|
"An issue occured during the agentic evaluation process."
|
|
) from e
|
|
|
|
elif self.search_query.evaluation_type == LLMEvaluationType.BASIC:
|
|
if DISABLE_LLM_DOC_RELEVANCE:
|
|
raise ValueError(
|
|
"Basic search evaluation operation called while DISABLE_LLM_DOC_RELEVANCE is enabled."
|
|
)
|
|
self._section_relevance = next(
|
|
cast(
|
|
Iterator[list[SectionRelevancePiece]],
|
|
self._postprocessing_generator,
|
|
)
|
|
)
|
|
|
|
else:
|
|
# All other cases should have been handled above
|
|
raise ValueError(
|
|
f"Unexpected evaluation type: {self.search_query.evaluation_type}"
|
|
)
|
|
|
|
return self._section_relevance
|
|
|
|
@property
|
|
def section_relevance_list(self) -> list[bool]:
|
|
return section_relevance_list_impl(
|
|
section_relevance=self.section_relevance,
|
|
final_context_sections=self.final_context_sections,
|
|
)
|
|
|
|
|
|
def section_relevance_list_impl(
|
|
section_relevance: list[SectionRelevancePiece] | None,
|
|
final_context_sections: list[InferenceSection],
|
|
) -> list[bool]:
|
|
llm_indices = relevant_sections_to_indices(
|
|
relevance_sections=section_relevance,
|
|
items=final_context_sections,
|
|
)
|
|
return [ind in llm_indices for ind in range(len(final_context_sections))]
|