mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-10 12:59:59 +02:00
Refactor search pipeline
This commit is contained in:
parent
7a861ecec4
commit
1ba74ee4df
@ -9,7 +9,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.search.models import RecencyBiasSetting
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.models import SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
@ -13,7 +13,7 @@ from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Prompt as PromptDBModel
|
||||
from danswer.search.models import RecencyBiasSetting
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
|
||||
|
||||
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
|
||||
|
@ -5,10 +5,10 @@ from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.search.models import SearchType
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
|
@ -53,11 +53,10 @@ from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.request_preprocessing import retrieval_preprocessing
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import full_chunk_search_generator
|
||||
from danswer.search.search_runner import inference_documents_from_ids
|
||||
from danswer.search.models import SearchRequest
|
||||
from danswer.search.pipeline import SearchPipeline
|
||||
from danswer.search.retrieval.search_runner import inference_documents_from_ids
|
||||
from danswer.search.utils import chunks_to_search_docs
|
||||
from danswer.secondary_llm_flows.choose_search import check_if_need_search
|
||||
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
@ -377,37 +376,25 @@ def stream_chat_message_objects(
|
||||
else query_override
|
||||
)
|
||||
|
||||
(
|
||||
retrieval_request,
|
||||
predicted_search_type,
|
||||
predicted_flow,
|
||||
) = retrieval_preprocessing(
|
||||
query=rephrased_query,
|
||||
retrieval_details=cast(RetrievalDetails, retrieval_options),
|
||||
persona=persona,
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=rephrased_query,
|
||||
human_selected_filters=retrieval_options.filters
|
||||
if retrieval_options
|
||||
else None,
|
||||
persona=persona,
|
||||
offset=retrieval_options.offset if retrieval_options else None,
|
||||
limit=retrieval_options.limit if retrieval_options else None,
|
||||
),
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
documents_generator = full_chunk_search_generator(
|
||||
search_query=retrieval_request,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
)
|
||||
time_cutoff = retrieval_request.filters.time_cutoff
|
||||
recency_bias_multiplier = retrieval_request.recency_bias_multiplier
|
||||
run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter
|
||||
|
||||
# First fetch and return the top chunks to the UI so the user can
|
||||
# immediately see some results
|
||||
top_chunks = cast(list[InferenceChunk], next(documents_generator))
|
||||
top_chunks = search_pipeline.reranked_docs
|
||||
top_docs = chunks_to_search_docs(top_chunks)
|
||||
|
||||
# Get ranking of the documents for citation purposes later
|
||||
doc_id_to_rank_map = map_document_id_order(
|
||||
cast(list[InferenceChunk | LlmDoc], top_chunks)
|
||||
)
|
||||
|
||||
top_docs = chunks_to_search_docs(top_chunks)
|
||||
doc_id_to_rank_map = map_document_id_order(top_chunks)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
|
||||
@ -422,24 +409,17 @@ def stream_chat_message_objects(
|
||||
initial_response = QADocsResponse(
|
||||
rephrased_query=rephrased_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search_type,
|
||||
applied_source_filters=retrieval_request.filters.source_type,
|
||||
applied_time_cutoff=time_cutoff,
|
||||
recency_bias_multiplier=recency_bias_multiplier,
|
||||
predicted_flow=search_pipeline.predicted_flow,
|
||||
predicted_search=search_pipeline.predicted_search_type,
|
||||
applied_source_filters=search_pipeline.search_query.filters.source_type,
|
||||
applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff,
|
||||
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
|
||||
)
|
||||
yield initial_response
|
||||
|
||||
# Get the final ordering of chunks for the LLM call
|
||||
llm_chunk_selection = cast(list[bool], next(documents_generator))
|
||||
|
||||
# Yield the list of LLM selected chunks for showing the LLM selected icons in the UI
|
||||
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=[
|
||||
index for index, value in enumerate(llm_chunk_selection) if value
|
||||
]
|
||||
if run_llm_chunk_filter
|
||||
else []
|
||||
relevant_chunk_indices=search_pipeline.relevant_chunk_indicies
|
||||
)
|
||||
yield llm_relevance_filtering_response
|
||||
|
||||
@ -467,7 +447,7 @@ def stream_chat_message_objects(
|
||||
)
|
||||
llm_chunks_indices = get_chunks_for_qa(
|
||||
chunks=top_chunks,
|
||||
llm_chunk_selection=llm_chunk_selection,
|
||||
llm_chunk_selection=search_pipeline.chunk_relevance_list,
|
||||
token_limit=chunk_token_limit,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
@ -27,7 +27,7 @@ from danswer.db.models import SearchDoc
|
||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.db.models import StarterMessage
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.search.models import RecencyBiasSetting
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.search.models import SearchDoc as ServerSearchDoc
|
||||
|
@ -36,8 +36,8 @@ from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.search.models import RecencyBiasSetting
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
|
@ -12,7 +12,7 @@ from danswer.db.models import Persona
|
||||
from danswer.db.models import Persona__DocumentSet
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.search.models import RecencyBiasSetting
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
|
||||
|
||||
def _build_persona_name(channel_names: list[str]) -> str:
|
||||
|
@ -64,8 +64,8 @@ from danswer.document_index.vespa.utils import remove_invalid_unicode_chars
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.search_runner import query_processing
|
||||
from danswer.search.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.search.retrieval.search_runner import query_processing
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
@ -1,7 +1,6 @@
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
@ -33,11 +32,9 @@ from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import get_persona_by_id
|
||||
from danswer.db.chat import get_prompt_by_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
@ -55,9 +52,9 @@ from danswer.prompts.prompt_utils import build_task_prompt_reminders
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.search.request_preprocessing import retrieval_preprocessing
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import full_chunk_search_generator
|
||||
from danswer.search.models import SearchRequest
|
||||
from danswer.search.pipeline import SearchPipeline
|
||||
from danswer.search.utils import chunks_to_search_docs
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
@ -221,12 +218,6 @@ def stream_answer_objects(
|
||||
|
||||
llm_tokenizer = get_default_llm_token_encode()
|
||||
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
# Create a chat session which will just store the root message, the query, and the AI response
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
@ -244,33 +235,23 @@ def stream_answer_objects(
|
||||
# In chat flow it's given back along with the documents
|
||||
yield QueryRephrase(rephrased_query=rephrased_query)
|
||||
|
||||
(
|
||||
retrieval_request,
|
||||
predicted_search_type,
|
||||
predicted_flow,
|
||||
) = retrieval_preprocessing(
|
||||
query=rephrased_query,
|
||||
retrieval_details=query_req.retrieval_options,
|
||||
persona=chat_session.persona,
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=rephrased_query,
|
||||
human_selected_filters=query_req.retrieval_options.filters,
|
||||
persona=chat_session.persona,
|
||||
offset=query_req.retrieval_options.offset,
|
||||
limit=query_req.retrieval_options.limit,
|
||||
),
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
bypass_acl=bypass_acl,
|
||||
)
|
||||
|
||||
documents_generator = full_chunk_search_generator(
|
||||
search_query=retrieval_request,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
applied_time_cutoff = retrieval_request.filters.time_cutoff
|
||||
recency_bias_multiplier = retrieval_request.recency_bias_multiplier
|
||||
run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter
|
||||
|
||||
# First fetch and return the top chunks so the user can immediately see some results
|
||||
top_chunks = cast(list[InferenceChunk], next(documents_generator))
|
||||
|
||||
top_chunks = search_pipeline.reranked_docs
|
||||
top_docs = chunks_to_search_docs(top_chunks)
|
||||
fake_saved_docs = [SavedSearchDoc.from_search_doc(doc) for doc in top_docs]
|
||||
|
||||
@ -278,24 +259,17 @@ def stream_answer_objects(
|
||||
initial_response = QADocsResponse(
|
||||
rephrased_query=rephrased_query,
|
||||
top_documents=fake_saved_docs,
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search_type,
|
||||
applied_source_filters=retrieval_request.filters.source_type,
|
||||
applied_time_cutoff=applied_time_cutoff,
|
||||
recency_bias_multiplier=recency_bias_multiplier,
|
||||
predicted_flow=search_pipeline.predicted_flow,
|
||||
predicted_search=search_pipeline.predicted_search_type,
|
||||
applied_source_filters=search_pipeline.search_query.filters.source_type,
|
||||
applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff,
|
||||
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
|
||||
)
|
||||
yield initial_response
|
||||
|
||||
# Get the final ordering of chunks for the LLM call
|
||||
llm_chunk_selection = cast(list[bool], next(documents_generator))
|
||||
|
||||
# Yield the list of LLM selected chunks for showing the LLM selected icons in the UI
|
||||
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=[
|
||||
index for index, value in enumerate(llm_chunk_selection) if value
|
||||
]
|
||||
if run_llm_chunk_filter
|
||||
else []
|
||||
relevant_chunk_indices=search_pipeline.relevant_chunk_indicies
|
||||
)
|
||||
yield llm_relevance_filtering_response
|
||||
|
||||
@ -317,7 +291,7 @@ def stream_answer_objects(
|
||||
|
||||
llm_chunks_indices = get_chunks_for_qa(
|
||||
chunks=top_chunks,
|
||||
llm_chunk_selection=llm_chunk_selection,
|
||||
llm_chunk_selection=search_pipeline.chunk_relevance_list,
|
||||
token_limit=chunk_token_limit,
|
||||
)
|
||||
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||
|
30
backend/danswer/search/enums.py
Normal file
30
backend/danswer/search/enums.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""NOTE: this needs to be separate from models.py because of circular imports.
|
||||
Both search/models.py and db/models.py import enums from this file AND
|
||||
search/models.py imports from db/models.py."""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OptionalSearchSetting(str, Enum):
|
||||
ALWAYS = "always"
|
||||
NEVER = "never"
|
||||
# Determine whether to run search based on history and latest query
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class RecencyBiasSetting(str, Enum):
|
||||
FAVOR_RECENT = "favor_recent" # 2x decay rate
|
||||
BASE_DECAY = "base_decay"
|
||||
NO_DECAY = "no_decay"
|
||||
# Determine based on query if to use base_decay or favor_recent
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
class QueryFlow(str, Enum):
|
||||
SEARCH = "search"
|
||||
QUESTION_ANSWER = "question-answer"
|
@ -1,46 +1,24 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from danswer.db.models import Persona
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import SearchType
|
||||
|
||||
|
||||
MAX_METRICS_CONTENT = (
|
||||
200 # Just need enough characters to identify where in the doc the chunk is
|
||||
)
|
||||
|
||||
|
||||
class OptionalSearchSetting(str, Enum):
|
||||
ALWAYS = "always"
|
||||
NEVER = "never"
|
||||
# Determine whether to run search based on history and latest query
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class RecencyBiasSetting(str, Enum):
|
||||
FAVOR_RECENT = "favor_recent" # 2x decay rate
|
||||
BASE_DECAY = "base_decay"
|
||||
NO_DECAY = "no_decay"
|
||||
# Determine based on query if to use base_decay or favor_recent
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
class QueryFlow(str, Enum):
|
||||
SEARCH = "search"
|
||||
QUESTION_ANSWER = "question-answer"
|
||||
|
||||
|
||||
class Tag(BaseModel):
|
||||
tag_key: str
|
||||
tag_value: str
|
||||
@ -64,6 +42,28 @@ class ChunkMetric(BaseModel):
|
||||
score: float
|
||||
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
"""Input to the SearchPipeline."""
|
||||
|
||||
query: str
|
||||
search_type: SearchType = SearchType.HYBRID
|
||||
|
||||
human_selected_filters: BaseFilters | None = None
|
||||
enable_auto_detect_filters: bool | None = None
|
||||
persona: Persona | None = None
|
||||
|
||||
# if None, no offset / limit
|
||||
offset: int | None = None
|
||||
limit: int | None = None
|
||||
|
||||
recency_bias_multiplier: float = 1.0
|
||||
hybrid_alpha: float = HYBRID_ALPHA
|
||||
skip_rerank: bool = True
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class SearchQuery(BaseModel):
|
||||
query: str
|
||||
filters: IndexFilters
|
||||
|
152
backend/danswer/search/pipeline.py
Normal file
152
backend/danswer/search/pipeline.py
Normal file
@ -0,0 +1,152 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchRequest
|
||||
from danswer.search.postprocessing.postprocessing import search_postprocessing
|
||||
from danswer.search.preprocessing.preprocessing import retrieval_preprocessing
|
||||
from danswer.search.retrieval.search_runner import retrieve_chunks
|
||||
|
||||
|
||||
class SearchPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
search_request: SearchRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
):
|
||||
self.search_request = search_request
|
||||
self.user = user
|
||||
self.db_session = db_session
|
||||
self.bypass_acl = bypass_acl
|
||||
self.retrieval_metrics_callback = retrieval_metrics_callback
|
||||
self.rerank_metrics_callback = rerank_metrics_callback
|
||||
|
||||
self.embedding_model = get_current_db_embedding_model(db_session)
|
||||
self.document_index = get_default_document_index(
|
||||
primary_index_name=self.embedding_model.index_name,
|
||||
secondary_index_name=None,
|
||||
)
|
||||
|
||||
self._search_query: SearchQuery | None = None
|
||||
self._predicted_search_type: SearchType | None = None
|
||||
self._predicted_flow: QueryFlow | None = None
|
||||
|
||||
self._retrieved_docs: list[InferenceChunk] | None = None
|
||||
self._reranked_docs: list[InferenceChunk] | None = None
|
||||
self._relevant_chunk_indicies: list[int] | None = None
|
||||
|
||||
"""Pre-processing"""
|
||||
|
||||
def _run_preprocessing(self) -> None:
|
||||
(
|
||||
final_search_query,
|
||||
predicted_search_type,
|
||||
predicted_flow,
|
||||
) = retrieval_preprocessing(
|
||||
search_request=self.search_request,
|
||||
user=self.user,
|
||||
db_session=self.db_session,
|
||||
bypass_acl=self.bypass_acl,
|
||||
)
|
||||
self._predicted_search_type = predicted_search_type
|
||||
self._predicted_flow = predicted_flow
|
||||
self._search_query = final_search_query
|
||||
|
||||
@property
|
||||
def search_query(self) -> SearchQuery:
|
||||
if self._search_query is not None:
|
||||
return self._search_query
|
||||
|
||||
self._run_preprocessing()
|
||||
return cast(SearchQuery, self._search_query)
|
||||
|
||||
@property
|
||||
def predicted_search_type(self) -> SearchType:
|
||||
if self._predicted_search_type is not None:
|
||||
return self._predicted_search_type
|
||||
|
||||
self._run_preprocessing()
|
||||
return cast(SearchType, self._predicted_search_type)
|
||||
|
||||
@property
|
||||
def predicted_flow(self) -> QueryFlow:
|
||||
if self._predicted_flow is not None:
|
||||
return self._predicted_flow
|
||||
|
||||
self._run_preprocessing()
|
||||
return cast(QueryFlow, self._predicted_flow)
|
||||
|
||||
"""Retrieval"""
|
||||
|
||||
@property
|
||||
def retrieved_docs(self) -> list[InferenceChunk]:
|
||||
if self._retrieved_docs is not None:
|
||||
return self._retrieved_docs
|
||||
|
||||
self._retrieved_docs = retrieve_chunks(
|
||||
query=self.search_query,
|
||||
document_index=self.document_index,
|
||||
db_session=self.db_session,
|
||||
hybrid_alpha=self.search_request.hybrid_alpha,
|
||||
multilingual_expansion_str=MULTILINGUAL_QUERY_EXPANSION,
|
||||
retrieval_metrics_callback=self.retrieval_metrics_callback,
|
||||
)
|
||||
|
||||
# self._retrieved_docs = chunks_to_search_docs(retrieved_chunks)
|
||||
return cast(list[InferenceChunk], self._retrieved_docs)
|
||||
|
||||
"""Post-Processing"""
|
||||
|
||||
def _run_postprocessing(self) -> None:
|
||||
postprocessing_generator = search_postprocessing(
|
||||
search_query=self.search_query,
|
||||
retrieved_chunks=self.retrieved_docs,
|
||||
rerank_metrics_callback=self.rerank_metrics_callback,
|
||||
)
|
||||
self._reranked_docs = cast(list[InferenceChunk], next(postprocessing_generator))
|
||||
|
||||
relevant_chunk_ids = cast(list[str], next(postprocessing_generator))
|
||||
self._relevant_chunk_indicies = [
|
||||
ind
|
||||
for ind, chunk in enumerate(self._reranked_docs)
|
||||
if chunk.unique_id in relevant_chunk_ids
|
||||
]
|
||||
|
||||
@property
|
||||
def reranked_docs(self) -> list[InferenceChunk]:
|
||||
if self._reranked_docs is not None:
|
||||
return self._reranked_docs
|
||||
|
||||
self._run_postprocessing()
|
||||
return cast(list[InferenceChunk], self._reranked_docs)
|
||||
|
||||
@property
|
||||
def relevant_chunk_indicies(self) -> list[int]:
|
||||
if self._relevant_chunk_indicies is not None:
|
||||
return self._relevant_chunk_indicies
|
||||
|
||||
self._run_postprocessing()
|
||||
return cast(list[int], self._relevant_chunk_indicies)
|
||||
|
||||
@property
|
||||
def chunk_relevance_list(self) -> list[bool]:
|
||||
return [
|
||||
True if ind in self.relevant_chunk_indicies else False
|
||||
for ind in range(len(self.reranked_docs))
|
||||
]
|
222
backend/danswer/search/postprocessing/postprocessing.py
Normal file
222
backend/danswer/search/postprocessing/postprocessing.py
Normal file
@ -0,0 +1,222 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
|
||||
import numpy
|
||||
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
from danswer.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
|
||||
top_links = [
|
||||
c.source_links[0] if c.source_links is not None else "No Link" for c in chunks
|
||||
]
|
||||
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
||||
|
||||
|
||||
def should_rerank(query: SearchQuery) -> bool:
|
||||
# Don't re-rank for keyword search
|
||||
return query.search_type != SearchType.KEYWORD and not query.skip_rerank
|
||||
|
||||
|
||||
def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool:
|
||||
return not query.skip_llm_chunk_filter
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def semantic_reranking(
|
||||
query: str,
|
||||
chunks: list[InferenceChunk],
|
||||
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
||||
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> tuple[list[InferenceChunk], list[int]]:
|
||||
"""Reranks chunks based on cross-encoder models. Additionally provides the original indices
|
||||
of the chunks in their new sorted order.
|
||||
|
||||
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
||||
"""
|
||||
cross_encoders = CrossEncoderEnsembleModel()
|
||||
passages = [chunk.content for chunk in chunks]
|
||||
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
|
||||
|
||||
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
|
||||
|
||||
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))
|
||||
|
||||
cross_models_min = numpy.min(sim_scores)
|
||||
|
||||
shifted_sim_scores = sum(
|
||||
[enc_n_scores - cross_models_min for enc_n_scores in sim_scores]
|
||||
) / len(sim_scores)
|
||||
|
||||
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
|
||||
recency_multiplier = [chunk.recency_bias for chunk in chunks]
|
||||
boosted_sim_scores = shifted_sim_scores * boosts * recency_multiplier
|
||||
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
|
||||
model_max - model_min
|
||||
)
|
||||
orig_indices = [i for i in range(len(normalized_b_s_scores))]
|
||||
scored_results = list(
|
||||
zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices)
|
||||
)
|
||||
scored_results.sort(key=lambda x: x[0], reverse=True)
|
||||
ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip(
|
||||
*scored_results
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}"
|
||||
)
|
||||
|
||||
# Assign new chunk scores based on reranking
|
||||
for ind, chunk in enumerate(ranked_chunks):
|
||||
chunk.score = ranked_sim_scores[ind]
|
||||
|
||||
if rerank_metrics_callback is not None:
|
||||
chunk_metrics = [
|
||||
ChunkMetric(
|
||||
document_id=chunk.document_id,
|
||||
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
||||
first_link=chunk.source_links[0] if chunk.source_links else None,
|
||||
score=chunk.score if chunk.score is not None else 0,
|
||||
)
|
||||
for chunk in ranked_chunks
|
||||
]
|
||||
|
||||
rerank_metrics_callback(
|
||||
RerankMetricsContainer(
|
||||
metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
return list(ranked_chunks), list(ranked_indices)
|
||||
|
||||
|
||||
def rerank_chunks(
|
||||
query: SearchQuery,
|
||||
chunks_to_rerank: list[InferenceChunk],
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
ranked_chunks, _ = semantic_reranking(
|
||||
query=query.query,
|
||||
chunks=chunks_to_rerank[: query.num_rerank],
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
lower_chunks = chunks_to_rerank[query.num_rerank :]
|
||||
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
||||
for lower_chunk in lower_chunks:
|
||||
lower_chunk.score = None
|
||||
ranked_chunks.extend(lower_chunks)
|
||||
return ranked_chunks
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def filter_chunks(
|
||||
query: SearchQuery,
|
||||
chunks_to_filter: list[InferenceChunk],
|
||||
) -> list[str]:
|
||||
"""Filters chunks based on whether the LLM thought they were relevant to the query.
|
||||
|
||||
Returns a list of the unique chunk IDs that were marked as relevant"""
|
||||
chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks]
|
||||
llm_chunk_selection = llm_batch_eval_chunks(
|
||||
query=query.query,
|
||||
chunk_contents=[chunk.content for chunk in chunks_to_filter],
|
||||
)
|
||||
return [
|
||||
chunk.unique_id
|
||||
for ind, chunk in enumerate(chunks_to_filter)
|
||||
if llm_chunk_selection[ind]
|
||||
]
|
||||
|
||||
|
||||
def search_postprocessing(
|
||||
search_query: SearchQuery,
|
||||
retrieved_chunks: list[InferenceChunk],
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> Generator[list[InferenceChunk] | list[str], None, None]:
|
||||
post_processing_tasks: list[FunctionCall] = []
|
||||
|
||||
rerank_task_id = None
|
||||
if should_rerank(search_query):
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
rerank_chunks,
|
||||
(
|
||||
search_query,
|
||||
retrieved_chunks,
|
||||
rerank_metrics_callback,
|
||||
),
|
||||
)
|
||||
)
|
||||
rerank_task_id = post_processing_tasks[-1].result_id
|
||||
else:
|
||||
final_chunks = retrieved_chunks
|
||||
# NOTE: if we don't rerank, we can return the chunks immediately
|
||||
# since we know this is the final order
|
||||
_log_top_chunk_links(search_query.search_type.value, final_chunks)
|
||||
yield final_chunks
|
||||
chunks_yielded = True
|
||||
|
||||
llm_filter_task_id = None
|
||||
if should_apply_llm_based_relevance_filter(search_query):
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
filter_chunks,
|
||||
(search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]),
|
||||
)
|
||||
)
|
||||
llm_filter_task_id = post_processing_tasks[-1].result_id
|
||||
|
||||
post_processing_results = (
|
||||
run_functions_in_parallel(post_processing_tasks)
|
||||
if post_processing_tasks
|
||||
else {}
|
||||
)
|
||||
reranked_chunks = cast(
|
||||
list[InferenceChunk] | None,
|
||||
post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None,
|
||||
)
|
||||
if reranked_chunks:
|
||||
if chunks_yielded:
|
||||
logger.error(
|
||||
"Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen."
|
||||
)
|
||||
else:
|
||||
_log_top_chunk_links(search_query.search_type.value, reranked_chunks)
|
||||
yield reranked_chunks
|
||||
|
||||
llm_chunk_selection = cast(
|
||||
list[str] | None,
|
||||
post_processing_results.get(str(llm_filter_task_id))
|
||||
if llm_filter_task_id
|
||||
else None,
|
||||
)
|
||||
if llm_chunk_selection is not None:
|
||||
yield [
|
||||
chunk.unique_id
|
||||
for chunk in reranked_chunks or retrieved_chunks
|
||||
if chunk.unique_id in llm_chunk_selection
|
||||
]
|
||||
else:
|
||||
yield []
|
@ -1,10 +1,10 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
from danswer.search.search_nlp_models import IntentModel
|
||||
from danswer.search.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.server.query_and_chat.models import HelperResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
@ -5,19 +5,16 @@ from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
|
||||
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import RecencyBiasSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchRequest
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
|
||||
from danswer.search.preprocessing.danswer_helper import query_intent
|
||||
from danswer.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from danswer.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -31,15 +28,12 @@ logger = setup_logger()
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def retrieval_preprocessing(
|
||||
query: str,
|
||||
retrieval_details: RetrievalDetails,
|
||||
persona: Persona,
|
||||
search_request: SearchRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False,
|
||||
include_query_intent: bool = True,
|
||||
skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW,
|
||||
skip_rerank_non_realtime: bool = not ENABLE_RERANKING_ASYNC_FLOW,
|
||||
enable_auto_detect_filters: bool = False,
|
||||
disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
|
||||
disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||
@ -50,8 +44,12 @@ def retrieval_preprocessing(
|
||||
Then any filters or settings as part of the query are used
|
||||
Then defaults to Persona settings if not specified by the query
|
||||
"""
|
||||
query = search_request.query
|
||||
limit = search_request.limit
|
||||
offset = search_request.offset
|
||||
persona = search_request.persona
|
||||
|
||||
preset_filters = retrieval_details.filters or BaseFilters()
|
||||
preset_filters = search_request.human_selected_filters or BaseFilters()
|
||||
if persona and persona.document_sets and preset_filters.document_set is None:
|
||||
preset_filters.document_set = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
@ -65,16 +63,20 @@ def retrieval_preprocessing(
|
||||
if disable_llm_filter_extraction:
|
||||
auto_detect_time_filter = False
|
||||
auto_detect_source_filter = False
|
||||
elif retrieval_details.enable_auto_detect_filters is False:
|
||||
elif enable_auto_detect_filters is False:
|
||||
logger.debug("Retrieval details disables auto detect filters")
|
||||
auto_detect_time_filter = False
|
||||
auto_detect_source_filter = False
|
||||
elif persona.llm_filter_extraction is False:
|
||||
elif persona and persona.llm_filter_extraction is False:
|
||||
logger.debug("Persona disables auto detect filters")
|
||||
auto_detect_time_filter = False
|
||||
auto_detect_source_filter = False
|
||||
|
||||
if time_filter is not None and persona.recency_bias != RecencyBiasSetting.AUTO:
|
||||
if (
|
||||
time_filter is not None
|
||||
and persona
|
||||
and persona.recency_bias != RecencyBiasSetting.AUTO
|
||||
):
|
||||
auto_detect_time_filter = False
|
||||
logger.debug("Not extract time filter - already provided")
|
||||
if source_filter is not None:
|
||||
@ -138,24 +140,18 @@ def retrieval_preprocessing(
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
|
||||
# Tranformer-based re-ranking to run at same time as LLM chunk relevance filter
|
||||
# This one is only set globally, not via query or Persona settings
|
||||
skip_reranking = (
|
||||
skip_rerank_realtime
|
||||
if retrieval_details.real_time
|
||||
else skip_rerank_non_realtime
|
||||
)
|
||||
|
||||
llm_chunk_filter = persona.llm_relevance_filter
|
||||
llm_chunk_filter = False
|
||||
if persona:
|
||||
llm_chunk_filter = persona.llm_relevance_filter
|
||||
if disable_llm_chunk_filter:
|
||||
llm_chunk_filter = False
|
||||
|
||||
# Decays at 1 / (1 + (multiplier * num years))
|
||||
if persona.recency_bias == RecencyBiasSetting.NO_DECAY:
|
||||
if persona and persona.recency_bias == RecencyBiasSetting.NO_DECAY:
|
||||
recency_bias_multiplier = 0.0
|
||||
elif persona.recency_bias == RecencyBiasSetting.BASE_DECAY:
|
||||
elif persona and persona.recency_bias == RecencyBiasSetting.BASE_DECAY:
|
||||
recency_bias_multiplier = base_recency_decay
|
||||
elif persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT:
|
||||
elif persona and persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT:
|
||||
recency_bias_multiplier = base_recency_decay * favor_recent_decay_multiplier
|
||||
else:
|
||||
if predicted_favor_recent:
|
||||
@ -166,14 +162,12 @@ def retrieval_preprocessing(
|
||||
return (
|
||||
SearchQuery(
|
||||
query=query,
|
||||
search_type=persona.search_type,
|
||||
search_type=persona.search_type if persona else SearchType.HYBRID,
|
||||
filters=final_filters,
|
||||
recency_bias_multiplier=recency_bias_multiplier,
|
||||
num_hits=retrieval_details.limit
|
||||
if retrieval_details.limit is not None
|
||||
else NUM_RETURNED_HITS,
|
||||
offset=retrieval_details.offset or 0,
|
||||
skip_rerank=skip_reranking,
|
||||
num_hits=limit if limit is not None else NUM_RETURNED_HITS,
|
||||
offset=offset or 0,
|
||||
skip_rerank=search_request.skip_rerank,
|
||||
skip_llm_chunk_filter=not llm_chunk_filter,
|
||||
),
|
||||
predicted_search_type,
|
256
backend/danswer/search/retrieval/search_runner.py
Normal file
256
backend/danswer/search/retrieval/search_runner.py
Normal file
@ -0,0 +1,256 @@
|
||||
import string
|
||||
from collections.abc import Callable
|
||||
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
from nltk.stem import WordNetLemmatizer # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.search_nlp_models import EmbedTextType
|
||||
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def lemmatize_text(text: str) -> list[str]:
|
||||
lemmatizer = WordNetLemmatizer()
|
||||
word_tokens = word_tokenize(text)
|
||||
return [lemmatizer.lemmatize(word) for word in word_tokens]
|
||||
|
||||
|
||||
def remove_stop_words_and_punctuation(text: str) -> list[str]:
|
||||
stop_words = set(stopwords.words("english"))
|
||||
word_tokens = word_tokenize(text)
|
||||
text_trimmed = [
|
||||
word
|
||||
for word in word_tokens
|
||||
if (word.casefold() not in stop_words and word not in string.punctuation)
|
||||
]
|
||||
return text_trimmed or word_tokens
|
||||
|
||||
|
||||
def query_processing(
|
||||
query: str,
|
||||
) -> str:
|
||||
query = " ".join(remove_stop_words_and_punctuation(query))
|
||||
query = " ".join(lemmatize_text(query))
|
||||
return query
|
||||
|
||||
|
||||
def combine_retrieval_results(
|
||||
chunk_sets: list[list[InferenceChunk]],
|
||||
) -> list[InferenceChunk]:
|
||||
all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set]
|
||||
|
||||
unique_chunks: dict[tuple[str, int], InferenceChunk] = {}
|
||||
for chunk in all_chunks:
|
||||
key = (chunk.document_id, chunk.chunk_id)
|
||||
if key not in unique_chunks:
|
||||
unique_chunks[key] = chunk
|
||||
continue
|
||||
|
||||
stored_chunk_score = unique_chunks[key].score or 0
|
||||
this_chunk_score = chunk.score or 0
|
||||
if stored_chunk_score < this_chunk_score:
|
||||
unique_chunks[key] = chunk
|
||||
|
||||
sorted_chunks = sorted(
|
||||
unique_chunks.values(), key=lambda x: x.score or 0, reverse=True
|
||||
)
|
||||
|
||||
return sorted_chunks
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def doc_index_retrieval(
|
||||
query: SearchQuery,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
hybrid_alpha: float = HYBRID_ALPHA,
|
||||
) -> list[InferenceChunk]:
|
||||
if query.search_type == SearchType.KEYWORD:
|
||||
top_chunks = document_index.keyword_retrieval(
|
||||
query=query.query,
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
)
|
||||
else:
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
model = EmbeddingModel(
|
||||
model_name=db_embedding_model.model_name,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
normalize=db_embedding_model.normalize,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
|
||||
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
top_chunks = document_index.semantic_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
)
|
||||
|
||||
elif query.search_type == SearchType.HYBRID:
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
offset=query.offset,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
)
|
||||
|
||||
else:
|
||||
raise RuntimeError("Invalid Search Flow")
|
||||
|
||||
return top_chunks
|
||||
|
||||
|
||||
def _simplify_text(text: str) -> str:
|
||||
return "".join(
|
||||
char for char in text if char not in string.punctuation and not char.isspace()
|
||||
).lower()
|
||||
|
||||
|
||||
def retrieve_chunks(
|
||||
query: SearchQuery,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
|
||||
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
|
||||
# Don't do query expansion on complex queries, rephrasings likely would not work well
|
||||
if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query:
|
||||
top_chunks = doc_index_retrieval(
|
||||
query=query,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
)
|
||||
else:
|
||||
simplified_queries = set()
|
||||
run_queries: list[tuple[Callable, tuple]] = []
|
||||
|
||||
# Currently only uses query expansion on multilingual use cases
|
||||
query_rephrases = multilingual_query_expansion(
|
||||
query.query, multilingual_expansion_str
|
||||
)
|
||||
# Just to be extra sure, add the original query.
|
||||
query_rephrases.append(query.query)
|
||||
for rephrase in set(query_rephrases):
|
||||
# Sometimes the model rephrases the query in the same language with minor changes
|
||||
# Avoid doing an extra search with the minor changes as this biases the results
|
||||
simplified_rephrase = _simplify_text(rephrase)
|
||||
if simplified_rephrase in simplified_queries:
|
||||
continue
|
||||
simplified_queries.add(simplified_rephrase)
|
||||
|
||||
q_copy = query.copy(update={"query": rephrase}, deep=True)
|
||||
run_queries.append(
|
||||
(
|
||||
doc_index_retrieval,
|
||||
(q_copy, document_index, db_session, hybrid_alpha),
|
||||
)
|
||||
)
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
top_chunks = combine_retrieval_results(parallel_search_results)
|
||||
|
||||
if not top_chunks:
|
||||
logger.info(
|
||||
f"{query.search_type.value.capitalize()} search returned no results "
|
||||
f"with filters: {query.filters}"
|
||||
)
|
||||
return []
|
||||
|
||||
if retrieval_metrics_callback is not None:
|
||||
chunk_metrics = [
|
||||
ChunkMetric(
|
||||
document_id=chunk.document_id,
|
||||
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
||||
first_link=chunk.source_links[0] if chunk.source_links else None,
|
||||
score=chunk.score if chunk.score is not None else 0,
|
||||
)
|
||||
for chunk in top_chunks
|
||||
]
|
||||
retrieval_metrics_callback(
|
||||
RetrievalMetricsContainer(
|
||||
search_type=query.search_type, metrics=chunk_metrics
|
||||
)
|
||||
)
|
||||
|
||||
return top_chunks
|
||||
|
||||
|
||||
def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc:
|
||||
if not inf_chunks:
|
||||
raise ValueError("Cannot combine empty list of chunks")
|
||||
|
||||
# Use the first link of the document
|
||||
first_chunk = inf_chunks[0]
|
||||
chunk_texts = [chunk.content for chunk in inf_chunks]
|
||||
return LlmDoc(
|
||||
document_id=first_chunk.document_id,
|
||||
content="\n".join(chunk_texts),
|
||||
semantic_identifier=first_chunk.semantic_identifier,
|
||||
source_type=first_chunk.source_type,
|
||||
metadata=first_chunk.metadata,
|
||||
updated_at=first_chunk.updated_at,
|
||||
link=first_chunk.source_links[0] if first_chunk.source_links else None,
|
||||
)
|
||||
|
||||
|
||||
def inference_documents_from_ids(
|
||||
doc_identifiers: list[tuple[str, int]],
|
||||
document_index: DocumentIndex,
|
||||
) -> list[LlmDoc]:
|
||||
# Currently only fetches whole docs
|
||||
doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers)
|
||||
|
||||
# No need for ACL here because the doc ids were validated beforehand
|
||||
filters = IndexFilters(access_control_list=None)
|
||||
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(document_index.id_based_retrieval, (doc_id, None, filters))
|
||||
for doc_id in doc_ids_set
|
||||
]
|
||||
|
||||
parallel_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=True
|
||||
)
|
||||
|
||||
# Any failures to retrieve would give a None, drop the Nones and empty lists
|
||||
inference_chunks_sets = [res for res in parallel_results if res]
|
||||
|
||||
return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets]
|
@ -1,645 +0,0 @@
|
||||
import string
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
import numpy
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
from nltk.stem import WordNetLemmatizer # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
|
||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchDoc
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.search_nlp_models import EmbedTextType
|
||||
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
|
||||
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
|
||||
top_links = [
|
||||
c.source_links[0] if c.source_links is not None else "No Link" for c in chunks
|
||||
]
|
||||
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
||||
|
||||
|
||||
def lemmatize_text(text: str) -> list[str]:
|
||||
lemmatizer = WordNetLemmatizer()
|
||||
word_tokens = word_tokenize(text)
|
||||
return [lemmatizer.lemmatize(word) for word in word_tokens]
|
||||
|
||||
|
||||
def remove_stop_words_and_punctuation(text: str) -> list[str]:
|
||||
stop_words = set(stopwords.words("english"))
|
||||
word_tokens = word_tokenize(text)
|
||||
text_trimmed = [
|
||||
word
|
||||
for word in word_tokens
|
||||
if (word.casefold() not in stop_words and word not in string.punctuation)
|
||||
]
|
||||
return text_trimmed or word_tokens
|
||||
|
||||
|
||||
def query_processing(
|
||||
query: str,
|
||||
) -> str:
|
||||
query = " ".join(remove_stop_words_and_punctuation(query))
|
||||
query = " ".join(lemmatize_text(query))
|
||||
return query
|
||||
|
||||
|
||||
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
|
||||
search_docs = (
|
||||
[
|
||||
SearchDoc(
|
||||
document_id=chunk.document_id,
|
||||
chunk_ind=chunk.chunk_id,
|
||||
semantic_identifier=chunk.semantic_identifier or "Unknown",
|
||||
link=chunk.source_links.get(0) if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
boost=chunk.boost,
|
||||
hidden=chunk.hidden,
|
||||
metadata=chunk.metadata,
|
||||
score=chunk.score,
|
||||
match_highlights=chunk.match_highlights,
|
||||
updated_at=chunk.updated_at,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
if chunks
|
||||
else []
|
||||
)
|
||||
return search_docs
|
||||
|
||||
|
||||
def combine_retrieval_results(
|
||||
chunk_sets: list[list[InferenceChunk]],
|
||||
) -> list[InferenceChunk]:
|
||||
all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set]
|
||||
|
||||
unique_chunks: dict[tuple[str, int], InferenceChunk] = {}
|
||||
for chunk in all_chunks:
|
||||
key = (chunk.document_id, chunk.chunk_id)
|
||||
if key not in unique_chunks:
|
||||
unique_chunks[key] = chunk
|
||||
continue
|
||||
|
||||
stored_chunk_score = unique_chunks[key].score or 0
|
||||
this_chunk_score = chunk.score or 0
|
||||
if stored_chunk_score < this_chunk_score:
|
||||
unique_chunks[key] = chunk
|
||||
|
||||
sorted_chunks = sorted(
|
||||
unique_chunks.values(), key=lambda x: x.score or 0, reverse=True
|
||||
)
|
||||
|
||||
return sorted_chunks
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def doc_index_retrieval(
|
||||
query: SearchQuery,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
hybrid_alpha: float = HYBRID_ALPHA,
|
||||
) -> list[InferenceChunk]:
|
||||
if query.search_type == SearchType.KEYWORD:
|
||||
top_chunks = document_index.keyword_retrieval(
|
||||
query=query.query,
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
)
|
||||
else:
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
model = EmbeddingModel(
|
||||
model_name=db_embedding_model.model_name,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
normalize=db_embedding_model.normalize,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
|
||||
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
top_chunks = document_index.semantic_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
)
|
||||
|
||||
elif query.search_type == SearchType.HYBRID:
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
offset=query.offset,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
)
|
||||
|
||||
else:
|
||||
raise RuntimeError("Invalid Search Flow")
|
||||
|
||||
return top_chunks
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def semantic_reranking(
|
||||
query: str,
|
||||
chunks: list[InferenceChunk],
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
||||
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
||||
) -> tuple[list[InferenceChunk], list[int]]:
|
||||
"""Reranks chunks based on cross-encoder models. Additionally provides the original indices
|
||||
of the chunks in their new sorted order.
|
||||
|
||||
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
||||
"""
|
||||
cross_encoders = CrossEncoderEnsembleModel()
|
||||
passages = [chunk.content for chunk in chunks]
|
||||
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
|
||||
|
||||
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
|
||||
|
||||
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))
|
||||
|
||||
cross_models_min = numpy.min(sim_scores)
|
||||
|
||||
shifted_sim_scores = sum(
|
||||
[enc_n_scores - cross_models_min for enc_n_scores in sim_scores]
|
||||
) / len(sim_scores)
|
||||
|
||||
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
|
||||
recency_multiplier = [chunk.recency_bias for chunk in chunks]
|
||||
boosted_sim_scores = shifted_sim_scores * boosts * recency_multiplier
|
||||
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
|
||||
model_max - model_min
|
||||
)
|
||||
orig_indices = [i for i in range(len(normalized_b_s_scores))]
|
||||
scored_results = list(
|
||||
zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices)
|
||||
)
|
||||
scored_results.sort(key=lambda x: x[0], reverse=True)
|
||||
ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip(
|
||||
*scored_results
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}"
|
||||
)
|
||||
|
||||
# Assign new chunk scores based on reranking
|
||||
for ind, chunk in enumerate(ranked_chunks):
|
||||
chunk.score = ranked_sim_scores[ind]
|
||||
|
||||
if rerank_metrics_callback is not None:
|
||||
chunk_metrics = [
|
||||
ChunkMetric(
|
||||
document_id=chunk.document_id,
|
||||
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
||||
first_link=chunk.source_links[0] if chunk.source_links else None,
|
||||
score=chunk.score if chunk.score is not None else 0,
|
||||
)
|
||||
for chunk in ranked_chunks
|
||||
]
|
||||
|
||||
rerank_metrics_callback(
|
||||
RerankMetricsContainer(
|
||||
metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores # type: ignore
|
||||
)
|
||||
)
|
||||
|
||||
return list(ranked_chunks), list(ranked_indices)
|
||||
|
||||
|
||||
def apply_boost_legacy(
|
||||
chunks: list[InferenceChunk],
|
||||
norm_min: float = SIM_SCORE_RANGE_LOW,
|
||||
norm_max: float = SIM_SCORE_RANGE_HIGH,
|
||||
) -> list[InferenceChunk]:
|
||||
scores = [chunk.score or 0 for chunk in chunks]
|
||||
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
|
||||
|
||||
logger.debug(f"Raw similarity scores: {scores}")
|
||||
|
||||
score_min = min(scores)
|
||||
score_max = max(scores)
|
||||
score_range = score_max - score_min
|
||||
|
||||
if score_range != 0:
|
||||
boosted_scores = [
|
||||
((score - score_min) / score_range) * boost
|
||||
for score, boost in zip(scores, boosts)
|
||||
]
|
||||
unnormed_boosted_scores = [
|
||||
score * score_range + score_min for score in boosted_scores
|
||||
]
|
||||
else:
|
||||
unnormed_boosted_scores = [
|
||||
score * boost for score, boost in zip(scores, boosts)
|
||||
]
|
||||
|
||||
norm_min = min(norm_min, min(scores))
|
||||
norm_max = max(norm_max, max(scores))
|
||||
# This should never be 0 unless user has done some weird/wrong settings
|
||||
norm_range = norm_max - norm_min
|
||||
|
||||
# For score display purposes
|
||||
if norm_range != 0:
|
||||
re_normed_scores = [
|
||||
((score - norm_min) / norm_range) for score in unnormed_boosted_scores
|
||||
]
|
||||
else:
|
||||
re_normed_scores = unnormed_boosted_scores
|
||||
|
||||
rescored_chunks = list(zip(re_normed_scores, chunks))
|
||||
rescored_chunks.sort(key=lambda x: x[0], reverse=True)
|
||||
sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks)
|
||||
|
||||
final_chunks = list(boost_sorted_chunks)
|
||||
final_scores = list(sorted_boosted_scores)
|
||||
for ind, chunk in enumerate(final_chunks):
|
||||
chunk.score = final_scores[ind]
|
||||
|
||||
logger.debug(f"Boost sorted similary scores: {list(final_scores)}")
|
||||
|
||||
return final_chunks
|
||||
|
||||
|
||||
def apply_boost(
|
||||
chunks: list[InferenceChunk],
|
||||
# Need the range of values to not be too spread out for applying boost
|
||||
# therefore norm across only the top few results
|
||||
norm_cutoff: int = NUM_RERANKED_RESULTS,
|
||||
norm_min: float = SIM_SCORE_RANGE_LOW,
|
||||
norm_max: float = SIM_SCORE_RANGE_HIGH,
|
||||
) -> list[InferenceChunk]:
|
||||
scores = [chunk.score or 0.0 for chunk in chunks]
|
||||
logger.debug(f"Raw similarity scores: {scores}")
|
||||
|
||||
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
|
||||
recency_multiplier = [chunk.recency_bias for chunk in chunks]
|
||||
|
||||
norm_min = min(norm_min, min(scores[:norm_cutoff]))
|
||||
norm_max = max(norm_max, max(scores[:norm_cutoff]))
|
||||
# This should never be 0 unless user has done some weird/wrong settings
|
||||
norm_range = norm_max - norm_min
|
||||
|
||||
boosted_scores = [
|
||||
max(0, (score - norm_min) * boost * recency / norm_range)
|
||||
for score, boost, recency in zip(scores, boosts, recency_multiplier)
|
||||
]
|
||||
|
||||
rescored_chunks = list(zip(boosted_scores, chunks))
|
||||
rescored_chunks.sort(key=lambda x: x[0], reverse=True)
|
||||
sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks)
|
||||
|
||||
final_chunks = list(boost_sorted_chunks)
|
||||
final_scores = list(sorted_boosted_scores)
|
||||
for ind, chunk in enumerate(final_chunks):
|
||||
chunk.score = final_scores[ind]
|
||||
|
||||
logger.debug(
|
||||
f"Boosted + Time Weighted sorted similarity scores: {list(final_scores)}"
|
||||
)
|
||||
|
||||
return final_chunks
|
||||
|
||||
|
||||
def _simplify_text(text: str) -> str:
|
||||
return "".join(
|
||||
char for char in text if char not in string.punctuation and not char.isspace()
|
||||
).lower()
|
||||
|
||||
|
||||
def retrieve_chunks(
|
||||
query: SearchQuery,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
|
||||
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
|
||||
# Don't do query expansion on complex queries, rephrasings likely would not work well
|
||||
if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query:
|
||||
top_chunks = doc_index_retrieval(
|
||||
query=query,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
)
|
||||
else:
|
||||
simplified_queries = set()
|
||||
run_queries: list[tuple[Callable, tuple]] = []
|
||||
|
||||
# Currently only uses query expansion on multilingual use cases
|
||||
query_rephrases = multilingual_query_expansion(
|
||||
query.query, multilingual_expansion_str
|
||||
)
|
||||
# Just to be extra sure, add the original query.
|
||||
query_rephrases.append(query.query)
|
||||
for rephrase in set(query_rephrases):
|
||||
# Sometimes the model rephrases the query in the same language with minor changes
|
||||
# Avoid doing an extra search with the minor changes as this biases the results
|
||||
simplified_rephrase = _simplify_text(rephrase)
|
||||
if simplified_rephrase in simplified_queries:
|
||||
continue
|
||||
simplified_queries.add(simplified_rephrase)
|
||||
|
||||
q_copy = query.copy(update={"query": rephrase}, deep=True)
|
||||
run_queries.append(
|
||||
(
|
||||
doc_index_retrieval,
|
||||
(q_copy, document_index, db_session, hybrid_alpha),
|
||||
)
|
||||
)
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
top_chunks = combine_retrieval_results(parallel_search_results)
|
||||
|
||||
if not top_chunks:
|
||||
logger.info(
|
||||
f"{query.search_type.value.capitalize()} search returned no results "
|
||||
f"with filters: {query.filters}"
|
||||
)
|
||||
return []
|
||||
|
||||
if retrieval_metrics_callback is not None:
|
||||
chunk_metrics = [
|
||||
ChunkMetric(
|
||||
document_id=chunk.document_id,
|
||||
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
||||
first_link=chunk.source_links[0] if chunk.source_links else None,
|
||||
score=chunk.score if chunk.score is not None else 0,
|
||||
)
|
||||
for chunk in top_chunks
|
||||
]
|
||||
retrieval_metrics_callback(
|
||||
RetrievalMetricsContainer(
|
||||
search_type=query.search_type, metrics=chunk_metrics
|
||||
)
|
||||
)
|
||||
|
||||
return top_chunks
|
||||
|
||||
|
||||
def should_rerank(query: SearchQuery) -> bool:
|
||||
# Don't re-rank for keyword search
|
||||
return query.search_type != SearchType.KEYWORD and not query.skip_rerank
|
||||
|
||||
|
||||
def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool:
|
||||
return not query.skip_llm_chunk_filter
|
||||
|
||||
|
||||
def rerank_chunks(
|
||||
query: SearchQuery,
|
||||
chunks_to_rerank: list[InferenceChunk],
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
ranked_chunks, _ = semantic_reranking(
|
||||
query=query.query,
|
||||
chunks=chunks_to_rerank[: query.num_rerank],
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
lower_chunks = chunks_to_rerank[query.num_rerank :]
|
||||
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
||||
for lower_chunk in lower_chunks:
|
||||
lower_chunk.score = None
|
||||
ranked_chunks.extend(lower_chunks)
|
||||
return ranked_chunks
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def filter_chunks(
|
||||
query: SearchQuery,
|
||||
chunks_to_filter: list[InferenceChunk],
|
||||
) -> list[str]:
|
||||
"""Filters chunks based on whether the LLM thought they were relevant to the query.
|
||||
|
||||
Returns a list of the unique chunk IDs that were marked as relevant"""
|
||||
chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks]
|
||||
llm_chunk_selection = llm_batch_eval_chunks(
|
||||
query=query.query,
|
||||
chunk_contents=[chunk.content for chunk in chunks_to_filter],
|
||||
)
|
||||
return [
|
||||
chunk.unique_id
|
||||
for ind, chunk in enumerate(chunks_to_filter)
|
||||
if llm_chunk_selection[ind]
|
||||
]
|
||||
|
||||
|
||||
def full_chunk_search(
|
||||
query: SearchQuery,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
|
||||
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> tuple[list[InferenceChunk], list[bool]]:
|
||||
"""A utility which provides an easier interface than `full_chunk_search_generator`.
|
||||
Rather than returning the chunks and llm relevance filter results in two separate
|
||||
yields, just returns them both at once."""
|
||||
search_generator = full_chunk_search_generator(
|
||||
search_query=query,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
multilingual_expansion_str=multilingual_expansion_str,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
top_chunks = cast(list[InferenceChunk], next(search_generator))
|
||||
llm_chunk_selection = cast(list[bool], next(search_generator))
|
||||
return top_chunks, llm_chunk_selection
|
||||
|
||||
|
||||
def empty_search_generator() -> Iterator[list[InferenceChunk] | list[bool]]:
|
||||
yield cast(list[InferenceChunk], [])
|
||||
yield cast(list[bool], [])
|
||||
|
||||
|
||||
def full_chunk_search_generator(
|
||||
search_query: SearchQuery,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
|
||||
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> Iterator[list[InferenceChunk] | list[bool]]:
|
||||
"""Always yields twice. Once with the selected chunks and once with the LLM relevance filter result.
|
||||
If LLM filter results are turned off, returns a list of False
|
||||
"""
|
||||
chunks_yielded = False
|
||||
|
||||
retrieved_chunks = retrieve_chunks(
|
||||
query=search_query,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
multilingual_expansion_str=multilingual_expansion_str,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
)
|
||||
|
||||
if not retrieved_chunks:
|
||||
yield cast(list[InferenceChunk], [])
|
||||
yield cast(list[bool], [])
|
||||
return
|
||||
|
||||
post_processing_tasks: list[FunctionCall] = []
|
||||
|
||||
rerank_task_id = None
|
||||
if should_rerank(search_query):
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
rerank_chunks,
|
||||
(
|
||||
search_query,
|
||||
retrieved_chunks,
|
||||
rerank_metrics_callback,
|
||||
),
|
||||
)
|
||||
)
|
||||
rerank_task_id = post_processing_tasks[-1].result_id
|
||||
else:
|
||||
final_chunks = retrieved_chunks
|
||||
# NOTE: if we don't rerank, we can return the chunks immediately
|
||||
# since we know this is the final order
|
||||
_log_top_chunk_links(search_query.search_type.value, final_chunks)
|
||||
yield final_chunks
|
||||
chunks_yielded = True
|
||||
|
||||
llm_filter_task_id = None
|
||||
if should_apply_llm_based_relevance_filter(search_query):
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
filter_chunks,
|
||||
(search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]),
|
||||
)
|
||||
)
|
||||
llm_filter_task_id = post_processing_tasks[-1].result_id
|
||||
|
||||
post_processing_results = (
|
||||
run_functions_in_parallel(post_processing_tasks)
|
||||
if post_processing_tasks
|
||||
else {}
|
||||
)
|
||||
reranked_chunks = cast(
|
||||
list[InferenceChunk] | None,
|
||||
post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None,
|
||||
)
|
||||
if reranked_chunks:
|
||||
if chunks_yielded:
|
||||
logger.error(
|
||||
"Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen."
|
||||
)
|
||||
else:
|
||||
_log_top_chunk_links(search_query.search_type.value, reranked_chunks)
|
||||
yield reranked_chunks
|
||||
|
||||
llm_chunk_selection = cast(
|
||||
list[str] | None,
|
||||
post_processing_results.get(str(llm_filter_task_id))
|
||||
if llm_filter_task_id
|
||||
else None,
|
||||
)
|
||||
if llm_chunk_selection is not None:
|
||||
yield [
|
||||
chunk.unique_id in llm_chunk_selection
|
||||
for chunk in reranked_chunks or retrieved_chunks
|
||||
]
|
||||
else:
|
||||
yield [False for _ in reranked_chunks or retrieved_chunks]
|
||||
|
||||
|
||||
def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc:
|
||||
if not inf_chunks:
|
||||
raise ValueError("Cannot combine empty list of chunks")
|
||||
|
||||
# Use the first link of the document
|
||||
first_chunk = inf_chunks[0]
|
||||
chunk_texts = [chunk.content for chunk in inf_chunks]
|
||||
return LlmDoc(
|
||||
document_id=first_chunk.document_id,
|
||||
content="\n".join(chunk_texts),
|
||||
semantic_identifier=first_chunk.semantic_identifier,
|
||||
source_type=first_chunk.source_type,
|
||||
metadata=first_chunk.metadata,
|
||||
updated_at=first_chunk.updated_at,
|
||||
link=first_chunk.source_links[0] if first_chunk.source_links else None,
|
||||
)
|
||||
|
||||
|
||||
def inference_documents_from_ids(
|
||||
doc_identifiers: list[tuple[str, int]],
|
||||
document_index: DocumentIndex,
|
||||
) -> list[LlmDoc]:
|
||||
# Currently only fetches whole docs
|
||||
doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers)
|
||||
|
||||
# No need for ACL here because the doc ids were validated beforehand
|
||||
filters = IndexFilters(access_control_list=None)
|
||||
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(document_index.id_based_retrieval, (doc_id, None, filters))
|
||||
for doc_id in doc_ids_set
|
||||
]
|
||||
|
||||
parallel_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=True
|
||||
)
|
||||
|
||||
# Any failures to retrieve would give a None, drop the Nones and empty lists
|
||||
inference_chunks_sets = [res for res in parallel_results if res]
|
||||
|
||||
return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets]
|
29
backend/danswer/search/utils.py
Normal file
29
backend/danswer/search/utils.py
Normal file
@ -0,0 +1,29 @@
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.models import SearchDoc
|
||||
|
||||
|
||||
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
|
||||
search_docs = (
|
||||
[
|
||||
SearchDoc(
|
||||
document_id=chunk.document_id,
|
||||
chunk_ind=chunk.chunk_id,
|
||||
semantic_identifier=chunk.semantic_identifier or "Unknown",
|
||||
link=chunk.source_links.get(0) if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
boost=chunk.boost,
|
||||
hidden=chunk.hidden,
|
||||
metadata=chunk.metadata,
|
||||
score=chunk.score,
|
||||
match_highlights=chunk.match_highlights,
|
||||
updated_at=chunk.updated_at,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
)
|
||||
for chunk in chunks
|
||||
]
|
||||
if chunks
|
||||
else []
|
||||
)
|
||||
return search_docs
|
@ -11,8 +11,8 @@ from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
|
||||
from danswer.server.documents.models import ChunkInfo
|
||||
from danswer.server.documents.models import DocumentInfo
|
||||
|
||||
|
@ -4,7 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import StarterMessage
|
||||
from danswer.search.models import RecencyBiasSetting
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.server.features.document_set.models import DocumentSet
|
||||
from danswer.server.features.prompt.models import PromptSnapshot
|
||||
|
||||
|
@ -6,13 +6,9 @@ from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.search_runner import full_chunk_search
|
||||
from danswer.search.models import SearchRequest
|
||||
from danswer.search.pipeline import SearchPipeline
|
||||
from danswer.server.danswer_api.ingestion import api_key_dep
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@ -70,27 +66,13 @@ def gpt_search(
|
||||
_: str | None = Depends(api_key_dep),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> GptSearchResponse:
|
||||
query = search_request.query
|
||||
|
||||
user_acl_filters = build_access_filters_for_user(None, db_session)
|
||||
final_filters = IndexFilters(access_control_list=user_acl_filters)
|
||||
|
||||
search_query = SearchQuery(
|
||||
query=query,
|
||||
filters=final_filters,
|
||||
recency_bias_multiplier=1.0,
|
||||
skip_llm_chunk_filter=True,
|
||||
)
|
||||
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
top_chunks, __ = full_chunk_search(
|
||||
query=search_query, document_index=document_index, db_session=db_session
|
||||
)
|
||||
top_chunks = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=search_request.query,
|
||||
),
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
).reranked_docs
|
||||
|
||||
return GptSearchResponse(
|
||||
matching_document_chunks=[
|
||||
|
@ -15,11 +15,11 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.one_shot_answer.answer_question import stream_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.danswer_helper import recommend_search_flow
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import SearchDoc
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
|
||||
from danswer.search.preprocessing.danswer_helper import recommend_search_flow
|
||||
from danswer.search.utils import chunks_to_search_docs
|
||||
from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
||||
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
||||
from danswer.server.query_and_chat.models import AdminSearchRequest
|
||||
|
@ -8,15 +8,12 @@ from typing import TextIO
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import get_chunks_for_qa
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.search_runner import full_chunk_search
|
||||
from danswer.search.models import SearchRequest
|
||||
from danswer.search.pipeline import SearchPipeline
|
||||
from danswer.utils.callbacks import MetricsHander
|
||||
|
||||
|
||||
@ -81,35 +78,22 @@ def get_search_results(
|
||||
RetrievalMetricsContainer | None,
|
||||
RerankMetricsContainer | None,
|
||||
]:
|
||||
filters = IndexFilters(
|
||||
source_type=None,
|
||||
document_set=None,
|
||||
time_cutoff=None,
|
||||
access_control_list=None,
|
||||
)
|
||||
search_query = SearchQuery(
|
||||
query=query,
|
||||
filters=filters,
|
||||
recency_bias_multiplier=1.0,
|
||||
)
|
||||
|
||||
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
|
||||
rerank_metrics = MetricsHander[RerankMetricsContainer]()
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
),
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
retrieval_metrics_callback=retrieval_metrics.record_metric,
|
||||
rerank_metrics_callback=rerank_metrics.record_metric,
|
||||
)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
top_chunks, llm_chunk_selection = full_chunk_search(
|
||||
query=search_query,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
retrieval_metrics_callback=retrieval_metrics.record_metric,
|
||||
rerank_metrics_callback=rerank_metrics.record_metric,
|
||||
)
|
||||
top_chunks = search_pipeline.reranked_docs
|
||||
llm_chunk_selection = search_pipeline.chunk_relevance_list
|
||||
|
||||
llm_chunks_indices = get_chunks_for_qa(
|
||||
chunks=top_chunks,
|
||||
|
Loading…
x
Reference in New Issue
Block a user