diff --git a/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py b/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py index 272335ca0..1e2e7cd3c 100644 --- a/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py +++ b/backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py @@ -9,7 +9,7 @@ from alembic import op import sqlalchemy as sa from danswer.db.models import IndexModelStatus -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting from danswer.search.models import SearchType # revision identifiers, used by Alembic. diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index 0800abb70..ccc754437 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -13,7 +13,7 @@ from danswer.db.document_set import get_or_create_document_set_by_name from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import DocumentSet as DocumentSetDBModel from danswer.db.models import Prompt as PromptDBModel -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None: diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index de3f7e4f0..47d554de7 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -5,10 +5,10 @@ from typing import Any from pydantic import BaseModel from danswer.configs.constants import DocumentSource -from danswer.search.models import QueryFlow +from danswer.search.enums import QueryFlow +from danswer.search.enums import SearchType from danswer.search.models import RetrievalDocs from danswer.search.models import SearchResponse -from danswer.search.models import SearchType class LlmDoc(BaseModel): diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index aafe5d000..9cd78c963 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -53,11 +53,10 @@ from danswer.llm.utils import tokenizer_trim_content from danswer.llm.utils import translate_history_to_basemessages from danswer.prompts.prompt_utils import build_doc_context_str from danswer.search.models import OptionalSearchSetting -from danswer.search.models import RetrievalDetails -from danswer.search.request_preprocessing import retrieval_preprocessing -from danswer.search.search_runner import chunks_to_search_docs -from danswer.search.search_runner import full_chunk_search_generator -from danswer.search.search_runner import inference_documents_from_ids +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline +from danswer.search.retrieval.search_runner import inference_documents_from_ids +from danswer.search.utils import chunks_to_search_docs from danswer.secondary_llm_flows.choose_search import check_if_need_search from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail @@ -377,37 +376,25 @@ def stream_chat_message_objects( else query_override ) - ( - retrieval_request, - predicted_search_type, - predicted_flow, - ) = retrieval_preprocessing( - query=rephrased_query, - retrieval_details=cast(RetrievalDetails, retrieval_options), - persona=persona, + search_pipeline = SearchPipeline( + search_request=SearchRequest( + query=rephrased_query, + human_selected_filters=retrieval_options.filters + if retrieval_options + else None, + persona=persona, + offset=retrieval_options.offset if retrieval_options else None, + limit=retrieval_options.limit if retrieval_options else None, + ), user=user, db_session=db_session, ) - documents_generator = full_chunk_search_generator( - search_query=retrieval_request, - document_index=document_index, - db_session=db_session, - ) - time_cutoff = retrieval_request.filters.time_cutoff - recency_bias_multiplier = retrieval_request.recency_bias_multiplier - run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter - - # First fetch and return the top chunks to the UI so the user can - # immediately see some results - top_chunks = cast(list[InferenceChunk], next(documents_generator)) + top_chunks = search_pipeline.reranked_docs + top_docs = chunks_to_search_docs(top_chunks) # Get ranking of the documents for citation purposes later - doc_id_to_rank_map = map_document_id_order( - cast(list[InferenceChunk | LlmDoc], top_chunks) - ) - - top_docs = chunks_to_search_docs(top_chunks) + doc_id_to_rank_map = map_document_id_order(top_chunks) reference_db_search_docs = [ create_db_search_doc(server_search_doc=top_doc, db_session=db_session) @@ -422,24 +409,17 @@ def stream_chat_message_objects( initial_response = QADocsResponse( rephrased_query=rephrased_query, top_documents=response_docs, - predicted_flow=predicted_flow, - predicted_search=predicted_search_type, - applied_source_filters=retrieval_request.filters.source_type, - applied_time_cutoff=time_cutoff, - recency_bias_multiplier=recency_bias_multiplier, + predicted_flow=search_pipeline.predicted_flow, + predicted_search=search_pipeline.predicted_search_type, + applied_source_filters=search_pipeline.search_query.filters.source_type, + applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff, + recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, ) yield initial_response - # Get the final ordering of chunks for the LLM call - llm_chunk_selection = cast(list[bool], next(documents_generator)) - # Yield the list of LLM selected chunks for showing the LLM selected icons in the UI llm_relevance_filtering_response = LLMRelevanceFilterResponse( - relevant_chunk_indices=[ - index for index, value in enumerate(llm_chunk_selection) if value - ] - if run_llm_chunk_filter - else [] + relevant_chunk_indices=search_pipeline.relevant_chunk_indicies ) yield llm_relevance_filtering_response @@ -467,7 +447,7 @@ def stream_chat_message_objects( ) llm_chunks_indices = get_chunks_for_qa( chunks=top_chunks, - llm_chunk_selection=llm_chunk_selection, + llm_chunk_selection=search_pipeline.chunk_relevance_list, token_limit=chunk_token_limit, llm_tokenizer=llm_tokenizer, ) diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 343912e27..6dfa02c2f 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -27,7 +27,7 @@ from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import StarterMessage from danswer.db.models import User__UserGroup -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting from danswer.search.models import RetrievalDocs from danswer.search.models import SavedSearchDoc from danswer.search.models import SearchDoc as ServerSearchDoc diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index abe189c45..faafd2aed 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -36,8 +36,8 @@ from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.connectors.models import InputType from danswer.dynamic_configs.interface import JSON_ro -from danswer.search.models import RecencyBiasSetting -from danswer.search.models import SearchType +from danswer.search.enums import RecencyBiasSetting +from danswer.search.enums import SearchType class IndexingStatus(str, PyEnum): diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index 3e93a76cf..c3b463e35 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -12,7 +12,7 @@ from danswer.db.models import Persona from danswer.db.models import Persona__DocumentSet from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotResponseType -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting def _build_persona_name(channel_names: list[str]) -> str: diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index 178aadf3e..9f78f05c2 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -64,8 +64,8 @@ from danswer.document_index.vespa.utils import remove_invalid_unicode_chars from danswer.indexing.models import DocMetadataAwareIndexChunk from danswer.indexing.models import InferenceChunk from danswer.search.models import IndexFilters -from danswer.search.search_runner import query_processing -from danswer.search.search_runner import remove_stop_words_and_punctuation +from danswer.search.retrieval.search_runner import query_processing +from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 529180a79..db5ef6f0f 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -1,7 +1,6 @@ import itertools from collections.abc import Callable from collections.abc import Iterator -from typing import cast from langchain.schema.messages import BaseMessage from langchain.schema.messages import HumanMessage @@ -33,11 +32,9 @@ from danswer.db.chat import get_or_create_root_message from danswer.db.chat import get_persona_by_id from danswer.db.chat import get_prompt_by_id from danswer.db.chat import translate_db_message_to_chat_message_detail -from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session_context_manager from danswer.db.models import Prompt from danswer.db.models import User -from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import InferenceChunk from danswer.llm.factory import get_default_llm from danswer.llm.utils import get_default_llm_token_encode @@ -55,9 +52,9 @@ from danswer.prompts.prompt_utils import build_task_prompt_reminders from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import SavedSearchDoc -from danswer.search.request_preprocessing import retrieval_preprocessing -from danswer.search.search_runner import chunks_to_search_docs -from danswer.search.search_runner import full_chunk_search_generator +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline +from danswer.search.utils import chunks_to_search_docs from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail @@ -221,12 +218,6 @@ def stream_answer_objects( llm_tokenizer = get_default_llm_token_encode() - embedding_model = get_current_db_embedding_model(db_session) - - document_index = get_default_document_index( - primary_index_name=embedding_model.index_name, secondary_index_name=None - ) - # Create a chat session which will just store the root message, the query, and the AI response root_message = get_or_create_root_message( chat_session_id=chat_session.id, db_session=db_session @@ -244,33 +235,23 @@ def stream_answer_objects( # In chat flow it's given back along with the documents yield QueryRephrase(rephrased_query=rephrased_query) - ( - retrieval_request, - predicted_search_type, - predicted_flow, - ) = retrieval_preprocessing( - query=rephrased_query, - retrieval_details=query_req.retrieval_options, - persona=chat_session.persona, + search_pipeline = SearchPipeline( + search_request=SearchRequest( + query=rephrased_query, + human_selected_filters=query_req.retrieval_options.filters, + persona=chat_session.persona, + offset=query_req.retrieval_options.offset, + limit=query_req.retrieval_options.limit, + ), user=user, db_session=db_session, bypass_acl=bypass_acl, - ) - - documents_generator = full_chunk_search_generator( - search_query=retrieval_request, - document_index=document_index, - db_session=db_session, retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, ) - applied_time_cutoff = retrieval_request.filters.time_cutoff - recency_bias_multiplier = retrieval_request.recency_bias_multiplier - run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter # First fetch and return the top chunks so the user can immediately see some results - top_chunks = cast(list[InferenceChunk], next(documents_generator)) - + top_chunks = search_pipeline.reranked_docs top_docs = chunks_to_search_docs(top_chunks) fake_saved_docs = [SavedSearchDoc.from_search_doc(doc) for doc in top_docs] @@ -278,24 +259,17 @@ def stream_answer_objects( initial_response = QADocsResponse( rephrased_query=rephrased_query, top_documents=fake_saved_docs, - predicted_flow=predicted_flow, - predicted_search=predicted_search_type, - applied_source_filters=retrieval_request.filters.source_type, - applied_time_cutoff=applied_time_cutoff, - recency_bias_multiplier=recency_bias_multiplier, + predicted_flow=search_pipeline.predicted_flow, + predicted_search=search_pipeline.predicted_search_type, + applied_source_filters=search_pipeline.search_query.filters.source_type, + applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff, + recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, ) yield initial_response - # Get the final ordering of chunks for the LLM call - llm_chunk_selection = cast(list[bool], next(documents_generator)) - # Yield the list of LLM selected chunks for showing the LLM selected icons in the UI llm_relevance_filtering_response = LLMRelevanceFilterResponse( - relevant_chunk_indices=[ - index for index, value in enumerate(llm_chunk_selection) if value - ] - if run_llm_chunk_filter - else [] + relevant_chunk_indices=search_pipeline.relevant_chunk_indicies ) yield llm_relevance_filtering_response @@ -317,7 +291,7 @@ def stream_answer_objects( llm_chunks_indices = get_chunks_for_qa( chunks=top_chunks, - llm_chunk_selection=llm_chunk_selection, + llm_chunk_selection=search_pipeline.chunk_relevance_list, token_limit=chunk_token_limit, ) llm_chunks = [top_chunks[i] for i in llm_chunks_indices] diff --git a/backend/danswer/search/enums.py b/backend/danswer/search/enums.py new file mode 100644 index 000000000..9ba44ada2 --- /dev/null +++ b/backend/danswer/search/enums.py @@ -0,0 +1,30 @@ +"""NOTE: this needs to be separate from models.py because of circular imports. +Both search/models.py and db/models.py import enums from this file AND +search/models.py imports from db/models.py.""" +from enum import Enum + + +class OptionalSearchSetting(str, Enum): + ALWAYS = "always" + NEVER = "never" + # Determine whether to run search based on history and latest query + AUTO = "auto" + + +class RecencyBiasSetting(str, Enum): + FAVOR_RECENT = "favor_recent" # 2x decay rate + BASE_DECAY = "base_decay" + NO_DECAY = "no_decay" + # Determine based on query if to use base_decay or favor_recent + AUTO = "auto" + + +class SearchType(str, Enum): + KEYWORD = "keyword" + SEMANTIC = "semantic" + HYBRID = "hybrid" + + +class QueryFlow(str, Enum): + SEARCH = "search" + QUESTION_ANSWER = "question-answer" diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index db3dc31f8..d2ad74c34 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -1,46 +1,24 @@ from datetime import datetime -from enum import Enum from typing import Any from pydantic import BaseModel from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER +from danswer.configs.chat_configs import HYBRID_ALPHA from danswer.configs.chat_configs import NUM_RERANKED_RESULTS from danswer.configs.chat_configs import NUM_RETURNED_HITS from danswer.configs.constants import DocumentSource from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW +from danswer.db.models import Persona +from danswer.search.enums import OptionalSearchSetting +from danswer.search.enums import SearchType + MAX_METRICS_CONTENT = ( 200 # Just need enough characters to identify where in the doc the chunk is ) -class OptionalSearchSetting(str, Enum): - ALWAYS = "always" - NEVER = "never" - # Determine whether to run search based on history and latest query - AUTO = "auto" - - -class RecencyBiasSetting(str, Enum): - FAVOR_RECENT = "favor_recent" # 2x decay rate - BASE_DECAY = "base_decay" - NO_DECAY = "no_decay" - # Determine based on query if to use base_decay or favor_recent - AUTO = "auto" - - -class SearchType(str, Enum): - KEYWORD = "keyword" - SEMANTIC = "semantic" - HYBRID = "hybrid" - - -class QueryFlow(str, Enum): - SEARCH = "search" - QUESTION_ANSWER = "question-answer" - - class Tag(BaseModel): tag_key: str tag_value: str @@ -64,6 +42,28 @@ class ChunkMetric(BaseModel): score: float +class SearchRequest(BaseModel): + """Input to the SearchPipeline.""" + + query: str + search_type: SearchType = SearchType.HYBRID + + human_selected_filters: BaseFilters | None = None + enable_auto_detect_filters: bool | None = None + persona: Persona | None = None + + # if None, no offset / limit + offset: int | None = None + limit: int | None = None + + recency_bias_multiplier: float = 1.0 + hybrid_alpha: float = HYBRID_ALPHA + skip_rerank: bool = True + + class Config: + arbitrary_types_allowed = True + + class SearchQuery(BaseModel): query: str filters: IndexFilters diff --git a/backend/danswer/search/pipeline.py b/backend/danswer/search/pipeline.py new file mode 100644 index 000000000..972f510db --- /dev/null +++ b/backend/danswer/search/pipeline.py @@ -0,0 +1,152 @@ +from collections.abc import Callable +from typing import cast + +from sqlalchemy.orm import Session + +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.models import User +from danswer.document_index.factory import get_default_document_index +from danswer.indexing.models import InferenceChunk +from danswer.search.enums import QueryFlow +from danswer.search.enums import SearchType +from danswer.search.models import RerankMetricsContainer +from danswer.search.models import RetrievalMetricsContainer +from danswer.search.models import SearchQuery +from danswer.search.models import SearchRequest +from danswer.search.postprocessing.postprocessing import search_postprocessing +from danswer.search.preprocessing.preprocessing import retrieval_preprocessing +from danswer.search.retrieval.search_runner import retrieve_chunks + + +class SearchPipeline: + def __init__( + self, + search_request: SearchRequest, + user: User | None, + db_session: Session, + bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION + retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] + | None = None, + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, + ): + self.search_request = search_request + self.user = user + self.db_session = db_session + self.bypass_acl = bypass_acl + self.retrieval_metrics_callback = retrieval_metrics_callback + self.rerank_metrics_callback = rerank_metrics_callback + + self.embedding_model = get_current_db_embedding_model(db_session) + self.document_index = get_default_document_index( + primary_index_name=self.embedding_model.index_name, + secondary_index_name=None, + ) + + self._search_query: SearchQuery | None = None + self._predicted_search_type: SearchType | None = None + self._predicted_flow: QueryFlow | None = None + + self._retrieved_docs: list[InferenceChunk] | None = None + self._reranked_docs: list[InferenceChunk] | None = None + self._relevant_chunk_indicies: list[int] | None = None + + """Pre-processing""" + + def _run_preprocessing(self) -> None: + ( + final_search_query, + predicted_search_type, + predicted_flow, + ) = retrieval_preprocessing( + search_request=self.search_request, + user=self.user, + db_session=self.db_session, + bypass_acl=self.bypass_acl, + ) + self._predicted_search_type = predicted_search_type + self._predicted_flow = predicted_flow + self._search_query = final_search_query + + @property + def search_query(self) -> SearchQuery: + if self._search_query is not None: + return self._search_query + + self._run_preprocessing() + return cast(SearchQuery, self._search_query) + + @property + def predicted_search_type(self) -> SearchType: + if self._predicted_search_type is not None: + return self._predicted_search_type + + self._run_preprocessing() + return cast(SearchType, self._predicted_search_type) + + @property + def predicted_flow(self) -> QueryFlow: + if self._predicted_flow is not None: + return self._predicted_flow + + self._run_preprocessing() + return cast(QueryFlow, self._predicted_flow) + + """Retrieval""" + + @property + def retrieved_docs(self) -> list[InferenceChunk]: + if self._retrieved_docs is not None: + return self._retrieved_docs + + self._retrieved_docs = retrieve_chunks( + query=self.search_query, + document_index=self.document_index, + db_session=self.db_session, + hybrid_alpha=self.search_request.hybrid_alpha, + multilingual_expansion_str=MULTILINGUAL_QUERY_EXPANSION, + retrieval_metrics_callback=self.retrieval_metrics_callback, + ) + + # self._retrieved_docs = chunks_to_search_docs(retrieved_chunks) + return cast(list[InferenceChunk], self._retrieved_docs) + + """Post-Processing""" + + def _run_postprocessing(self) -> None: + postprocessing_generator = search_postprocessing( + search_query=self.search_query, + retrieved_chunks=self.retrieved_docs, + rerank_metrics_callback=self.rerank_metrics_callback, + ) + self._reranked_docs = cast(list[InferenceChunk], next(postprocessing_generator)) + + relevant_chunk_ids = cast(list[str], next(postprocessing_generator)) + self._relevant_chunk_indicies = [ + ind + for ind, chunk in enumerate(self._reranked_docs) + if chunk.unique_id in relevant_chunk_ids + ] + + @property + def reranked_docs(self) -> list[InferenceChunk]: + if self._reranked_docs is not None: + return self._reranked_docs + + self._run_postprocessing() + return cast(list[InferenceChunk], self._reranked_docs) + + @property + def relevant_chunk_indicies(self) -> list[int]: + if self._relevant_chunk_indicies is not None: + return self._relevant_chunk_indicies + + self._run_postprocessing() + return cast(list[int], self._relevant_chunk_indicies) + + @property + def chunk_relevance_list(self) -> list[bool]: + return [ + True if ind in self.relevant_chunk_indicies else False + for ind in range(len(self.reranked_docs)) + ] diff --git a/backend/danswer/search/postprocessing/postprocessing.py b/backend/danswer/search/postprocessing/postprocessing.py new file mode 100644 index 000000000..e1cee4bd6 --- /dev/null +++ b/backend/danswer/search/postprocessing/postprocessing.py @@ -0,0 +1,222 @@ +from collections.abc import Callable +from collections.abc import Generator +from typing import cast + +import numpy + +from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX +from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN +from danswer.document_index.document_index_utils import ( + translate_boost_count_to_multiplier, +) +from danswer.indexing.models import InferenceChunk +from danswer.search.models import ChunkMetric +from danswer.search.models import MAX_METRICS_CONTENT +from danswer.search.models import RerankMetricsContainer +from danswer.search.models import SearchQuery +from danswer.search.models import SearchType +from danswer.search.search_nlp_models import CrossEncoderEnsembleModel +from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks +from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import FunctionCall +from danswer.utils.threadpool_concurrency import run_functions_in_parallel +from danswer.utils.timing import log_function_time + + +logger = setup_logger() + + +def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None: + top_links = [ + c.source_links[0] if c.source_links is not None else "No Link" for c in chunks + ] + logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") + + +def should_rerank(query: SearchQuery) -> bool: + # Don't re-rank for keyword search + return query.search_type != SearchType.KEYWORD and not query.skip_rerank + + +def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool: + return not query.skip_llm_chunk_filter + + +@log_function_time(print_only=True) +def semantic_reranking( + query: str, + chunks: list[InferenceChunk], + model_min: int = CROSS_ENCODER_RANGE_MIN, + model_max: int = CROSS_ENCODER_RANGE_MAX, + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> tuple[list[InferenceChunk], list[int]]: + """Reranks chunks based on cross-encoder models. Additionally provides the original indices + of the chunks in their new sorted order. + + Note: this updates the chunks in place, it updates the chunk scores which came from retrieval + """ + cross_encoders = CrossEncoderEnsembleModel() + passages = [chunk.content for chunk in chunks] + sim_scores_floats = cross_encoders.predict(query=query, passages=passages) + + sim_scores = [numpy.array(scores) for scores in sim_scores_floats] + + raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores)) + + cross_models_min = numpy.min(sim_scores) + + shifted_sim_scores = sum( + [enc_n_scores - cross_models_min for enc_n_scores in sim_scores] + ) / len(sim_scores) + + boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] + recency_multiplier = [chunk.recency_bias for chunk in chunks] + boosted_sim_scores = shifted_sim_scores * boosts * recency_multiplier + normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / ( + model_max - model_min + ) + orig_indices = [i for i in range(len(normalized_b_s_scores))] + scored_results = list( + zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices) + ) + scored_results.sort(key=lambda x: x[0], reverse=True) + ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip( + *scored_results + ) + + logger.debug( + f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}" + ) + + # Assign new chunk scores based on reranking + for ind, chunk in enumerate(ranked_chunks): + chunk.score = ranked_sim_scores[ind] + + if rerank_metrics_callback is not None: + chunk_metrics = [ + ChunkMetric( + document_id=chunk.document_id, + chunk_content_start=chunk.content[:MAX_METRICS_CONTENT], + first_link=chunk.source_links[0] if chunk.source_links else None, + score=chunk.score if chunk.score is not None else 0, + ) + for chunk in ranked_chunks + ] + + rerank_metrics_callback( + RerankMetricsContainer( + metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores # type: ignore + ) + ) + + return list(ranked_chunks), list(ranked_indices) + + +def rerank_chunks( + query: SearchQuery, + chunks_to_rerank: list[InferenceChunk], + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> list[InferenceChunk]: + ranked_chunks, _ = semantic_reranking( + query=query.query, + chunks=chunks_to_rerank[: query.num_rerank], + rerank_metrics_callback=rerank_metrics_callback, + ) + lower_chunks = chunks_to_rerank[query.num_rerank :] + # Scores from rerank cannot be meaningfully combined with scores without rerank + for lower_chunk in lower_chunks: + lower_chunk.score = None + ranked_chunks.extend(lower_chunks) + return ranked_chunks + + +@log_function_time(print_only=True) +def filter_chunks( + query: SearchQuery, + chunks_to_filter: list[InferenceChunk], +) -> list[str]: + """Filters chunks based on whether the LLM thought they were relevant to the query. + + Returns a list of the unique chunk IDs that were marked as relevant""" + chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks] + llm_chunk_selection = llm_batch_eval_chunks( + query=query.query, + chunk_contents=[chunk.content for chunk in chunks_to_filter], + ) + return [ + chunk.unique_id + for ind, chunk in enumerate(chunks_to_filter) + if llm_chunk_selection[ind] + ] + + +def search_postprocessing( + search_query: SearchQuery, + retrieved_chunks: list[InferenceChunk], + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> Generator[list[InferenceChunk] | list[str], None, None]: + post_processing_tasks: list[FunctionCall] = [] + + rerank_task_id = None + if should_rerank(search_query): + post_processing_tasks.append( + FunctionCall( + rerank_chunks, + ( + search_query, + retrieved_chunks, + rerank_metrics_callback, + ), + ) + ) + rerank_task_id = post_processing_tasks[-1].result_id + else: + final_chunks = retrieved_chunks + # NOTE: if we don't rerank, we can return the chunks immediately + # since we know this is the final order + _log_top_chunk_links(search_query.search_type.value, final_chunks) + yield final_chunks + chunks_yielded = True + + llm_filter_task_id = None + if should_apply_llm_based_relevance_filter(search_query): + post_processing_tasks.append( + FunctionCall( + filter_chunks, + (search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]), + ) + ) + llm_filter_task_id = post_processing_tasks[-1].result_id + + post_processing_results = ( + run_functions_in_parallel(post_processing_tasks) + if post_processing_tasks + else {} + ) + reranked_chunks = cast( + list[InferenceChunk] | None, + post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None, + ) + if reranked_chunks: + if chunks_yielded: + logger.error( + "Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen." + ) + else: + _log_top_chunk_links(search_query.search_type.value, reranked_chunks) + yield reranked_chunks + + llm_chunk_selection = cast( + list[str] | None, + post_processing_results.get(str(llm_filter_task_id)) + if llm_filter_task_id + else None, + ) + if llm_chunk_selection is not None: + yield [ + chunk.unique_id + for chunk in reranked_chunks or retrieved_chunks + if chunk.unique_id in llm_chunk_selection + ] + else: + yield [] diff --git a/backend/danswer/search/access_filters.py b/backend/danswer/search/preprocessing/access_filters.py similarity index 100% rename from backend/danswer/search/access_filters.py rename to backend/danswer/search/preprocessing/access_filters.py diff --git a/backend/danswer/search/danswer_helper.py b/backend/danswer/search/preprocessing/danswer_helper.py similarity index 96% rename from backend/danswer/search/danswer_helper.py rename to backend/danswer/search/preprocessing/danswer_helper.py index d5dbeb8a3..88e465dac 100644 --- a/backend/danswer/search/danswer_helper.py +++ b/backend/danswer/search/preprocessing/danswer_helper.py @@ -1,10 +1,10 @@ from typing import TYPE_CHECKING -from danswer.search.models import QueryFlow +from danswer.search.enums import QueryFlow from danswer.search.models import SearchType +from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation from danswer.search.search_nlp_models import get_default_tokenizer from danswer.search.search_nlp_models import IntentModel -from danswer.search.search_runner import remove_stop_words_and_punctuation from danswer.server.query_and_chat.models import HelperResponse from danswer.utils.logger import setup_logger diff --git a/backend/danswer/search/request_preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py similarity index 76% rename from backend/danswer/search/request_preprocessing.py rename to backend/danswer/search/preprocessing/preprocessing.py index e74618d39..f35afe438 100644 --- a/backend/danswer/search/request_preprocessing.py +++ b/backend/danswer/search/preprocessing/preprocessing.py @@ -5,19 +5,16 @@ from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER from danswer.configs.chat_configs import NUM_RETURNED_HITS -from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW -from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW -from danswer.db.models import Persona from danswer.db.models import User -from danswer.search.access_filters import build_access_filters_for_user -from danswer.search.danswer_helper import query_intent +from danswer.search.enums import QueryFlow +from danswer.search.enums import RecencyBiasSetting from danswer.search.models import BaseFilters from danswer.search.models import IndexFilters -from danswer.search.models import QueryFlow -from danswer.search.models import RecencyBiasSetting -from danswer.search.models import RetrievalDetails from danswer.search.models import SearchQuery +from danswer.search.models import SearchRequest from danswer.search.models import SearchType +from danswer.search.preprocessing.access_filters import build_access_filters_for_user +from danswer.search.preprocessing.danswer_helper import query_intent from danswer.secondary_llm_flows.source_filter import extract_source_filter from danswer.secondary_llm_flows.time_filter import extract_time_filter from danswer.utils.logger import setup_logger @@ -31,15 +28,12 @@ logger = setup_logger() @log_function_time(print_only=True) def retrieval_preprocessing( - query: str, - retrieval_details: RetrievalDetails, - persona: Persona, + search_request: SearchRequest, user: User | None, db_session: Session, bypass_acl: bool = False, include_query_intent: bool = True, - skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW, - skip_rerank_non_realtime: bool = not ENABLE_RERANKING_ASYNC_FLOW, + enable_auto_detect_filters: bool = False, disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION, disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, base_recency_decay: float = BASE_RECENCY_DECAY, @@ -50,8 +44,12 @@ def retrieval_preprocessing( Then any filters or settings as part of the query are used Then defaults to Persona settings if not specified by the query """ + query = search_request.query + limit = search_request.limit + offset = search_request.offset + persona = search_request.persona - preset_filters = retrieval_details.filters or BaseFilters() + preset_filters = search_request.human_selected_filters or BaseFilters() if persona and persona.document_sets and preset_filters.document_set is None: preset_filters.document_set = [ document_set.name for document_set in persona.document_sets @@ -65,16 +63,20 @@ def retrieval_preprocessing( if disable_llm_filter_extraction: auto_detect_time_filter = False auto_detect_source_filter = False - elif retrieval_details.enable_auto_detect_filters is False: + elif enable_auto_detect_filters is False: logger.debug("Retrieval details disables auto detect filters") auto_detect_time_filter = False auto_detect_source_filter = False - elif persona.llm_filter_extraction is False: + elif persona and persona.llm_filter_extraction is False: logger.debug("Persona disables auto detect filters") auto_detect_time_filter = False auto_detect_source_filter = False - if time_filter is not None and persona.recency_bias != RecencyBiasSetting.AUTO: + if ( + time_filter is not None + and persona + and persona.recency_bias != RecencyBiasSetting.AUTO + ): auto_detect_time_filter = False logger.debug("Not extract time filter - already provided") if source_filter is not None: @@ -138,24 +140,18 @@ def retrieval_preprocessing( access_control_list=user_acl_filters, ) - # Tranformer-based re-ranking to run at same time as LLM chunk relevance filter - # This one is only set globally, not via query or Persona settings - skip_reranking = ( - skip_rerank_realtime - if retrieval_details.real_time - else skip_rerank_non_realtime - ) - - llm_chunk_filter = persona.llm_relevance_filter + llm_chunk_filter = False + if persona: + llm_chunk_filter = persona.llm_relevance_filter if disable_llm_chunk_filter: llm_chunk_filter = False # Decays at 1 / (1 + (multiplier * num years)) - if persona.recency_bias == RecencyBiasSetting.NO_DECAY: + if persona and persona.recency_bias == RecencyBiasSetting.NO_DECAY: recency_bias_multiplier = 0.0 - elif persona.recency_bias == RecencyBiasSetting.BASE_DECAY: + elif persona and persona.recency_bias == RecencyBiasSetting.BASE_DECAY: recency_bias_multiplier = base_recency_decay - elif persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT: + elif persona and persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT: recency_bias_multiplier = base_recency_decay * favor_recent_decay_multiplier else: if predicted_favor_recent: @@ -166,14 +162,12 @@ def retrieval_preprocessing( return ( SearchQuery( query=query, - search_type=persona.search_type, + search_type=persona.search_type if persona else SearchType.HYBRID, filters=final_filters, recency_bias_multiplier=recency_bias_multiplier, - num_hits=retrieval_details.limit - if retrieval_details.limit is not None - else NUM_RETURNED_HITS, - offset=retrieval_details.offset or 0, - skip_rerank=skip_reranking, + num_hits=limit if limit is not None else NUM_RETURNED_HITS, + offset=offset or 0, + skip_rerank=search_request.skip_rerank, skip_llm_chunk_filter=not llm_chunk_filter, ), predicted_search_type, diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py new file mode 100644 index 000000000..3dff76d96 --- /dev/null +++ b/backend/danswer/search/retrieval/search_runner.py @@ -0,0 +1,256 @@ +import string +from collections.abc import Callable + +from nltk.corpus import stopwords # type:ignore +from nltk.stem import WordNetLemmatizer # type:ignore +from nltk.tokenize import word_tokenize # type:ignore +from sqlalchemy.orm import Session + +from danswer.chat.models import LlmDoc +from danswer.configs.app_configs import MODEL_SERVER_HOST +from danswer.configs.app_configs import MODEL_SERVER_PORT +from danswer.configs.chat_configs import HYBRID_ALPHA +from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.document_index.interfaces import DocumentIndex +from danswer.indexing.models import InferenceChunk +from danswer.search.models import ChunkMetric +from danswer.search.models import IndexFilters +from danswer.search.models import MAX_METRICS_CONTENT +from danswer.search.models import RetrievalMetricsContainer +from danswer.search.models import SearchQuery +from danswer.search.models import SearchType +from danswer.search.search_nlp_models import EmbeddingModel +from danswer.search.search_nlp_models import EmbedTextType +from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion +from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel +from danswer.utils.timing import log_function_time + + +logger = setup_logger() + + +def lemmatize_text(text: str) -> list[str]: + lemmatizer = WordNetLemmatizer() + word_tokens = word_tokenize(text) + return [lemmatizer.lemmatize(word) for word in word_tokens] + + +def remove_stop_words_and_punctuation(text: str) -> list[str]: + stop_words = set(stopwords.words("english")) + word_tokens = word_tokenize(text) + text_trimmed = [ + word + for word in word_tokens + if (word.casefold() not in stop_words and word not in string.punctuation) + ] + return text_trimmed or word_tokens + + +def query_processing( + query: str, +) -> str: + query = " ".join(remove_stop_words_and_punctuation(query)) + query = " ".join(lemmatize_text(query)) + return query + + +def combine_retrieval_results( + chunk_sets: list[list[InferenceChunk]], +) -> list[InferenceChunk]: + all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set] + + unique_chunks: dict[tuple[str, int], InferenceChunk] = {} + for chunk in all_chunks: + key = (chunk.document_id, chunk.chunk_id) + if key not in unique_chunks: + unique_chunks[key] = chunk + continue + + stored_chunk_score = unique_chunks[key].score or 0 + this_chunk_score = chunk.score or 0 + if stored_chunk_score < this_chunk_score: + unique_chunks[key] = chunk + + sorted_chunks = sorted( + unique_chunks.values(), key=lambda x: x.score or 0, reverse=True + ) + + return sorted_chunks + + +@log_function_time(print_only=True) +def doc_index_retrieval( + query: SearchQuery, + document_index: DocumentIndex, + db_session: Session, + hybrid_alpha: float = HYBRID_ALPHA, +) -> list[InferenceChunk]: + if query.search_type == SearchType.KEYWORD: + top_chunks = document_index.keyword_retrieval( + query=query.query, + filters=query.filters, + time_decay_multiplier=query.recency_bias_multiplier, + num_to_retrieve=query.num_hits, + ) + else: + db_embedding_model = get_current_db_embedding_model(db_session) + + model = EmbeddingModel( + model_name=db_embedding_model.model_name, + query_prefix=db_embedding_model.query_prefix, + passage_prefix=db_embedding_model.passage_prefix, + normalize=db_embedding_model.normalize, + # The below are globally set, this flow always uses the indexing one + server_host=MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + ) + + query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0] + + if query.search_type == SearchType.SEMANTIC: + top_chunks = document_index.semantic_retrieval( + query=query.query, + query_embedding=query_embedding, + filters=query.filters, + time_decay_multiplier=query.recency_bias_multiplier, + num_to_retrieve=query.num_hits, + ) + + elif query.search_type == SearchType.HYBRID: + top_chunks = document_index.hybrid_retrieval( + query=query.query, + query_embedding=query_embedding, + filters=query.filters, + time_decay_multiplier=query.recency_bias_multiplier, + num_to_retrieve=query.num_hits, + offset=query.offset, + hybrid_alpha=hybrid_alpha, + ) + + else: + raise RuntimeError("Invalid Search Flow") + + return top_chunks + + +def _simplify_text(text: str) -> str: + return "".join( + char for char in text if char not in string.punctuation and not char.isspace() + ).lower() + + +def retrieve_chunks( + query: SearchQuery, + document_index: DocumentIndex, + db_session: Session, + hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search + multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, + retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] + | None = None, +) -> list[InferenceChunk]: + """Returns a list of the best chunks from an initial keyword/semantic/ hybrid search.""" + # Don't do query expansion on complex queries, rephrasings likely would not work well + if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query: + top_chunks = doc_index_retrieval( + query=query, + document_index=document_index, + db_session=db_session, + hybrid_alpha=hybrid_alpha, + ) + else: + simplified_queries = set() + run_queries: list[tuple[Callable, tuple]] = [] + + # Currently only uses query expansion on multilingual use cases + query_rephrases = multilingual_query_expansion( + query.query, multilingual_expansion_str + ) + # Just to be extra sure, add the original query. + query_rephrases.append(query.query) + for rephrase in set(query_rephrases): + # Sometimes the model rephrases the query in the same language with minor changes + # Avoid doing an extra search with the minor changes as this biases the results + simplified_rephrase = _simplify_text(rephrase) + if simplified_rephrase in simplified_queries: + continue + simplified_queries.add(simplified_rephrase) + + q_copy = query.copy(update={"query": rephrase}, deep=True) + run_queries.append( + ( + doc_index_retrieval, + (q_copy, document_index, db_session, hybrid_alpha), + ) + ) + parallel_search_results = run_functions_tuples_in_parallel(run_queries) + top_chunks = combine_retrieval_results(parallel_search_results) + + if not top_chunks: + logger.info( + f"{query.search_type.value.capitalize()} search returned no results " + f"with filters: {query.filters}" + ) + return [] + + if retrieval_metrics_callback is not None: + chunk_metrics = [ + ChunkMetric( + document_id=chunk.document_id, + chunk_content_start=chunk.content[:MAX_METRICS_CONTENT], + first_link=chunk.source_links[0] if chunk.source_links else None, + score=chunk.score if chunk.score is not None else 0, + ) + for chunk in top_chunks + ] + retrieval_metrics_callback( + RetrievalMetricsContainer( + search_type=query.search_type, metrics=chunk_metrics + ) + ) + + return top_chunks + + +def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc: + if not inf_chunks: + raise ValueError("Cannot combine empty list of chunks") + + # Use the first link of the document + first_chunk = inf_chunks[0] + chunk_texts = [chunk.content for chunk in inf_chunks] + return LlmDoc( + document_id=first_chunk.document_id, + content="\n".join(chunk_texts), + semantic_identifier=first_chunk.semantic_identifier, + source_type=first_chunk.source_type, + metadata=first_chunk.metadata, + updated_at=first_chunk.updated_at, + link=first_chunk.source_links[0] if first_chunk.source_links else None, + ) + + +def inference_documents_from_ids( + doc_identifiers: list[tuple[str, int]], + document_index: DocumentIndex, +) -> list[LlmDoc]: + # Currently only fetches whole docs + doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers) + + # No need for ACL here because the doc ids were validated beforehand + filters = IndexFilters(access_control_list=None) + + functions_with_args: list[tuple[Callable, tuple]] = [ + (document_index.id_based_retrieval, (doc_id, None, filters)) + for doc_id in doc_ids_set + ] + + parallel_results = run_functions_tuples_in_parallel( + functions_with_args, allow_failures=True + ) + + # Any failures to retrieve would give a None, drop the Nones and empty lists + inference_chunks_sets = [res for res in parallel_results if res] + + return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets] diff --git a/backend/danswer/search/search_runner.py b/backend/danswer/search/search_runner.py deleted file mode 100644 index 18bfa1a3c..000000000 --- a/backend/danswer/search/search_runner.py +++ /dev/null @@ -1,645 +0,0 @@ -import string -from collections.abc import Callable -from collections.abc import Iterator -from typing import cast - -import numpy -from nltk.corpus import stopwords # type:ignore -from nltk.stem import WordNetLemmatizer # type:ignore -from nltk.tokenize import word_tokenize # type:ignore -from sqlalchemy.orm import Session - -from danswer.chat.models import LlmDoc -from danswer.configs.app_configs import MODEL_SERVER_HOST -from danswer.configs.app_configs import MODEL_SERVER_PORT -from danswer.configs.chat_configs import HYBRID_ALPHA -from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION -from danswer.configs.chat_configs import NUM_RERANKED_RESULTS -from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX -from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN -from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH -from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW -from danswer.db.embedding_model import get_current_db_embedding_model -from danswer.document_index.document_index_utils import ( - translate_boost_count_to_multiplier, -) -from danswer.document_index.interfaces import DocumentIndex -from danswer.indexing.models import InferenceChunk -from danswer.search.models import ChunkMetric -from danswer.search.models import IndexFilters -from danswer.search.models import MAX_METRICS_CONTENT -from danswer.search.models import RerankMetricsContainer -from danswer.search.models import RetrievalMetricsContainer -from danswer.search.models import SearchDoc -from danswer.search.models import SearchQuery -from danswer.search.models import SearchType -from danswer.search.search_nlp_models import CrossEncoderEnsembleModel -from danswer.search.search_nlp_models import EmbeddingModel -from danswer.search.search_nlp_models import EmbedTextType -from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks -from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion -from danswer.utils.logger import setup_logger -from danswer.utils.threadpool_concurrency import FunctionCall -from danswer.utils.threadpool_concurrency import run_functions_in_parallel -from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel -from danswer.utils.timing import log_function_time - - -logger = setup_logger() - - -def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None: - top_links = [ - c.source_links[0] if c.source_links is not None else "No Link" for c in chunks - ] - logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") - - -def lemmatize_text(text: str) -> list[str]: - lemmatizer = WordNetLemmatizer() - word_tokens = word_tokenize(text) - return [lemmatizer.lemmatize(word) for word in word_tokens] - - -def remove_stop_words_and_punctuation(text: str) -> list[str]: - stop_words = set(stopwords.words("english")) - word_tokens = word_tokenize(text) - text_trimmed = [ - word - for word in word_tokens - if (word.casefold() not in stop_words and word not in string.punctuation) - ] - return text_trimmed or word_tokens - - -def query_processing( - query: str, -) -> str: - query = " ".join(remove_stop_words_and_punctuation(query)) - query = " ".join(lemmatize_text(query)) - return query - - -def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]: - search_docs = ( - [ - SearchDoc( - document_id=chunk.document_id, - chunk_ind=chunk.chunk_id, - semantic_identifier=chunk.semantic_identifier or "Unknown", - link=chunk.source_links.get(0) if chunk.source_links else None, - blurb=chunk.blurb, - source_type=chunk.source_type, - boost=chunk.boost, - hidden=chunk.hidden, - metadata=chunk.metadata, - score=chunk.score, - match_highlights=chunk.match_highlights, - updated_at=chunk.updated_at, - primary_owners=chunk.primary_owners, - secondary_owners=chunk.secondary_owners, - ) - for chunk in chunks - ] - if chunks - else [] - ) - return search_docs - - -def combine_retrieval_results( - chunk_sets: list[list[InferenceChunk]], -) -> list[InferenceChunk]: - all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set] - - unique_chunks: dict[tuple[str, int], InferenceChunk] = {} - for chunk in all_chunks: - key = (chunk.document_id, chunk.chunk_id) - if key not in unique_chunks: - unique_chunks[key] = chunk - continue - - stored_chunk_score = unique_chunks[key].score or 0 - this_chunk_score = chunk.score or 0 - if stored_chunk_score < this_chunk_score: - unique_chunks[key] = chunk - - sorted_chunks = sorted( - unique_chunks.values(), key=lambda x: x.score or 0, reverse=True - ) - - return sorted_chunks - - -@log_function_time(print_only=True) -def doc_index_retrieval( - query: SearchQuery, - document_index: DocumentIndex, - db_session: Session, - hybrid_alpha: float = HYBRID_ALPHA, -) -> list[InferenceChunk]: - if query.search_type == SearchType.KEYWORD: - top_chunks = document_index.keyword_retrieval( - query=query.query, - filters=query.filters, - time_decay_multiplier=query.recency_bias_multiplier, - num_to_retrieve=query.num_hits, - ) - else: - db_embedding_model = get_current_db_embedding_model(db_session) - - model = EmbeddingModel( - model_name=db_embedding_model.model_name, - query_prefix=db_embedding_model.query_prefix, - passage_prefix=db_embedding_model.passage_prefix, - normalize=db_embedding_model.normalize, - # The below are globally set, this flow always uses the indexing one - server_host=MODEL_SERVER_HOST, - server_port=MODEL_SERVER_PORT, - ) - - query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0] - - if query.search_type == SearchType.SEMANTIC: - top_chunks = document_index.semantic_retrieval( - query=query.query, - query_embedding=query_embedding, - filters=query.filters, - time_decay_multiplier=query.recency_bias_multiplier, - num_to_retrieve=query.num_hits, - ) - - elif query.search_type == SearchType.HYBRID: - top_chunks = document_index.hybrid_retrieval( - query=query.query, - query_embedding=query_embedding, - filters=query.filters, - time_decay_multiplier=query.recency_bias_multiplier, - num_to_retrieve=query.num_hits, - offset=query.offset, - hybrid_alpha=hybrid_alpha, - ) - - else: - raise RuntimeError("Invalid Search Flow") - - return top_chunks - - -@log_function_time(print_only=True) -def semantic_reranking( - query: str, - chunks: list[InferenceChunk], - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, - model_min: int = CROSS_ENCODER_RANGE_MIN, - model_max: int = CROSS_ENCODER_RANGE_MAX, -) -> tuple[list[InferenceChunk], list[int]]: - """Reranks chunks based on cross-encoder models. Additionally provides the original indices - of the chunks in their new sorted order. - - Note: this updates the chunks in place, it updates the chunk scores which came from retrieval - """ - cross_encoders = CrossEncoderEnsembleModel() - passages = [chunk.content for chunk in chunks] - sim_scores_floats = cross_encoders.predict(query=query, passages=passages) - - sim_scores = [numpy.array(scores) for scores in sim_scores_floats] - - raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores)) - - cross_models_min = numpy.min(sim_scores) - - shifted_sim_scores = sum( - [enc_n_scores - cross_models_min for enc_n_scores in sim_scores] - ) / len(sim_scores) - - boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] - recency_multiplier = [chunk.recency_bias for chunk in chunks] - boosted_sim_scores = shifted_sim_scores * boosts * recency_multiplier - normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / ( - model_max - model_min - ) - orig_indices = [i for i in range(len(normalized_b_s_scores))] - scored_results = list( - zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices) - ) - scored_results.sort(key=lambda x: x[0], reverse=True) - ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip( - *scored_results - ) - - logger.debug( - f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}" - ) - - # Assign new chunk scores based on reranking - for ind, chunk in enumerate(ranked_chunks): - chunk.score = ranked_sim_scores[ind] - - if rerank_metrics_callback is not None: - chunk_metrics = [ - ChunkMetric( - document_id=chunk.document_id, - chunk_content_start=chunk.content[:MAX_METRICS_CONTENT], - first_link=chunk.source_links[0] if chunk.source_links else None, - score=chunk.score if chunk.score is not None else 0, - ) - for chunk in ranked_chunks - ] - - rerank_metrics_callback( - RerankMetricsContainer( - metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores # type: ignore - ) - ) - - return list(ranked_chunks), list(ranked_indices) - - -def apply_boost_legacy( - chunks: list[InferenceChunk], - norm_min: float = SIM_SCORE_RANGE_LOW, - norm_max: float = SIM_SCORE_RANGE_HIGH, -) -> list[InferenceChunk]: - scores = [chunk.score or 0 for chunk in chunks] - boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] - - logger.debug(f"Raw similarity scores: {scores}") - - score_min = min(scores) - score_max = max(scores) - score_range = score_max - score_min - - if score_range != 0: - boosted_scores = [ - ((score - score_min) / score_range) * boost - for score, boost in zip(scores, boosts) - ] - unnormed_boosted_scores = [ - score * score_range + score_min for score in boosted_scores - ] - else: - unnormed_boosted_scores = [ - score * boost for score, boost in zip(scores, boosts) - ] - - norm_min = min(norm_min, min(scores)) - norm_max = max(norm_max, max(scores)) - # This should never be 0 unless user has done some weird/wrong settings - norm_range = norm_max - norm_min - - # For score display purposes - if norm_range != 0: - re_normed_scores = [ - ((score - norm_min) / norm_range) for score in unnormed_boosted_scores - ] - else: - re_normed_scores = unnormed_boosted_scores - - rescored_chunks = list(zip(re_normed_scores, chunks)) - rescored_chunks.sort(key=lambda x: x[0], reverse=True) - sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks) - - final_chunks = list(boost_sorted_chunks) - final_scores = list(sorted_boosted_scores) - for ind, chunk in enumerate(final_chunks): - chunk.score = final_scores[ind] - - logger.debug(f"Boost sorted similary scores: {list(final_scores)}") - - return final_chunks - - -def apply_boost( - chunks: list[InferenceChunk], - # Need the range of values to not be too spread out for applying boost - # therefore norm across only the top few results - norm_cutoff: int = NUM_RERANKED_RESULTS, - norm_min: float = SIM_SCORE_RANGE_LOW, - norm_max: float = SIM_SCORE_RANGE_HIGH, -) -> list[InferenceChunk]: - scores = [chunk.score or 0.0 for chunk in chunks] - logger.debug(f"Raw similarity scores: {scores}") - - boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] - recency_multiplier = [chunk.recency_bias for chunk in chunks] - - norm_min = min(norm_min, min(scores[:norm_cutoff])) - norm_max = max(norm_max, max(scores[:norm_cutoff])) - # This should never be 0 unless user has done some weird/wrong settings - norm_range = norm_max - norm_min - - boosted_scores = [ - max(0, (score - norm_min) * boost * recency / norm_range) - for score, boost, recency in zip(scores, boosts, recency_multiplier) - ] - - rescored_chunks = list(zip(boosted_scores, chunks)) - rescored_chunks.sort(key=lambda x: x[0], reverse=True) - sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks) - - final_chunks = list(boost_sorted_chunks) - final_scores = list(sorted_boosted_scores) - for ind, chunk in enumerate(final_chunks): - chunk.score = final_scores[ind] - - logger.debug( - f"Boosted + Time Weighted sorted similarity scores: {list(final_scores)}" - ) - - return final_chunks - - -def _simplify_text(text: str) -> str: - return "".join( - char for char in text if char not in string.punctuation and not char.isspace() - ).lower() - - -def retrieve_chunks( - query: SearchQuery, - document_index: DocumentIndex, - db_session: Session, - hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search - multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, - retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] - | None = None, -) -> list[InferenceChunk]: - """Returns a list of the best chunks from an initial keyword/semantic/ hybrid search.""" - # Don't do query expansion on complex queries, rephrasings likely would not work well - if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query: - top_chunks = doc_index_retrieval( - query=query, - document_index=document_index, - db_session=db_session, - hybrid_alpha=hybrid_alpha, - ) - else: - simplified_queries = set() - run_queries: list[tuple[Callable, tuple]] = [] - - # Currently only uses query expansion on multilingual use cases - query_rephrases = multilingual_query_expansion( - query.query, multilingual_expansion_str - ) - # Just to be extra sure, add the original query. - query_rephrases.append(query.query) - for rephrase in set(query_rephrases): - # Sometimes the model rephrases the query in the same language with minor changes - # Avoid doing an extra search with the minor changes as this biases the results - simplified_rephrase = _simplify_text(rephrase) - if simplified_rephrase in simplified_queries: - continue - simplified_queries.add(simplified_rephrase) - - q_copy = query.copy(update={"query": rephrase}, deep=True) - run_queries.append( - ( - doc_index_retrieval, - (q_copy, document_index, db_session, hybrid_alpha), - ) - ) - parallel_search_results = run_functions_tuples_in_parallel(run_queries) - top_chunks = combine_retrieval_results(parallel_search_results) - - if not top_chunks: - logger.info( - f"{query.search_type.value.capitalize()} search returned no results " - f"with filters: {query.filters}" - ) - return [] - - if retrieval_metrics_callback is not None: - chunk_metrics = [ - ChunkMetric( - document_id=chunk.document_id, - chunk_content_start=chunk.content[:MAX_METRICS_CONTENT], - first_link=chunk.source_links[0] if chunk.source_links else None, - score=chunk.score if chunk.score is not None else 0, - ) - for chunk in top_chunks - ] - retrieval_metrics_callback( - RetrievalMetricsContainer( - search_type=query.search_type, metrics=chunk_metrics - ) - ) - - return top_chunks - - -def should_rerank(query: SearchQuery) -> bool: - # Don't re-rank for keyword search - return query.search_type != SearchType.KEYWORD and not query.skip_rerank - - -def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool: - return not query.skip_llm_chunk_filter - - -def rerank_chunks( - query: SearchQuery, - chunks_to_rerank: list[InferenceChunk], - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> list[InferenceChunk]: - ranked_chunks, _ = semantic_reranking( - query=query.query, - chunks=chunks_to_rerank[: query.num_rerank], - rerank_metrics_callback=rerank_metrics_callback, - ) - lower_chunks = chunks_to_rerank[query.num_rerank :] - # Scores from rerank cannot be meaningfully combined with scores without rerank - for lower_chunk in lower_chunks: - lower_chunk.score = None - ranked_chunks.extend(lower_chunks) - return ranked_chunks - - -@log_function_time(print_only=True) -def filter_chunks( - query: SearchQuery, - chunks_to_filter: list[InferenceChunk], -) -> list[str]: - """Filters chunks based on whether the LLM thought they were relevant to the query. - - Returns a list of the unique chunk IDs that were marked as relevant""" - chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks] - llm_chunk_selection = llm_batch_eval_chunks( - query=query.query, - chunk_contents=[chunk.content for chunk in chunks_to_filter], - ) - return [ - chunk.unique_id - for ind, chunk in enumerate(chunks_to_filter) - if llm_chunk_selection[ind] - ] - - -def full_chunk_search( - query: SearchQuery, - document_index: DocumentIndex, - db_session: Session, - hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search - multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, - retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] - | None = None, - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> tuple[list[InferenceChunk], list[bool]]: - """A utility which provides an easier interface than `full_chunk_search_generator`. - Rather than returning the chunks and llm relevance filter results in two separate - yields, just returns them both at once.""" - search_generator = full_chunk_search_generator( - search_query=query, - document_index=document_index, - db_session=db_session, - hybrid_alpha=hybrid_alpha, - multilingual_expansion_str=multilingual_expansion_str, - retrieval_metrics_callback=retrieval_metrics_callback, - rerank_metrics_callback=rerank_metrics_callback, - ) - top_chunks = cast(list[InferenceChunk], next(search_generator)) - llm_chunk_selection = cast(list[bool], next(search_generator)) - return top_chunks, llm_chunk_selection - - -def empty_search_generator() -> Iterator[list[InferenceChunk] | list[bool]]: - yield cast(list[InferenceChunk], []) - yield cast(list[bool], []) - - -def full_chunk_search_generator( - search_query: SearchQuery, - document_index: DocumentIndex, - db_session: Session, - hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search - multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, - retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] - | None = None, - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> Iterator[list[InferenceChunk] | list[bool]]: - """Always yields twice. Once with the selected chunks and once with the LLM relevance filter result. - If LLM filter results are turned off, returns a list of False - """ - chunks_yielded = False - - retrieved_chunks = retrieve_chunks( - query=search_query, - document_index=document_index, - db_session=db_session, - hybrid_alpha=hybrid_alpha, - multilingual_expansion_str=multilingual_expansion_str, - retrieval_metrics_callback=retrieval_metrics_callback, - ) - - if not retrieved_chunks: - yield cast(list[InferenceChunk], []) - yield cast(list[bool], []) - return - - post_processing_tasks: list[FunctionCall] = [] - - rerank_task_id = None - if should_rerank(search_query): - post_processing_tasks.append( - FunctionCall( - rerank_chunks, - ( - search_query, - retrieved_chunks, - rerank_metrics_callback, - ), - ) - ) - rerank_task_id = post_processing_tasks[-1].result_id - else: - final_chunks = retrieved_chunks - # NOTE: if we don't rerank, we can return the chunks immediately - # since we know this is the final order - _log_top_chunk_links(search_query.search_type.value, final_chunks) - yield final_chunks - chunks_yielded = True - - llm_filter_task_id = None - if should_apply_llm_based_relevance_filter(search_query): - post_processing_tasks.append( - FunctionCall( - filter_chunks, - (search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]), - ) - ) - llm_filter_task_id = post_processing_tasks[-1].result_id - - post_processing_results = ( - run_functions_in_parallel(post_processing_tasks) - if post_processing_tasks - else {} - ) - reranked_chunks = cast( - list[InferenceChunk] | None, - post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None, - ) - if reranked_chunks: - if chunks_yielded: - logger.error( - "Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen." - ) - else: - _log_top_chunk_links(search_query.search_type.value, reranked_chunks) - yield reranked_chunks - - llm_chunk_selection = cast( - list[str] | None, - post_processing_results.get(str(llm_filter_task_id)) - if llm_filter_task_id - else None, - ) - if llm_chunk_selection is not None: - yield [ - chunk.unique_id in llm_chunk_selection - for chunk in reranked_chunks or retrieved_chunks - ] - else: - yield [False for _ in reranked_chunks or retrieved_chunks] - - -def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc: - if not inf_chunks: - raise ValueError("Cannot combine empty list of chunks") - - # Use the first link of the document - first_chunk = inf_chunks[0] - chunk_texts = [chunk.content for chunk in inf_chunks] - return LlmDoc( - document_id=first_chunk.document_id, - content="\n".join(chunk_texts), - semantic_identifier=first_chunk.semantic_identifier, - source_type=first_chunk.source_type, - metadata=first_chunk.metadata, - updated_at=first_chunk.updated_at, - link=first_chunk.source_links[0] if first_chunk.source_links else None, - ) - - -def inference_documents_from_ids( - doc_identifiers: list[tuple[str, int]], - document_index: DocumentIndex, -) -> list[LlmDoc]: - # Currently only fetches whole docs - doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers) - - # No need for ACL here because the doc ids were validated beforehand - filters = IndexFilters(access_control_list=None) - - functions_with_args: list[tuple[Callable, tuple]] = [ - (document_index.id_based_retrieval, (doc_id, None, filters)) - for doc_id in doc_ids_set - ] - - parallel_results = run_functions_tuples_in_parallel( - functions_with_args, allow_failures=True - ) - - # Any failures to retrieve would give a None, drop the Nones and empty lists - inference_chunks_sets = [res for res in parallel_results if res] - - return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets] diff --git a/backend/danswer/search/utils.py b/backend/danswer/search/utils.py new file mode 100644 index 000000000..4b01f70eb --- /dev/null +++ b/backend/danswer/search/utils.py @@ -0,0 +1,29 @@ +from danswer.indexing.models import InferenceChunk +from danswer.search.models import SearchDoc + + +def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]: + search_docs = ( + [ + SearchDoc( + document_id=chunk.document_id, + chunk_ind=chunk.chunk_id, + semantic_identifier=chunk.semantic_identifier or "Unknown", + link=chunk.source_links.get(0) if chunk.source_links else None, + blurb=chunk.blurb, + source_type=chunk.source_type, + boost=chunk.boost, + hidden=chunk.hidden, + metadata=chunk.metadata, + score=chunk.score, + match_highlights=chunk.match_highlights, + updated_at=chunk.updated_at, + primary_owners=chunk.primary_owners, + secondary_owners=chunk.secondary_owners, + ) + for chunk in chunks + ] + if chunks + else [] + ) + return search_docs diff --git a/backend/danswer/server/documents/document.py b/backend/danswer/server/documents/document.py index ea080b033..3abab3302 100644 --- a/backend/danswer/server/documents/document.py +++ b/backend/danswer/server/documents/document.py @@ -11,8 +11,8 @@ from danswer.db.models import User from danswer.document_index.factory import get_default_document_index from danswer.llm.utils import get_default_llm_token_encode from danswer.prompts.prompt_utils import build_doc_context_str -from danswer.search.access_filters import build_access_filters_for_user from danswer.search.models import IndexFilters +from danswer.search.preprocessing.access_filters import build_access_filters_for_user from danswer.server.documents.models import ChunkInfo from danswer.server.documents.models import DocumentInfo diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index a724ac5f3..4cc80eec0 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from danswer.db.models import Persona from danswer.db.models import StarterMessage -from danswer.search.models import RecencyBiasSetting +from danswer.search.enums import RecencyBiasSetting from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.prompt.models import PromptSnapshot diff --git a/backend/danswer/server/gpts/api.py b/backend/danswer/server/gpts/api.py index 980003252..bfada9b55 100644 --- a/backend/danswer/server/gpts/api.py +++ b/backend/danswer/server/gpts/api.py @@ -6,13 +6,9 @@ from fastapi import Depends from pydantic import BaseModel from sqlalchemy.orm import Session -from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session -from danswer.document_index.factory import get_default_document_index -from danswer.search.access_filters import build_access_filters_for_user -from danswer.search.models import IndexFilters -from danswer.search.models import SearchQuery -from danswer.search.search_runner import full_chunk_search +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline from danswer.server.danswer_api.ingestion import api_key_dep from danswer.utils.logger import setup_logger @@ -70,27 +66,13 @@ def gpt_search( _: str | None = Depends(api_key_dep), db_session: Session = Depends(get_session), ) -> GptSearchResponse: - query = search_request.query - - user_acl_filters = build_access_filters_for_user(None, db_session) - final_filters = IndexFilters(access_control_list=user_acl_filters) - - search_query = SearchQuery( - query=query, - filters=final_filters, - recency_bias_multiplier=1.0, - skip_llm_chunk_filter=True, - ) - - embedding_model = get_current_db_embedding_model(db_session) - - document_index = get_default_document_index( - primary_index_name=embedding_model.index_name, secondary_index_name=None - ) - - top_chunks, __ = full_chunk_search( - query=search_query, document_index=document_index, db_session=db_session - ) + top_chunks = SearchPipeline( + search_request=SearchRequest( + query=search_request.query, + ), + user=None, + db_session=db_session, + ).reranked_docs return GptSearchResponse( matching_document_chunks=[ diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 6d8529486..5150eb9ce 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -15,11 +15,11 @@ from danswer.document_index.factory import get_default_document_index from danswer.document_index.vespa.index import VespaIndex from danswer.one_shot_answer.answer_question import stream_search_answer from danswer.one_shot_answer.models import DirectQARequest -from danswer.search.access_filters import build_access_filters_for_user -from danswer.search.danswer_helper import recommend_search_flow from danswer.search.models import IndexFilters from danswer.search.models import SearchDoc -from danswer.search.search_runner import chunks_to_search_docs +from danswer.search.preprocessing.access_filters import build_access_filters_for_user +from danswer.search.preprocessing.danswer_helper import recommend_search_flow +from danswer.search.utils import chunks_to_search_docs from danswer.secondary_llm_flows.query_validation import get_query_answerability from danswer.secondary_llm_flows.query_validation import stream_query_answerability from danswer.server.query_and_chat.models import AdminSearchRequest diff --git a/backend/tests/regression/search_quality/eval_search.py b/backend/tests/regression/search_quality/eval_search.py index 7cd3e6068..d40ae1348 100644 --- a/backend/tests/regression/search_quality/eval_search.py +++ b/backend/tests/regression/search_quality/eval_search.py @@ -8,15 +8,12 @@ from typing import TextIO from sqlalchemy.orm import Session from danswer.chat.chat_utils import get_chunks_for_qa -from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_sqlalchemy_engine -from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import InferenceChunk -from danswer.search.models import IndexFilters from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer -from danswer.search.models import SearchQuery -from danswer.search.search_runner import full_chunk_search +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline from danswer.utils.callbacks import MetricsHander @@ -81,35 +78,22 @@ def get_search_results( RetrievalMetricsContainer | None, RerankMetricsContainer | None, ]: - filters = IndexFilters( - source_type=None, - document_set=None, - time_cutoff=None, - access_control_list=None, - ) - search_query = SearchQuery( - query=query, - filters=filters, - recency_bias_multiplier=1.0, - ) - retrieval_metrics = MetricsHander[RetrievalMetricsContainer]() rerank_metrics = MetricsHander[RerankMetricsContainer]() with Session(get_sqlalchemy_engine()) as db_session: - embedding_model = get_current_db_embedding_model(db_session) + search_pipeline = SearchPipeline( + search_request=SearchRequest( + query=query, + ), + user=None, + db_session=db_session, + retrieval_metrics_callback=retrieval_metrics.record_metric, + rerank_metrics_callback=rerank_metrics.record_metric, + ) - document_index = get_default_document_index( - primary_index_name=embedding_model.index_name, secondary_index_name=None - ) - - top_chunks, llm_chunk_selection = full_chunk_search( - query=search_query, - document_index=document_index, - db_session=db_session, - retrieval_metrics_callback=retrieval_metrics.record_metric, - rerank_metrics_callback=rerank_metrics.record_metric, - ) + top_chunks = search_pipeline.reranked_docs + llm_chunk_selection = search_pipeline.chunk_relevance_list llm_chunks_indices = get_chunks_for_qa( chunks=top_chunks,