Refactor search pipeline

This commit is contained in:
Weves 2024-03-23 20:12:23 -07:00 committed by Chris Weaver
parent 7a861ecec4
commit 1ba74ee4df
24 changed files with 827 additions and 869 deletions

View File

@ -9,7 +9,7 @@ from alembic import op
import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.search.models import RecencyBiasSetting
from danswer.search.enums import RecencyBiasSetting
from danswer.search.models import SearchType
# revision identifiers, used by Alembic.

View File

@ -13,7 +13,7 @@ from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Prompt as PromptDBModel
from danswer.search.models import RecencyBiasSetting
from danswer.search.enums import RecencyBiasSetting
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:

View File

@ -5,10 +5,10 @@ from typing import Any
from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from danswer.search.models import QueryFlow
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.search.models import SearchType
class LlmDoc(BaseModel):

View File

@ -53,11 +53,10 @@ from danswer.llm.utils import tokenizer_trim_content
from danswer.llm.utils import translate_history_to_basemessages
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
from danswer.search.request_preprocessing import retrieval_preprocessing
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import full_chunk_search_generator
from danswer.search.search_runner import inference_documents_from_ids
from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
from danswer.search.retrieval.search_runner import inference_documents_from_ids
from danswer.search.utils import chunks_to_search_docs
from danswer.secondary_llm_flows.choose_search import check_if_need_search
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
@ -377,37 +376,25 @@ def stream_chat_message_objects(
else query_override
)
(
retrieval_request,
predicted_search_type,
predicted_flow,
) = retrieval_preprocessing(
query=rephrased_query,
retrieval_details=cast(RetrievalDetails, retrieval_options),
persona=persona,
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=rephrased_query,
human_selected_filters=retrieval_options.filters
if retrieval_options
else None,
persona=persona,
offset=retrieval_options.offset if retrieval_options else None,
limit=retrieval_options.limit if retrieval_options else None,
),
user=user,
db_session=db_session,
)
documents_generator = full_chunk_search_generator(
search_query=retrieval_request,
document_index=document_index,
db_session=db_session,
)
time_cutoff = retrieval_request.filters.time_cutoff
recency_bias_multiplier = retrieval_request.recency_bias_multiplier
run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter
# First fetch and return the top chunks to the UI so the user can
# immediately see some results
top_chunks = cast(list[InferenceChunk], next(documents_generator))
top_chunks = search_pipeline.reranked_docs
top_docs = chunks_to_search_docs(top_chunks)
# Get ranking of the documents for citation purposes later
doc_id_to_rank_map = map_document_id_order(
cast(list[InferenceChunk | LlmDoc], top_chunks)
)
top_docs = chunks_to_search_docs(top_chunks)
doc_id_to_rank_map = map_document_id_order(top_chunks)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
@ -422,24 +409,17 @@ def stream_chat_message_objects(
initial_response = QADocsResponse(
rephrased_query=rephrased_query,
top_documents=response_docs,
predicted_flow=predicted_flow,
predicted_search=predicted_search_type,
applied_source_filters=retrieval_request.filters.source_type,
applied_time_cutoff=time_cutoff,
recency_bias_multiplier=recency_bias_multiplier,
predicted_flow=search_pipeline.predicted_flow,
predicted_search=search_pipeline.predicted_search_type,
applied_source_filters=search_pipeline.search_query.filters.source_type,
applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff,
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
)
yield initial_response
# Get the final ordering of chunks for the LLM call
llm_chunk_selection = cast(list[bool], next(documents_generator))
# Yield the list of LLM selected chunks for showing the LLM selected icons in the UI
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
relevant_chunk_indices=[
index for index, value in enumerate(llm_chunk_selection) if value
]
if run_llm_chunk_filter
else []
relevant_chunk_indices=search_pipeline.relevant_chunk_indicies
)
yield llm_relevance_filtering_response
@ -467,7 +447,7 @@ def stream_chat_message_objects(
)
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
llm_chunk_selection=search_pipeline.chunk_relevance_list,
token_limit=chunk_token_limit,
llm_tokenizer=llm_tokenizer,
)

View File

@ -27,7 +27,7 @@ from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import StarterMessage
from danswer.db.models import User__UserGroup
from danswer.search.models import RecencyBiasSetting
from danswer.search.enums import RecencyBiasSetting
from danswer.search.models import RetrievalDocs
from danswer.search.models import SavedSearchDoc
from danswer.search.models import SearchDoc as ServerSearchDoc

View File

@ -36,8 +36,8 @@ from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.connectors.models import InputType
from danswer.dynamic_configs.interface import JSON_ro
from danswer.search.models import RecencyBiasSetting
from danswer.search.models import SearchType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
class IndexingStatus(str, PyEnum):

View File

@ -12,7 +12,7 @@ from danswer.db.models import Persona
from danswer.db.models import Persona__DocumentSet
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.search.models import RecencyBiasSetting
from danswer.search.enums import RecencyBiasSetting
def _build_persona_name(channel_names: list[str]) -> str:

View File

@ -64,8 +64,8 @@ from danswer.document_index.vespa.utils import remove_invalid_unicode_chars
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.indexing.models import InferenceChunk
from danswer.search.models import IndexFilters
from danswer.search.search_runner import query_processing
from danswer.search.search_runner import remove_stop_words_and_punctuation
from danswer.search.retrieval.search_runner import query_processing
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel

View File

@ -1,7 +1,6 @@
import itertools
from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
@ -33,11 +32,9 @@ from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import get_persona_by_id
from danswer.db.chat import get_prompt_by_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.models import Prompt
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import get_default_llm_token_encode
@ -55,9 +52,9 @@ from danswer.prompts.prompt_utils import build_task_prompt_reminders
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SavedSearchDoc
from danswer.search.request_preprocessing import retrieval_preprocessing
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import full_chunk_search_generator
from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
from danswer.search.utils import chunks_to_search_docs
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
@ -221,12 +218,6 @@ def stream_answer_objects(
llm_tokenizer = get_default_llm_token_encode()
embedding_model = get_current_db_embedding_model(db_session)
document_index = get_default_document_index(
primary_index_name=embedding_model.index_name, secondary_index_name=None
)
# Create a chat session which will just store the root message, the query, and the AI response
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
@ -244,33 +235,23 @@ def stream_answer_objects(
# In chat flow it's given back along with the documents
yield QueryRephrase(rephrased_query=rephrased_query)
(
retrieval_request,
predicted_search_type,
predicted_flow,
) = retrieval_preprocessing(
query=rephrased_query,
retrieval_details=query_req.retrieval_options,
persona=chat_session.persona,
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=rephrased_query,
human_selected_filters=query_req.retrieval_options.filters,
persona=chat_session.persona,
offset=query_req.retrieval_options.offset,
limit=query_req.retrieval_options.limit,
),
user=user,
db_session=db_session,
bypass_acl=bypass_acl,
)
documents_generator = full_chunk_search_generator(
search_query=retrieval_request,
document_index=document_index,
db_session=db_session,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
applied_time_cutoff = retrieval_request.filters.time_cutoff
recency_bias_multiplier = retrieval_request.recency_bias_multiplier
run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter
# First fetch and return the top chunks so the user can immediately see some results
top_chunks = cast(list[InferenceChunk], next(documents_generator))
top_chunks = search_pipeline.reranked_docs
top_docs = chunks_to_search_docs(top_chunks)
fake_saved_docs = [SavedSearchDoc.from_search_doc(doc) for doc in top_docs]
@ -278,24 +259,17 @@ def stream_answer_objects(
initial_response = QADocsResponse(
rephrased_query=rephrased_query,
top_documents=fake_saved_docs,
predicted_flow=predicted_flow,
predicted_search=predicted_search_type,
applied_source_filters=retrieval_request.filters.source_type,
applied_time_cutoff=applied_time_cutoff,
recency_bias_multiplier=recency_bias_multiplier,
predicted_flow=search_pipeline.predicted_flow,
predicted_search=search_pipeline.predicted_search_type,
applied_source_filters=search_pipeline.search_query.filters.source_type,
applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff,
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
)
yield initial_response
# Get the final ordering of chunks for the LLM call
llm_chunk_selection = cast(list[bool], next(documents_generator))
# Yield the list of LLM selected chunks for showing the LLM selected icons in the UI
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
relevant_chunk_indices=[
index for index, value in enumerate(llm_chunk_selection) if value
]
if run_llm_chunk_filter
else []
relevant_chunk_indices=search_pipeline.relevant_chunk_indicies
)
yield llm_relevance_filtering_response
@ -317,7 +291,7 @@ def stream_answer_objects(
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
llm_chunk_selection=search_pipeline.chunk_relevance_list,
token_limit=chunk_token_limit,
)
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]

View File

@ -0,0 +1,30 @@
"""NOTE: this needs to be separate from models.py because of circular imports.
Both search/models.py and db/models.py import enums from this file AND
search/models.py imports from db/models.py."""
from enum import Enum
class OptionalSearchSetting(str, Enum):
ALWAYS = "always"
NEVER = "never"
# Determine whether to run search based on history and latest query
AUTO = "auto"
class RecencyBiasSetting(str, Enum):
FAVOR_RECENT = "favor_recent" # 2x decay rate
BASE_DECAY = "base_decay"
NO_DECAY = "no_decay"
# Determine based on query if to use base_decay or favor_recent
AUTO = "auto"
class SearchType(str, Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
HYBRID = "hybrid"
class QueryFlow(str, Enum):
SEARCH = "search"
QUESTION_ANSWER = "question-answer"

View File

@ -1,46 +1,24 @@
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.db.models import Persona
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import SearchType
MAX_METRICS_CONTENT = (
200 # Just need enough characters to identify where in the doc the chunk is
)
class OptionalSearchSetting(str, Enum):
ALWAYS = "always"
NEVER = "never"
# Determine whether to run search based on history and latest query
AUTO = "auto"
class RecencyBiasSetting(str, Enum):
FAVOR_RECENT = "favor_recent" # 2x decay rate
BASE_DECAY = "base_decay"
NO_DECAY = "no_decay"
# Determine based on query if to use base_decay or favor_recent
AUTO = "auto"
class SearchType(str, Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
HYBRID = "hybrid"
class QueryFlow(str, Enum):
SEARCH = "search"
QUESTION_ANSWER = "question-answer"
class Tag(BaseModel):
tag_key: str
tag_value: str
@ -64,6 +42,28 @@ class ChunkMetric(BaseModel):
score: float
class SearchRequest(BaseModel):
"""Input to the SearchPipeline."""
query: str
search_type: SearchType = SearchType.HYBRID
human_selected_filters: BaseFilters | None = None
enable_auto_detect_filters: bool | None = None
persona: Persona | None = None
# if None, no offset / limit
offset: int | None = None
limit: int | None = None
recency_bias_multiplier: float = 1.0
hybrid_alpha: float = HYBRID_ALPHA
skip_rerank: bool = True
class Config:
arbitrary_types_allowed = True
class SearchQuery(BaseModel):
query: str
filters: IndexFilters

View File

@ -0,0 +1,152 @@
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchRequest
from danswer.search.postprocessing.postprocessing import search_postprocessing
from danswer.search.preprocessing.preprocessing import retrieval_preprocessing
from danswer.search.retrieval.search_runner import retrieve_chunks
class SearchPipeline:
def __init__(
self,
search_request: SearchRequest,
user: User | None,
db_session: Session,
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
):
self.search_request = search_request
self.user = user
self.db_session = db_session
self.bypass_acl = bypass_acl
self.retrieval_metrics_callback = retrieval_metrics_callback
self.rerank_metrics_callback = rerank_metrics_callback
self.embedding_model = get_current_db_embedding_model(db_session)
self.document_index = get_default_document_index(
primary_index_name=self.embedding_model.index_name,
secondary_index_name=None,
)
self._search_query: SearchQuery | None = None
self._predicted_search_type: SearchType | None = None
self._predicted_flow: QueryFlow | None = None
self._retrieved_docs: list[InferenceChunk] | None = None
self._reranked_docs: list[InferenceChunk] | None = None
self._relevant_chunk_indicies: list[int] | None = None
"""Pre-processing"""
def _run_preprocessing(self) -> None:
(
final_search_query,
predicted_search_type,
predicted_flow,
) = retrieval_preprocessing(
search_request=self.search_request,
user=self.user,
db_session=self.db_session,
bypass_acl=self.bypass_acl,
)
self._predicted_search_type = predicted_search_type
self._predicted_flow = predicted_flow
self._search_query = final_search_query
@property
def search_query(self) -> SearchQuery:
if self._search_query is not None:
return self._search_query
self._run_preprocessing()
return cast(SearchQuery, self._search_query)
@property
def predicted_search_type(self) -> SearchType:
if self._predicted_search_type is not None:
return self._predicted_search_type
self._run_preprocessing()
return cast(SearchType, self._predicted_search_type)
@property
def predicted_flow(self) -> QueryFlow:
if self._predicted_flow is not None:
return self._predicted_flow
self._run_preprocessing()
return cast(QueryFlow, self._predicted_flow)
"""Retrieval"""
@property
def retrieved_docs(self) -> list[InferenceChunk]:
if self._retrieved_docs is not None:
return self._retrieved_docs
self._retrieved_docs = retrieve_chunks(
query=self.search_query,
document_index=self.document_index,
db_session=self.db_session,
hybrid_alpha=self.search_request.hybrid_alpha,
multilingual_expansion_str=MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback=self.retrieval_metrics_callback,
)
# self._retrieved_docs = chunks_to_search_docs(retrieved_chunks)
return cast(list[InferenceChunk], self._retrieved_docs)
"""Post-Processing"""
def _run_postprocessing(self) -> None:
postprocessing_generator = search_postprocessing(
search_query=self.search_query,
retrieved_chunks=self.retrieved_docs,
rerank_metrics_callback=self.rerank_metrics_callback,
)
self._reranked_docs = cast(list[InferenceChunk], next(postprocessing_generator))
relevant_chunk_ids = cast(list[str], next(postprocessing_generator))
self._relevant_chunk_indicies = [
ind
for ind, chunk in enumerate(self._reranked_docs)
if chunk.unique_id in relevant_chunk_ids
]
@property
def reranked_docs(self) -> list[InferenceChunk]:
if self._reranked_docs is not None:
return self._reranked_docs
self._run_postprocessing()
return cast(list[InferenceChunk], self._reranked_docs)
@property
def relevant_chunk_indicies(self) -> list[int]:
if self._relevant_chunk_indicies is not None:
return self._relevant_chunk_indicies
self._run_postprocessing()
return cast(list[int], self._relevant_chunk_indicies)
@property
def chunk_relevance_list(self) -> list[bool]:
return [
True if ind in self.relevant_chunk_indicies else False
for ind in range(len(self.reranked_docs))
]

View File

@ -0,0 +1,222 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import cast
import numpy
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.indexing.models import InferenceChunk
from danswer.search.models import ChunkMetric
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time
logger = setup_logger()
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
top_links = [
c.source_links[0] if c.source_links is not None else "No Link" for c in chunks
]
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
def should_rerank(query: SearchQuery) -> bool:
# Don't re-rank for keyword search
return query.search_type != SearchType.KEYWORD and not query.skip_rerank
def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool:
return not query.skip_llm_chunk_filter
@log_function_time(print_only=True)
def semantic_reranking(
query: str,
chunks: list[InferenceChunk],
model_min: int = CROSS_ENCODER_RANGE_MIN,
model_max: int = CROSS_ENCODER_RANGE_MAX,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk], list[int]]:
"""Reranks chunks based on cross-encoder models. Additionally provides the original indices
of the chunks in their new sorted order.
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
"""
cross_encoders = CrossEncoderEnsembleModel()
passages = [chunk.content for chunk in chunks]
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))
cross_models_min = numpy.min(sim_scores)
shifted_sim_scores = sum(
[enc_n_scores - cross_models_min for enc_n_scores in sim_scores]
) / len(sim_scores)
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
recency_multiplier = [chunk.recency_bias for chunk in chunks]
boosted_sim_scores = shifted_sim_scores * boosts * recency_multiplier
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
model_max - model_min
)
orig_indices = [i for i in range(len(normalized_b_s_scores))]
scored_results = list(
zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices)
)
scored_results.sort(key=lambda x: x[0], reverse=True)
ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip(
*scored_results
)
logger.debug(
f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}"
)
# Assign new chunk scores based on reranking
for ind, chunk in enumerate(ranked_chunks):
chunk.score = ranked_sim_scores[ind]
if rerank_metrics_callback is not None:
chunk_metrics = [
ChunkMetric(
document_id=chunk.document_id,
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
first_link=chunk.source_links[0] if chunk.source_links else None,
score=chunk.score if chunk.score is not None else 0,
)
for chunk in ranked_chunks
]
rerank_metrics_callback(
RerankMetricsContainer(
metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores # type: ignore
)
)
return list(ranked_chunks), list(ranked_indices)
def rerank_chunks(
query: SearchQuery,
chunks_to_rerank: list[InferenceChunk],
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> list[InferenceChunk]:
ranked_chunks, _ = semantic_reranking(
query=query.query,
chunks=chunks_to_rerank[: query.num_rerank],
rerank_metrics_callback=rerank_metrics_callback,
)
lower_chunks = chunks_to_rerank[query.num_rerank :]
# Scores from rerank cannot be meaningfully combined with scores without rerank
for lower_chunk in lower_chunks:
lower_chunk.score = None
ranked_chunks.extend(lower_chunks)
return ranked_chunks
@log_function_time(print_only=True)
def filter_chunks(
query: SearchQuery,
chunks_to_filter: list[InferenceChunk],
) -> list[str]:
"""Filters chunks based on whether the LLM thought they were relevant to the query.
Returns a list of the unique chunk IDs that were marked as relevant"""
chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks]
llm_chunk_selection = llm_batch_eval_chunks(
query=query.query,
chunk_contents=[chunk.content for chunk in chunks_to_filter],
)
return [
chunk.unique_id
for ind, chunk in enumerate(chunks_to_filter)
if llm_chunk_selection[ind]
]
def search_postprocessing(
search_query: SearchQuery,
retrieved_chunks: list[InferenceChunk],
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Generator[list[InferenceChunk] | list[str], None, None]:
post_processing_tasks: list[FunctionCall] = []
rerank_task_id = None
if should_rerank(search_query):
post_processing_tasks.append(
FunctionCall(
rerank_chunks,
(
search_query,
retrieved_chunks,
rerank_metrics_callback,
),
)
)
rerank_task_id = post_processing_tasks[-1].result_id
else:
final_chunks = retrieved_chunks
# NOTE: if we don't rerank, we can return the chunks immediately
# since we know this is the final order
_log_top_chunk_links(search_query.search_type.value, final_chunks)
yield final_chunks
chunks_yielded = True
llm_filter_task_id = None
if should_apply_llm_based_relevance_filter(search_query):
post_processing_tasks.append(
FunctionCall(
filter_chunks,
(search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]),
)
)
llm_filter_task_id = post_processing_tasks[-1].result_id
post_processing_results = (
run_functions_in_parallel(post_processing_tasks)
if post_processing_tasks
else {}
)
reranked_chunks = cast(
list[InferenceChunk] | None,
post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None,
)
if reranked_chunks:
if chunks_yielded:
logger.error(
"Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen."
)
else:
_log_top_chunk_links(search_query.search_type.value, reranked_chunks)
yield reranked_chunks
llm_chunk_selection = cast(
list[str] | None,
post_processing_results.get(str(llm_filter_task_id))
if llm_filter_task_id
else None,
)
if llm_chunk_selection is not None:
yield [
chunk.unique_id
for chunk in reranked_chunks or retrieved_chunks
if chunk.unique_id in llm_chunk_selection
]
else:
yield []

View File

@ -1,10 +1,10 @@
from typing import TYPE_CHECKING
from danswer.search.models import QueryFlow
from danswer.search.enums import QueryFlow
from danswer.search.models import SearchType
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.search.search_nlp_models import get_default_tokenizer
from danswer.search.search_nlp_models import IntentModel
from danswer.search.search_runner import remove_stop_words_and_punctuation
from danswer.server.query_and_chat.models import HelperResponse
from danswer.utils.logger import setup_logger

View File

@ -5,19 +5,16 @@ from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.danswer_helper import query_intent
from danswer.search.enums import QueryFlow
from danswer.search.enums import RecencyBiasSetting
from danswer.search.models import BaseFilters
from danswer.search.models import IndexFilters
from danswer.search.models import QueryFlow
from danswer.search.models import RecencyBiasSetting
from danswer.search.models import RetrievalDetails
from danswer.search.models import SearchQuery
from danswer.search.models import SearchRequest
from danswer.search.models import SearchType
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.search.preprocessing.danswer_helper import query_intent
from danswer.secondary_llm_flows.source_filter import extract_source_filter
from danswer.secondary_llm_flows.time_filter import extract_time_filter
from danswer.utils.logger import setup_logger
@ -31,15 +28,12 @@ logger = setup_logger()
@log_function_time(print_only=True)
def retrieval_preprocessing(
query: str,
retrieval_details: RetrievalDetails,
persona: Persona,
search_request: SearchRequest,
user: User | None,
db_session: Session,
bypass_acl: bool = False,
include_query_intent: bool = True,
skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW,
skip_rerank_non_realtime: bool = not ENABLE_RERANKING_ASYNC_FLOW,
enable_auto_detect_filters: bool = False,
disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
base_recency_decay: float = BASE_RECENCY_DECAY,
@ -50,8 +44,12 @@ def retrieval_preprocessing(
Then any filters or settings as part of the query are used
Then defaults to Persona settings if not specified by the query
"""
query = search_request.query
limit = search_request.limit
offset = search_request.offset
persona = search_request.persona
preset_filters = retrieval_details.filters or BaseFilters()
preset_filters = search_request.human_selected_filters or BaseFilters()
if persona and persona.document_sets and preset_filters.document_set is None:
preset_filters.document_set = [
document_set.name for document_set in persona.document_sets
@ -65,16 +63,20 @@ def retrieval_preprocessing(
if disable_llm_filter_extraction:
auto_detect_time_filter = False
auto_detect_source_filter = False
elif retrieval_details.enable_auto_detect_filters is False:
elif enable_auto_detect_filters is False:
logger.debug("Retrieval details disables auto detect filters")
auto_detect_time_filter = False
auto_detect_source_filter = False
elif persona.llm_filter_extraction is False:
elif persona and persona.llm_filter_extraction is False:
logger.debug("Persona disables auto detect filters")
auto_detect_time_filter = False
auto_detect_source_filter = False
if time_filter is not None and persona.recency_bias != RecencyBiasSetting.AUTO:
if (
time_filter is not None
and persona
and persona.recency_bias != RecencyBiasSetting.AUTO
):
auto_detect_time_filter = False
logger.debug("Not extract time filter - already provided")
if source_filter is not None:
@ -138,24 +140,18 @@ def retrieval_preprocessing(
access_control_list=user_acl_filters,
)
# Tranformer-based re-ranking to run at same time as LLM chunk relevance filter
# This one is only set globally, not via query or Persona settings
skip_reranking = (
skip_rerank_realtime
if retrieval_details.real_time
else skip_rerank_non_realtime
)
llm_chunk_filter = persona.llm_relevance_filter
llm_chunk_filter = False
if persona:
llm_chunk_filter = persona.llm_relevance_filter
if disable_llm_chunk_filter:
llm_chunk_filter = False
# Decays at 1 / (1 + (multiplier * num years))
if persona.recency_bias == RecencyBiasSetting.NO_DECAY:
if persona and persona.recency_bias == RecencyBiasSetting.NO_DECAY:
recency_bias_multiplier = 0.0
elif persona.recency_bias == RecencyBiasSetting.BASE_DECAY:
elif persona and persona.recency_bias == RecencyBiasSetting.BASE_DECAY:
recency_bias_multiplier = base_recency_decay
elif persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT:
elif persona and persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT:
recency_bias_multiplier = base_recency_decay * favor_recent_decay_multiplier
else:
if predicted_favor_recent:
@ -166,14 +162,12 @@ def retrieval_preprocessing(
return (
SearchQuery(
query=query,
search_type=persona.search_type,
search_type=persona.search_type if persona else SearchType.HYBRID,
filters=final_filters,
recency_bias_multiplier=recency_bias_multiplier,
num_hits=retrieval_details.limit
if retrieval_details.limit is not None
else NUM_RETURNED_HITS,
offset=retrieval_details.offset or 0,
skip_rerank=skip_reranking,
num_hits=limit if limit is not None else NUM_RETURNED_HITS,
offset=offset or 0,
skip_rerank=search_request.skip_rerank,
skip_llm_chunk_filter=not llm_chunk_filter,
),
predicted_search_type,

View File

@ -0,0 +1,256 @@
import string
from collections.abc import Callable
from nltk.corpus import stopwords # type:ignore
from nltk.stem import WordNetLemmatizer # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from danswer.chat.models import LlmDoc
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.document_index.interfaces import DocumentIndex
from danswer.indexing.models import InferenceChunk
from danswer.search.models import ChunkMetric
from danswer.search.models import IndexFilters
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.search.search_nlp_models import EmbedTextType
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from danswer.utils.timing import log_function_time
logger = setup_logger()
def lemmatize_text(text: str) -> list[str]:
lemmatizer = WordNetLemmatizer()
word_tokens = word_tokenize(text)
return [lemmatizer.lemmatize(word) for word in word_tokens]
def remove_stop_words_and_punctuation(text: str) -> list[str]:
stop_words = set(stopwords.words("english"))
word_tokens = word_tokenize(text)
text_trimmed = [
word
for word in word_tokens
if (word.casefold() not in stop_words and word not in string.punctuation)
]
return text_trimmed or word_tokens
def query_processing(
query: str,
) -> str:
query = " ".join(remove_stop_words_and_punctuation(query))
query = " ".join(lemmatize_text(query))
return query
def combine_retrieval_results(
chunk_sets: list[list[InferenceChunk]],
) -> list[InferenceChunk]:
all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set]
unique_chunks: dict[tuple[str, int], InferenceChunk] = {}
for chunk in all_chunks:
key = (chunk.document_id, chunk.chunk_id)
if key not in unique_chunks:
unique_chunks[key] = chunk
continue
stored_chunk_score = unique_chunks[key].score or 0
this_chunk_score = chunk.score or 0
if stored_chunk_score < this_chunk_score:
unique_chunks[key] = chunk
sorted_chunks = sorted(
unique_chunks.values(), key=lambda x: x.score or 0, reverse=True
)
return sorted_chunks
@log_function_time(print_only=True)
def doc_index_retrieval(
query: SearchQuery,
document_index: DocumentIndex,
db_session: Session,
hybrid_alpha: float = HYBRID_ALPHA,
) -> list[InferenceChunk]:
if query.search_type == SearchType.KEYWORD:
top_chunks = document_index.keyword_retrieval(
query=query.query,
filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,
)
else:
db_embedding_model = get_current_db_embedding_model(db_session)
model = EmbeddingModel(
model_name=db_embedding_model.model_name,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
normalize=db_embedding_model.normalize,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
if query.search_type == SearchType.SEMANTIC:
top_chunks = document_index.semantic_retrieval(
query=query.query,
query_embedding=query_embedding,
filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,
)
elif query.search_type == SearchType.HYBRID:
top_chunks = document_index.hybrid_retrieval(
query=query.query,
query_embedding=query_embedding,
filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,
offset=query.offset,
hybrid_alpha=hybrid_alpha,
)
else:
raise RuntimeError("Invalid Search Flow")
return top_chunks
def _simplify_text(text: str) -> str:
return "".join(
char for char in text if char not in string.punctuation and not char.isspace()
).lower()
def retrieve_chunks(
query: SearchQuery,
document_index: DocumentIndex,
db_session: Session,
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
) -> list[InferenceChunk]:
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
# Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query:
top_chunks = doc_index_retrieval(
query=query,
document_index=document_index,
db_session=db_session,
hybrid_alpha=hybrid_alpha,
)
else:
simplified_queries = set()
run_queries: list[tuple[Callable, tuple]] = []
# Currently only uses query expansion on multilingual use cases
query_rephrases = multilingual_query_expansion(
query.query, multilingual_expansion_str
)
# Just to be extra sure, add the original query.
query_rephrases.append(query.query)
for rephrase in set(query_rephrases):
# Sometimes the model rephrases the query in the same language with minor changes
# Avoid doing an extra search with the minor changes as this biases the results
simplified_rephrase = _simplify_text(rephrase)
if simplified_rephrase in simplified_queries:
continue
simplified_queries.add(simplified_rephrase)
q_copy = query.copy(update={"query": rephrase}, deep=True)
run_queries.append(
(
doc_index_retrieval,
(q_copy, document_index, db_session, hybrid_alpha),
)
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
top_chunks = combine_retrieval_results(parallel_search_results)
if not top_chunks:
logger.info(
f"{query.search_type.value.capitalize()} search returned no results "
f"with filters: {query.filters}"
)
return []
if retrieval_metrics_callback is not None:
chunk_metrics = [
ChunkMetric(
document_id=chunk.document_id,
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
first_link=chunk.source_links[0] if chunk.source_links else None,
score=chunk.score if chunk.score is not None else 0,
)
for chunk in top_chunks
]
retrieval_metrics_callback(
RetrievalMetricsContainer(
search_type=query.search_type, metrics=chunk_metrics
)
)
return top_chunks
def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc:
if not inf_chunks:
raise ValueError("Cannot combine empty list of chunks")
# Use the first link of the document
first_chunk = inf_chunks[0]
chunk_texts = [chunk.content for chunk in inf_chunks]
return LlmDoc(
document_id=first_chunk.document_id,
content="\n".join(chunk_texts),
semantic_identifier=first_chunk.semantic_identifier,
source_type=first_chunk.source_type,
metadata=first_chunk.metadata,
updated_at=first_chunk.updated_at,
link=first_chunk.source_links[0] if first_chunk.source_links else None,
)
def inference_documents_from_ids(
doc_identifiers: list[tuple[str, int]],
document_index: DocumentIndex,
) -> list[LlmDoc]:
# Currently only fetches whole docs
doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers)
# No need for ACL here because the doc ids were validated beforehand
filters = IndexFilters(access_control_list=None)
functions_with_args: list[tuple[Callable, tuple]] = [
(document_index.id_based_retrieval, (doc_id, None, filters))
for doc_id in doc_ids_set
]
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=True
)
# Any failures to retrieve would give a None, drop the Nones and empty lists
inference_chunks_sets = [res for res in parallel_results if res]
return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets]

View File

@ -1,645 +0,0 @@
import string
from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
import numpy
from nltk.corpus import stopwords # type:ignore
from nltk.stem import WordNetLemmatizer # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from danswer.chat.models import LlmDoc
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.document_index.interfaces import DocumentIndex
from danswer.indexing.models import InferenceChunk
from danswer.search.models import ChunkMetric
from danswer.search.models import IndexFilters
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchDoc
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.search.search_nlp_models import EmbedTextType
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from danswer.utils.timing import log_function_time
logger = setup_logger()
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
top_links = [
c.source_links[0] if c.source_links is not None else "No Link" for c in chunks
]
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
def lemmatize_text(text: str) -> list[str]:
lemmatizer = WordNetLemmatizer()
word_tokens = word_tokenize(text)
return [lemmatizer.lemmatize(word) for word in word_tokens]
def remove_stop_words_and_punctuation(text: str) -> list[str]:
stop_words = set(stopwords.words("english"))
word_tokens = word_tokenize(text)
text_trimmed = [
word
for word in word_tokens
if (word.casefold() not in stop_words and word not in string.punctuation)
]
return text_trimmed or word_tokens
def query_processing(
query: str,
) -> str:
query = " ".join(remove_stop_words_and_punctuation(query))
query = " ".join(lemmatize_text(query))
return query
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
search_docs = (
[
SearchDoc(
document_id=chunk.document_id,
chunk_ind=chunk.chunk_id,
semantic_identifier=chunk.semantic_identifier or "Unknown",
link=chunk.source_links.get(0) if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
boost=chunk.boost,
hidden=chunk.hidden,
metadata=chunk.metadata,
score=chunk.score,
match_highlights=chunk.match_highlights,
updated_at=chunk.updated_at,
primary_owners=chunk.primary_owners,
secondary_owners=chunk.secondary_owners,
)
for chunk in chunks
]
if chunks
else []
)
return search_docs
def combine_retrieval_results(
chunk_sets: list[list[InferenceChunk]],
) -> list[InferenceChunk]:
all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set]
unique_chunks: dict[tuple[str, int], InferenceChunk] = {}
for chunk in all_chunks:
key = (chunk.document_id, chunk.chunk_id)
if key not in unique_chunks:
unique_chunks[key] = chunk
continue
stored_chunk_score = unique_chunks[key].score or 0
this_chunk_score = chunk.score or 0
if stored_chunk_score < this_chunk_score:
unique_chunks[key] = chunk
sorted_chunks = sorted(
unique_chunks.values(), key=lambda x: x.score or 0, reverse=True
)
return sorted_chunks
@log_function_time(print_only=True)
def doc_index_retrieval(
query: SearchQuery,
document_index: DocumentIndex,
db_session: Session,
hybrid_alpha: float = HYBRID_ALPHA,
) -> list[InferenceChunk]:
if query.search_type == SearchType.KEYWORD:
top_chunks = document_index.keyword_retrieval(
query=query.query,
filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,
)
else:
db_embedding_model = get_current_db_embedding_model(db_session)
model = EmbeddingModel(
model_name=db_embedding_model.model_name,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
normalize=db_embedding_model.normalize,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
if query.search_type == SearchType.SEMANTIC:
top_chunks = document_index.semantic_retrieval(
query=query.query,
query_embedding=query_embedding,
filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,
)
elif query.search_type == SearchType.HYBRID:
top_chunks = document_index.hybrid_retrieval(
query=query.query,
query_embedding=query_embedding,
filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,
offset=query.offset,
hybrid_alpha=hybrid_alpha,
)
else:
raise RuntimeError("Invalid Search Flow")
return top_chunks
@log_function_time(print_only=True)
def semantic_reranking(
query: str,
chunks: list[InferenceChunk],
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
model_min: int = CROSS_ENCODER_RANGE_MIN,
model_max: int = CROSS_ENCODER_RANGE_MAX,
) -> tuple[list[InferenceChunk], list[int]]:
"""Reranks chunks based on cross-encoder models. Additionally provides the original indices
of the chunks in their new sorted order.
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
"""
cross_encoders = CrossEncoderEnsembleModel()
passages = [chunk.content for chunk in chunks]
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))
cross_models_min = numpy.min(sim_scores)
shifted_sim_scores = sum(
[enc_n_scores - cross_models_min for enc_n_scores in sim_scores]
) / len(sim_scores)
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
recency_multiplier = [chunk.recency_bias for chunk in chunks]
boosted_sim_scores = shifted_sim_scores * boosts * recency_multiplier
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
model_max - model_min
)
orig_indices = [i for i in range(len(normalized_b_s_scores))]
scored_results = list(
zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices)
)
scored_results.sort(key=lambda x: x[0], reverse=True)
ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip(
*scored_results
)
logger.debug(
f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}"
)
# Assign new chunk scores based on reranking
for ind, chunk in enumerate(ranked_chunks):
chunk.score = ranked_sim_scores[ind]
if rerank_metrics_callback is not None:
chunk_metrics = [
ChunkMetric(
document_id=chunk.document_id,
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
first_link=chunk.source_links[0] if chunk.source_links else None,
score=chunk.score if chunk.score is not None else 0,
)
for chunk in ranked_chunks
]
rerank_metrics_callback(
RerankMetricsContainer(
metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores # type: ignore
)
)
return list(ranked_chunks), list(ranked_indices)
def apply_boost_legacy(
chunks: list[InferenceChunk],
norm_min: float = SIM_SCORE_RANGE_LOW,
norm_max: float = SIM_SCORE_RANGE_HIGH,
) -> list[InferenceChunk]:
scores = [chunk.score or 0 for chunk in chunks]
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
logger.debug(f"Raw similarity scores: {scores}")
score_min = min(scores)
score_max = max(scores)
score_range = score_max - score_min
if score_range != 0:
boosted_scores = [
((score - score_min) / score_range) * boost
for score, boost in zip(scores, boosts)
]
unnormed_boosted_scores = [
score * score_range + score_min for score in boosted_scores
]
else:
unnormed_boosted_scores = [
score * boost for score, boost in zip(scores, boosts)
]
norm_min = min(norm_min, min(scores))
norm_max = max(norm_max, max(scores))
# This should never be 0 unless user has done some weird/wrong settings
norm_range = norm_max - norm_min
# For score display purposes
if norm_range != 0:
re_normed_scores = [
((score - norm_min) / norm_range) for score in unnormed_boosted_scores
]
else:
re_normed_scores = unnormed_boosted_scores
rescored_chunks = list(zip(re_normed_scores, chunks))
rescored_chunks.sort(key=lambda x: x[0], reverse=True)
sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks)
final_chunks = list(boost_sorted_chunks)
final_scores = list(sorted_boosted_scores)
for ind, chunk in enumerate(final_chunks):
chunk.score = final_scores[ind]
logger.debug(f"Boost sorted similary scores: {list(final_scores)}")
return final_chunks
def apply_boost(
chunks: list[InferenceChunk],
# Need the range of values to not be too spread out for applying boost
# therefore norm across only the top few results
norm_cutoff: int = NUM_RERANKED_RESULTS,
norm_min: float = SIM_SCORE_RANGE_LOW,
norm_max: float = SIM_SCORE_RANGE_HIGH,
) -> list[InferenceChunk]:
scores = [chunk.score or 0.0 for chunk in chunks]
logger.debug(f"Raw similarity scores: {scores}")
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
recency_multiplier = [chunk.recency_bias for chunk in chunks]
norm_min = min(norm_min, min(scores[:norm_cutoff]))
norm_max = max(norm_max, max(scores[:norm_cutoff]))
# This should never be 0 unless user has done some weird/wrong settings
norm_range = norm_max - norm_min
boosted_scores = [
max(0, (score - norm_min) * boost * recency / norm_range)
for score, boost, recency in zip(scores, boosts, recency_multiplier)
]
rescored_chunks = list(zip(boosted_scores, chunks))
rescored_chunks.sort(key=lambda x: x[0], reverse=True)
sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks)
final_chunks = list(boost_sorted_chunks)
final_scores = list(sorted_boosted_scores)
for ind, chunk in enumerate(final_chunks):
chunk.score = final_scores[ind]
logger.debug(
f"Boosted + Time Weighted sorted similarity scores: {list(final_scores)}"
)
return final_chunks
def _simplify_text(text: str) -> str:
return "".join(
char for char in text if char not in string.punctuation and not char.isspace()
).lower()
def retrieve_chunks(
query: SearchQuery,
document_index: DocumentIndex,
db_session: Session,
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
) -> list[InferenceChunk]:
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
# Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query:
top_chunks = doc_index_retrieval(
query=query,
document_index=document_index,
db_session=db_session,
hybrid_alpha=hybrid_alpha,
)
else:
simplified_queries = set()
run_queries: list[tuple[Callable, tuple]] = []
# Currently only uses query expansion on multilingual use cases
query_rephrases = multilingual_query_expansion(
query.query, multilingual_expansion_str
)
# Just to be extra sure, add the original query.
query_rephrases.append(query.query)
for rephrase in set(query_rephrases):
# Sometimes the model rephrases the query in the same language with minor changes
# Avoid doing an extra search with the minor changes as this biases the results
simplified_rephrase = _simplify_text(rephrase)
if simplified_rephrase in simplified_queries:
continue
simplified_queries.add(simplified_rephrase)
q_copy = query.copy(update={"query": rephrase}, deep=True)
run_queries.append(
(
doc_index_retrieval,
(q_copy, document_index, db_session, hybrid_alpha),
)
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
top_chunks = combine_retrieval_results(parallel_search_results)
if not top_chunks:
logger.info(
f"{query.search_type.value.capitalize()} search returned no results "
f"with filters: {query.filters}"
)
return []
if retrieval_metrics_callback is not None:
chunk_metrics = [
ChunkMetric(
document_id=chunk.document_id,
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
first_link=chunk.source_links[0] if chunk.source_links else None,
score=chunk.score if chunk.score is not None else 0,
)
for chunk in top_chunks
]
retrieval_metrics_callback(
RetrievalMetricsContainer(
search_type=query.search_type, metrics=chunk_metrics
)
)
return top_chunks
def should_rerank(query: SearchQuery) -> bool:
# Don't re-rank for keyword search
return query.search_type != SearchType.KEYWORD and not query.skip_rerank
def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool:
return not query.skip_llm_chunk_filter
def rerank_chunks(
query: SearchQuery,
chunks_to_rerank: list[InferenceChunk],
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> list[InferenceChunk]:
ranked_chunks, _ = semantic_reranking(
query=query.query,
chunks=chunks_to_rerank[: query.num_rerank],
rerank_metrics_callback=rerank_metrics_callback,
)
lower_chunks = chunks_to_rerank[query.num_rerank :]
# Scores from rerank cannot be meaningfully combined with scores without rerank
for lower_chunk in lower_chunks:
lower_chunk.score = None
ranked_chunks.extend(lower_chunks)
return ranked_chunks
@log_function_time(print_only=True)
def filter_chunks(
query: SearchQuery,
chunks_to_filter: list[InferenceChunk],
) -> list[str]:
"""Filters chunks based on whether the LLM thought they were relevant to the query.
Returns a list of the unique chunk IDs that were marked as relevant"""
chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks]
llm_chunk_selection = llm_batch_eval_chunks(
query=query.query,
chunk_contents=[chunk.content for chunk in chunks_to_filter],
)
return [
chunk.unique_id
for ind, chunk in enumerate(chunks_to_filter)
if llm_chunk_selection[ind]
]
def full_chunk_search(
query: SearchQuery,
document_index: DocumentIndex,
db_session: Session,
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk], list[bool]]:
"""A utility which provides an easier interface than `full_chunk_search_generator`.
Rather than returning the chunks and llm relevance filter results in two separate
yields, just returns them both at once."""
search_generator = full_chunk_search_generator(
search_query=query,
document_index=document_index,
db_session=db_session,
hybrid_alpha=hybrid_alpha,
multilingual_expansion_str=multilingual_expansion_str,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
top_chunks = cast(list[InferenceChunk], next(search_generator))
llm_chunk_selection = cast(list[bool], next(search_generator))
return top_chunks, llm_chunk_selection
def empty_search_generator() -> Iterator[list[InferenceChunk] | list[bool]]:
yield cast(list[InferenceChunk], [])
yield cast(list[bool], [])
def full_chunk_search_generator(
search_query: SearchQuery,
document_index: DocumentIndex,
db_session: Session,
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Iterator[list[InferenceChunk] | list[bool]]:
"""Always yields twice. Once with the selected chunks and once with the LLM relevance filter result.
If LLM filter results are turned off, returns a list of False
"""
chunks_yielded = False
retrieved_chunks = retrieve_chunks(
query=search_query,
document_index=document_index,
db_session=db_session,
hybrid_alpha=hybrid_alpha,
multilingual_expansion_str=multilingual_expansion_str,
retrieval_metrics_callback=retrieval_metrics_callback,
)
if not retrieved_chunks:
yield cast(list[InferenceChunk], [])
yield cast(list[bool], [])
return
post_processing_tasks: list[FunctionCall] = []
rerank_task_id = None
if should_rerank(search_query):
post_processing_tasks.append(
FunctionCall(
rerank_chunks,
(
search_query,
retrieved_chunks,
rerank_metrics_callback,
),
)
)
rerank_task_id = post_processing_tasks[-1].result_id
else:
final_chunks = retrieved_chunks
# NOTE: if we don't rerank, we can return the chunks immediately
# since we know this is the final order
_log_top_chunk_links(search_query.search_type.value, final_chunks)
yield final_chunks
chunks_yielded = True
llm_filter_task_id = None
if should_apply_llm_based_relevance_filter(search_query):
post_processing_tasks.append(
FunctionCall(
filter_chunks,
(search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]),
)
)
llm_filter_task_id = post_processing_tasks[-1].result_id
post_processing_results = (
run_functions_in_parallel(post_processing_tasks)
if post_processing_tasks
else {}
)
reranked_chunks = cast(
list[InferenceChunk] | None,
post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None,
)
if reranked_chunks:
if chunks_yielded:
logger.error(
"Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen."
)
else:
_log_top_chunk_links(search_query.search_type.value, reranked_chunks)
yield reranked_chunks
llm_chunk_selection = cast(
list[str] | None,
post_processing_results.get(str(llm_filter_task_id))
if llm_filter_task_id
else None,
)
if llm_chunk_selection is not None:
yield [
chunk.unique_id in llm_chunk_selection
for chunk in reranked_chunks or retrieved_chunks
]
else:
yield [False for _ in reranked_chunks or retrieved_chunks]
def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc:
if not inf_chunks:
raise ValueError("Cannot combine empty list of chunks")
# Use the first link of the document
first_chunk = inf_chunks[0]
chunk_texts = [chunk.content for chunk in inf_chunks]
return LlmDoc(
document_id=first_chunk.document_id,
content="\n".join(chunk_texts),
semantic_identifier=first_chunk.semantic_identifier,
source_type=first_chunk.source_type,
metadata=first_chunk.metadata,
updated_at=first_chunk.updated_at,
link=first_chunk.source_links[0] if first_chunk.source_links else None,
)
def inference_documents_from_ids(
doc_identifiers: list[tuple[str, int]],
document_index: DocumentIndex,
) -> list[LlmDoc]:
# Currently only fetches whole docs
doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers)
# No need for ACL here because the doc ids were validated beforehand
filters = IndexFilters(access_control_list=None)
functions_with_args: list[tuple[Callable, tuple]] = [
(document_index.id_based_retrieval, (doc_id, None, filters))
for doc_id in doc_ids_set
]
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=True
)
# Any failures to retrieve would give a None, drop the Nones and empty lists
inference_chunks_sets = [res for res in parallel_results if res]
return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets]

View File

@ -0,0 +1,29 @@
from danswer.indexing.models import InferenceChunk
from danswer.search.models import SearchDoc
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
search_docs = (
[
SearchDoc(
document_id=chunk.document_id,
chunk_ind=chunk.chunk_id,
semantic_identifier=chunk.semantic_identifier or "Unknown",
link=chunk.source_links.get(0) if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
boost=chunk.boost,
hidden=chunk.hidden,
metadata=chunk.metadata,
score=chunk.score,
match_highlights=chunk.match_highlights,
updated_at=chunk.updated_at,
primary_owners=chunk.primary_owners,
secondary_owners=chunk.secondary_owners,
)
for chunk in chunks
]
if chunks
else []
)
return search_docs

View File

@ -11,8 +11,8 @@ from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.llm.utils import get_default_llm_token_encode
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.models import IndexFilters
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.server.documents.models import ChunkInfo
from danswer.server.documents.models import DocumentInfo

View File

@ -4,7 +4,7 @@ from pydantic import BaseModel
from danswer.db.models import Persona
from danswer.db.models import StarterMessage
from danswer.search.models import RecencyBiasSetting
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.prompt.models import PromptSnapshot

View File

@ -6,13 +6,9 @@ from fastapi import Depends
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session
from danswer.document_index.factory import get_default_document_index
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.models import IndexFilters
from danswer.search.models import SearchQuery
from danswer.search.search_runner import full_chunk_search
from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
from danswer.server.danswer_api.ingestion import api_key_dep
from danswer.utils.logger import setup_logger
@ -70,27 +66,13 @@ def gpt_search(
_: str | None = Depends(api_key_dep),
db_session: Session = Depends(get_session),
) -> GptSearchResponse:
query = search_request.query
user_acl_filters = build_access_filters_for_user(None, db_session)
final_filters = IndexFilters(access_control_list=user_acl_filters)
search_query = SearchQuery(
query=query,
filters=final_filters,
recency_bias_multiplier=1.0,
skip_llm_chunk_filter=True,
)
embedding_model = get_current_db_embedding_model(db_session)
document_index = get_default_document_index(
primary_index_name=embedding_model.index_name, secondary_index_name=None
)
top_chunks, __ = full_chunk_search(
query=search_query, document_index=document_index, db_session=db_session
)
top_chunks = SearchPipeline(
search_request=SearchRequest(
query=search_request.query,
),
user=None,
db_session=db_session,
).reranked_docs
return GptSearchResponse(
matching_document_chunks=[

View File

@ -15,11 +15,11 @@ from danswer.document_index.factory import get_default_document_index
from danswer.document_index.vespa.index import VespaIndex
from danswer.one_shot_answer.answer_question import stream_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.models import IndexFilters
from danswer.search.models import SearchDoc
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.search.preprocessing.danswer_helper import recommend_search_flow
from danswer.search.utils import chunks_to_search_docs
from danswer.secondary_llm_flows.query_validation import get_query_answerability
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
from danswer.server.query_and_chat.models import AdminSearchRequest

View File

@ -8,15 +8,12 @@ from typing import TextIO
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import get_chunks_for_qa
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.search.models import IndexFilters
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.search_runner import full_chunk_search
from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
from danswer.utils.callbacks import MetricsHander
@ -81,35 +78,22 @@ def get_search_results(
RetrievalMetricsContainer | None,
RerankMetricsContainer | None,
]:
filters = IndexFilters(
source_type=None,
document_set=None,
time_cutoff=None,
access_control_list=None,
)
search_query = SearchQuery(
query=query,
filters=filters,
recency_bias_multiplier=1.0,
)
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
rerank_metrics = MetricsHander[RerankMetricsContainer]()
with Session(get_sqlalchemy_engine()) as db_session:
embedding_model = get_current_db_embedding_model(db_session)
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,
),
user=None,
db_session=db_session,
retrieval_metrics_callback=retrieval_metrics.record_metric,
rerank_metrics_callback=rerank_metrics.record_metric,
)
document_index = get_default_document_index(
primary_index_name=embedding_model.index_name, secondary_index_name=None
)
top_chunks, llm_chunk_selection = full_chunk_search(
query=search_query,
document_index=document_index,
db_session=db_session,
retrieval_metrics_callback=retrieval_metrics.record_metric,
rerank_metrics_callback=rerank_metrics.record_metric,
)
top_chunks = search_pipeline.reranked_docs
llm_chunk_selection = search_pipeline.chunk_relevance_list
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,