mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Refactor search pipeline
This commit is contained in:
@@ -9,7 +9,7 @@ from alembic import op
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
from danswer.db.models import IndexModelStatus
|
from danswer.db.models import IndexModelStatus
|
||||||
from danswer.search.models import RecencyBiasSetting
|
from danswer.search.enums import RecencyBiasSetting
|
||||||
from danswer.search.models import SearchType
|
from danswer.search.models import SearchType
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
@@ -13,7 +13,7 @@ from danswer.db.document_set import get_or_create_document_set_by_name
|
|||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||||
from danswer.db.models import Prompt as PromptDBModel
|
from danswer.db.models import Prompt as PromptDBModel
|
||||||
from danswer.search.models import RecencyBiasSetting
|
from danswer.search.enums import RecencyBiasSetting
|
||||||
|
|
||||||
|
|
||||||
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
|
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
|
||||||
|
@@ -5,10 +5,10 @@ from typing import Any
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.search.models import QueryFlow
|
from danswer.search.enums import QueryFlow
|
||||||
|
from danswer.search.enums import SearchType
|
||||||
from danswer.search.models import RetrievalDocs
|
from danswer.search.models import RetrievalDocs
|
||||||
from danswer.search.models import SearchResponse
|
from danswer.search.models import SearchResponse
|
||||||
from danswer.search.models import SearchType
|
|
||||||
|
|
||||||
|
|
||||||
class LlmDoc(BaseModel):
|
class LlmDoc(BaseModel):
|
||||||
|
@@ -53,11 +53,10 @@ from danswer.llm.utils import tokenizer_trim_content
|
|||||||
from danswer.llm.utils import translate_history_to_basemessages
|
from danswer.llm.utils import translate_history_to_basemessages
|
||||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||||
from danswer.search.models import OptionalSearchSetting
|
from danswer.search.models import OptionalSearchSetting
|
||||||
from danswer.search.models import RetrievalDetails
|
from danswer.search.models import SearchRequest
|
||||||
from danswer.search.request_preprocessing import retrieval_preprocessing
|
from danswer.search.pipeline import SearchPipeline
|
||||||
from danswer.search.search_runner import chunks_to_search_docs
|
from danswer.search.retrieval.search_runner import inference_documents_from_ids
|
||||||
from danswer.search.search_runner import full_chunk_search_generator
|
from danswer.search.utils import chunks_to_search_docs
|
||||||
from danswer.search.search_runner import inference_documents_from_ids
|
|
||||||
from danswer.secondary_llm_flows.choose_search import check_if_need_search
|
from danswer.secondary_llm_flows.choose_search import check_if_need_search
|
||||||
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||||
@@ -377,37 +376,25 @@ def stream_chat_message_objects(
|
|||||||
else query_override
|
else query_override
|
||||||
)
|
)
|
||||||
|
|
||||||
(
|
search_pipeline = SearchPipeline(
|
||||||
retrieval_request,
|
search_request=SearchRequest(
|
||||||
predicted_search_type,
|
query=rephrased_query,
|
||||||
predicted_flow,
|
human_selected_filters=retrieval_options.filters
|
||||||
) = retrieval_preprocessing(
|
if retrieval_options
|
||||||
query=rephrased_query,
|
else None,
|
||||||
retrieval_details=cast(RetrievalDetails, retrieval_options),
|
persona=persona,
|
||||||
persona=persona,
|
offset=retrieval_options.offset if retrieval_options else None,
|
||||||
|
limit=retrieval_options.limit if retrieval_options else None,
|
||||||
|
),
|
||||||
user=user,
|
user=user,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
documents_generator = full_chunk_search_generator(
|
top_chunks = search_pipeline.reranked_docs
|
||||||
search_query=retrieval_request,
|
top_docs = chunks_to_search_docs(top_chunks)
|
||||||
document_index=document_index,
|
|
||||||
db_session=db_session,
|
|
||||||
)
|
|
||||||
time_cutoff = retrieval_request.filters.time_cutoff
|
|
||||||
recency_bias_multiplier = retrieval_request.recency_bias_multiplier
|
|
||||||
run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter
|
|
||||||
|
|
||||||
# First fetch and return the top chunks to the UI so the user can
|
|
||||||
# immediately see some results
|
|
||||||
top_chunks = cast(list[InferenceChunk], next(documents_generator))
|
|
||||||
|
|
||||||
# Get ranking of the documents for citation purposes later
|
# Get ranking of the documents for citation purposes later
|
||||||
doc_id_to_rank_map = map_document_id_order(
|
doc_id_to_rank_map = map_document_id_order(top_chunks)
|
||||||
cast(list[InferenceChunk | LlmDoc], top_chunks)
|
|
||||||
)
|
|
||||||
|
|
||||||
top_docs = chunks_to_search_docs(top_chunks)
|
|
||||||
|
|
||||||
reference_db_search_docs = [
|
reference_db_search_docs = [
|
||||||
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
|
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
|
||||||
@@ -422,24 +409,17 @@ def stream_chat_message_objects(
|
|||||||
initial_response = QADocsResponse(
|
initial_response = QADocsResponse(
|
||||||
rephrased_query=rephrased_query,
|
rephrased_query=rephrased_query,
|
||||||
top_documents=response_docs,
|
top_documents=response_docs,
|
||||||
predicted_flow=predicted_flow,
|
predicted_flow=search_pipeline.predicted_flow,
|
||||||
predicted_search=predicted_search_type,
|
predicted_search=search_pipeline.predicted_search_type,
|
||||||
applied_source_filters=retrieval_request.filters.source_type,
|
applied_source_filters=search_pipeline.search_query.filters.source_type,
|
||||||
applied_time_cutoff=time_cutoff,
|
applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff,
|
||||||
recency_bias_multiplier=recency_bias_multiplier,
|
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
|
||||||
)
|
)
|
||||||
yield initial_response
|
yield initial_response
|
||||||
|
|
||||||
# Get the final ordering of chunks for the LLM call
|
|
||||||
llm_chunk_selection = cast(list[bool], next(documents_generator))
|
|
||||||
|
|
||||||
# Yield the list of LLM selected chunks for showing the LLM selected icons in the UI
|
# Yield the list of LLM selected chunks for showing the LLM selected icons in the UI
|
||||||
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
|
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
|
||||||
relevant_chunk_indices=[
|
relevant_chunk_indices=search_pipeline.relevant_chunk_indicies
|
||||||
index for index, value in enumerate(llm_chunk_selection) if value
|
|
||||||
]
|
|
||||||
if run_llm_chunk_filter
|
|
||||||
else []
|
|
||||||
)
|
)
|
||||||
yield llm_relevance_filtering_response
|
yield llm_relevance_filtering_response
|
||||||
|
|
||||||
@@ -467,7 +447,7 @@ def stream_chat_message_objects(
|
|||||||
)
|
)
|
||||||
llm_chunks_indices = get_chunks_for_qa(
|
llm_chunks_indices = get_chunks_for_qa(
|
||||||
chunks=top_chunks,
|
chunks=top_chunks,
|
||||||
llm_chunk_selection=llm_chunk_selection,
|
llm_chunk_selection=search_pipeline.chunk_relevance_list,
|
||||||
token_limit=chunk_token_limit,
|
token_limit=chunk_token_limit,
|
||||||
llm_tokenizer=llm_tokenizer,
|
llm_tokenizer=llm_tokenizer,
|
||||||
)
|
)
|
||||||
|
@@ -27,7 +27,7 @@ from danswer.db.models import SearchDoc
|
|||||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||||
from danswer.db.models import StarterMessage
|
from danswer.db.models import StarterMessage
|
||||||
from danswer.db.models import User__UserGroup
|
from danswer.db.models import User__UserGroup
|
||||||
from danswer.search.models import RecencyBiasSetting
|
from danswer.search.enums import RecencyBiasSetting
|
||||||
from danswer.search.models import RetrievalDocs
|
from danswer.search.models import RetrievalDocs
|
||||||
from danswer.search.models import SavedSearchDoc
|
from danswer.search.models import SavedSearchDoc
|
||||||
from danswer.search.models import SearchDoc as ServerSearchDoc
|
from danswer.search.models import SearchDoc as ServerSearchDoc
|
||||||
|
@@ -36,8 +36,8 @@ from danswer.configs.constants import MessageType
|
|||||||
from danswer.configs.constants import SearchFeedbackType
|
from danswer.configs.constants import SearchFeedbackType
|
||||||
from danswer.connectors.models import InputType
|
from danswer.connectors.models import InputType
|
||||||
from danswer.dynamic_configs.interface import JSON_ro
|
from danswer.dynamic_configs.interface import JSON_ro
|
||||||
from danswer.search.models import RecencyBiasSetting
|
from danswer.search.enums import RecencyBiasSetting
|
||||||
from danswer.search.models import SearchType
|
from danswer.search.enums import SearchType
|
||||||
|
|
||||||
|
|
||||||
class IndexingStatus(str, PyEnum):
|
class IndexingStatus(str, PyEnum):
|
||||||
|
@@ -12,7 +12,7 @@ from danswer.db.models import Persona
|
|||||||
from danswer.db.models import Persona__DocumentSet
|
from danswer.db.models import Persona__DocumentSet
|
||||||
from danswer.db.models import SlackBotConfig
|
from danswer.db.models import SlackBotConfig
|
||||||
from danswer.db.models import SlackBotResponseType
|
from danswer.db.models import SlackBotResponseType
|
||||||
from danswer.search.models import RecencyBiasSetting
|
from danswer.search.enums import RecencyBiasSetting
|
||||||
|
|
||||||
|
|
||||||
def _build_persona_name(channel_names: list[str]) -> str:
|
def _build_persona_name(channel_names: list[str]) -> str:
|
||||||
|
@@ -64,8 +64,8 @@ from danswer.document_index.vespa.utils import remove_invalid_unicode_chars
|
|||||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
from danswer.search.models import IndexFilters
|
from danswer.search.models import IndexFilters
|
||||||
from danswer.search.search_runner import query_processing
|
from danswer.search.retrieval.search_runner import query_processing
|
||||||
from danswer.search.search_runner import remove_stop_words_and_punctuation
|
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||||
from danswer.utils.batching import batch_generator
|
from danswer.utils.batching import batch_generator
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import BaseMessage
|
||||||
from langchain.schema.messages import HumanMessage
|
from langchain.schema.messages import HumanMessage
|
||||||
@@ -33,11 +32,9 @@ from danswer.db.chat import get_or_create_root_message
|
|||||||
from danswer.db.chat import get_persona_by_id
|
from danswer.db.chat import get_persona_by_id
|
||||||
from danswer.db.chat import get_prompt_by_id
|
from danswer.db.chat import get_prompt_by_id
|
||||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
|
||||||
from danswer.db.engine import get_session_context_manager
|
from danswer.db.engine import get_session_context_manager
|
||||||
from danswer.db.models import Prompt
|
from danswer.db.models import Prompt
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.document_index.factory import get_default_document_index
|
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.utils import get_default_llm_token_encode
|
from danswer.llm.utils import get_default_llm_token_encode
|
||||||
@@ -55,9 +52,9 @@ from danswer.prompts.prompt_utils import build_task_prompt_reminders
|
|||||||
from danswer.search.models import RerankMetricsContainer
|
from danswer.search.models import RerankMetricsContainer
|
||||||
from danswer.search.models import RetrievalMetricsContainer
|
from danswer.search.models import RetrievalMetricsContainer
|
||||||
from danswer.search.models import SavedSearchDoc
|
from danswer.search.models import SavedSearchDoc
|
||||||
from danswer.search.request_preprocessing import retrieval_preprocessing
|
from danswer.search.models import SearchRequest
|
||||||
from danswer.search.search_runner import chunks_to_search_docs
|
from danswer.search.pipeline import SearchPipeline
|
||||||
from danswer.search.search_runner import full_chunk_search_generator
|
from danswer.search.utils import chunks_to_search_docs
|
||||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||||
@@ -221,12 +218,6 @@ def stream_answer_objects(
|
|||||||
|
|
||||||
llm_tokenizer = get_default_llm_token_encode()
|
llm_tokenizer = get_default_llm_token_encode()
|
||||||
|
|
||||||
embedding_model = get_current_db_embedding_model(db_session)
|
|
||||||
|
|
||||||
document_index = get_default_document_index(
|
|
||||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a chat session which will just store the root message, the query, and the AI response
|
# Create a chat session which will just store the root message, the query, and the AI response
|
||||||
root_message = get_or_create_root_message(
|
root_message = get_or_create_root_message(
|
||||||
chat_session_id=chat_session.id, db_session=db_session
|
chat_session_id=chat_session.id, db_session=db_session
|
||||||
@@ -244,33 +235,23 @@ def stream_answer_objects(
|
|||||||
# In chat flow it's given back along with the documents
|
# In chat flow it's given back along with the documents
|
||||||
yield QueryRephrase(rephrased_query=rephrased_query)
|
yield QueryRephrase(rephrased_query=rephrased_query)
|
||||||
|
|
||||||
(
|
search_pipeline = SearchPipeline(
|
||||||
retrieval_request,
|
search_request=SearchRequest(
|
||||||
predicted_search_type,
|
query=rephrased_query,
|
||||||
predicted_flow,
|
human_selected_filters=query_req.retrieval_options.filters,
|
||||||
) = retrieval_preprocessing(
|
persona=chat_session.persona,
|
||||||
query=rephrased_query,
|
offset=query_req.retrieval_options.offset,
|
||||||
retrieval_details=query_req.retrieval_options,
|
limit=query_req.retrieval_options.limit,
|
||||||
persona=chat_session.persona,
|
),
|
||||||
user=user,
|
user=user,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
bypass_acl=bypass_acl,
|
bypass_acl=bypass_acl,
|
||||||
)
|
|
||||||
|
|
||||||
documents_generator = full_chunk_search_generator(
|
|
||||||
search_query=retrieval_request,
|
|
||||||
document_index=document_index,
|
|
||||||
db_session=db_session,
|
|
||||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||||
rerank_metrics_callback=rerank_metrics_callback,
|
rerank_metrics_callback=rerank_metrics_callback,
|
||||||
)
|
)
|
||||||
applied_time_cutoff = retrieval_request.filters.time_cutoff
|
|
||||||
recency_bias_multiplier = retrieval_request.recency_bias_multiplier
|
|
||||||
run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter
|
|
||||||
|
|
||||||
# First fetch and return the top chunks so the user can immediately see some results
|
# First fetch and return the top chunks so the user can immediately see some results
|
||||||
top_chunks = cast(list[InferenceChunk], next(documents_generator))
|
top_chunks = search_pipeline.reranked_docs
|
||||||
|
|
||||||
top_docs = chunks_to_search_docs(top_chunks)
|
top_docs = chunks_to_search_docs(top_chunks)
|
||||||
fake_saved_docs = [SavedSearchDoc.from_search_doc(doc) for doc in top_docs]
|
fake_saved_docs = [SavedSearchDoc.from_search_doc(doc) for doc in top_docs]
|
||||||
|
|
||||||
@@ -278,24 +259,17 @@ def stream_answer_objects(
|
|||||||
initial_response = QADocsResponse(
|
initial_response = QADocsResponse(
|
||||||
rephrased_query=rephrased_query,
|
rephrased_query=rephrased_query,
|
||||||
top_documents=fake_saved_docs,
|
top_documents=fake_saved_docs,
|
||||||
predicted_flow=predicted_flow,
|
predicted_flow=search_pipeline.predicted_flow,
|
||||||
predicted_search=predicted_search_type,
|
predicted_search=search_pipeline.predicted_search_type,
|
||||||
applied_source_filters=retrieval_request.filters.source_type,
|
applied_source_filters=search_pipeline.search_query.filters.source_type,
|
||||||
applied_time_cutoff=applied_time_cutoff,
|
applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff,
|
||||||
recency_bias_multiplier=recency_bias_multiplier,
|
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
|
||||||
)
|
)
|
||||||
yield initial_response
|
yield initial_response
|
||||||
|
|
||||||
# Get the final ordering of chunks for the LLM call
|
|
||||||
llm_chunk_selection = cast(list[bool], next(documents_generator))
|
|
||||||
|
|
||||||
# Yield the list of LLM selected chunks for showing the LLM selected icons in the UI
|
# Yield the list of LLM selected chunks for showing the LLM selected icons in the UI
|
||||||
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
|
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
|
||||||
relevant_chunk_indices=[
|
relevant_chunk_indices=search_pipeline.relevant_chunk_indicies
|
||||||
index for index, value in enumerate(llm_chunk_selection) if value
|
|
||||||
]
|
|
||||||
if run_llm_chunk_filter
|
|
||||||
else []
|
|
||||||
)
|
)
|
||||||
yield llm_relevance_filtering_response
|
yield llm_relevance_filtering_response
|
||||||
|
|
||||||
@@ -317,7 +291,7 @@ def stream_answer_objects(
|
|||||||
|
|
||||||
llm_chunks_indices = get_chunks_for_qa(
|
llm_chunks_indices = get_chunks_for_qa(
|
||||||
chunks=top_chunks,
|
chunks=top_chunks,
|
||||||
llm_chunk_selection=llm_chunk_selection,
|
llm_chunk_selection=search_pipeline.chunk_relevance_list,
|
||||||
token_limit=chunk_token_limit,
|
token_limit=chunk_token_limit,
|
||||||
)
|
)
|
||||||
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||||
|
30
backend/danswer/search/enums.py
Normal file
30
backend/danswer/search/enums.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""NOTE: this needs to be separate from models.py because of circular imports.
|
||||||
|
Both search/models.py and db/models.py import enums from this file AND
|
||||||
|
search/models.py imports from db/models.py."""
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class OptionalSearchSetting(str, Enum):
|
||||||
|
ALWAYS = "always"
|
||||||
|
NEVER = "never"
|
||||||
|
# Determine whether to run search based on history and latest query
|
||||||
|
AUTO = "auto"
|
||||||
|
|
||||||
|
|
||||||
|
class RecencyBiasSetting(str, Enum):
|
||||||
|
FAVOR_RECENT = "favor_recent" # 2x decay rate
|
||||||
|
BASE_DECAY = "base_decay"
|
||||||
|
NO_DECAY = "no_decay"
|
||||||
|
# Determine based on query if to use base_decay or favor_recent
|
||||||
|
AUTO = "auto"
|
||||||
|
|
||||||
|
|
||||||
|
class SearchType(str, Enum):
|
||||||
|
KEYWORD = "keyword"
|
||||||
|
SEMANTIC = "semantic"
|
||||||
|
HYBRID = "hybrid"
|
||||||
|
|
||||||
|
|
||||||
|
class QueryFlow(str, Enum):
|
||||||
|
SEARCH = "search"
|
||||||
|
QUESTION_ANSWER = "question-answer"
|
@@ -1,46 +1,24 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
|
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
|
||||||
|
from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||||
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
|
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
|
||||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||||
|
from danswer.db.models import Persona
|
||||||
|
from danswer.search.enums import OptionalSearchSetting
|
||||||
|
from danswer.search.enums import SearchType
|
||||||
|
|
||||||
|
|
||||||
MAX_METRICS_CONTENT = (
|
MAX_METRICS_CONTENT = (
|
||||||
200 # Just need enough characters to identify where in the doc the chunk is
|
200 # Just need enough characters to identify where in the doc the chunk is
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OptionalSearchSetting(str, Enum):
|
|
||||||
ALWAYS = "always"
|
|
||||||
NEVER = "never"
|
|
||||||
# Determine whether to run search based on history and latest query
|
|
||||||
AUTO = "auto"
|
|
||||||
|
|
||||||
|
|
||||||
class RecencyBiasSetting(str, Enum):
|
|
||||||
FAVOR_RECENT = "favor_recent" # 2x decay rate
|
|
||||||
BASE_DECAY = "base_decay"
|
|
||||||
NO_DECAY = "no_decay"
|
|
||||||
# Determine based on query if to use base_decay or favor_recent
|
|
||||||
AUTO = "auto"
|
|
||||||
|
|
||||||
|
|
||||||
class SearchType(str, Enum):
|
|
||||||
KEYWORD = "keyword"
|
|
||||||
SEMANTIC = "semantic"
|
|
||||||
HYBRID = "hybrid"
|
|
||||||
|
|
||||||
|
|
||||||
class QueryFlow(str, Enum):
|
|
||||||
SEARCH = "search"
|
|
||||||
QUESTION_ANSWER = "question-answer"
|
|
||||||
|
|
||||||
|
|
||||||
class Tag(BaseModel):
|
class Tag(BaseModel):
|
||||||
tag_key: str
|
tag_key: str
|
||||||
tag_value: str
|
tag_value: str
|
||||||
@@ -64,6 +42,28 @@ class ChunkMetric(BaseModel):
|
|||||||
score: float
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
class SearchRequest(BaseModel):
|
||||||
|
"""Input to the SearchPipeline."""
|
||||||
|
|
||||||
|
query: str
|
||||||
|
search_type: SearchType = SearchType.HYBRID
|
||||||
|
|
||||||
|
human_selected_filters: BaseFilters | None = None
|
||||||
|
enable_auto_detect_filters: bool | None = None
|
||||||
|
persona: Persona | None = None
|
||||||
|
|
||||||
|
# if None, no offset / limit
|
||||||
|
offset: int | None = None
|
||||||
|
limit: int | None = None
|
||||||
|
|
||||||
|
recency_bias_multiplier: float = 1.0
|
||||||
|
hybrid_alpha: float = HYBRID_ALPHA
|
||||||
|
skip_rerank: bool = True
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
class SearchQuery(BaseModel):
|
class SearchQuery(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
filters: IndexFilters
|
filters: IndexFilters
|
||||||
|
152
backend/danswer/search/pipeline.py
Normal file
152
backend/danswer/search/pipeline.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||||
|
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||||
|
from danswer.db.models import User
|
||||||
|
from danswer.document_index.factory import get_default_document_index
|
||||||
|
from danswer.indexing.models import InferenceChunk
|
||||||
|
from danswer.search.enums import QueryFlow
|
||||||
|
from danswer.search.enums import SearchType
|
||||||
|
from danswer.search.models import RerankMetricsContainer
|
||||||
|
from danswer.search.models import RetrievalMetricsContainer
|
||||||
|
from danswer.search.models import SearchQuery
|
||||||
|
from danswer.search.models import SearchRequest
|
||||||
|
from danswer.search.postprocessing.postprocessing import search_postprocessing
|
||||||
|
from danswer.search.preprocessing.preprocessing import retrieval_preprocessing
|
||||||
|
from danswer.search.retrieval.search_runner import retrieve_chunks
|
||||||
|
|
||||||
|
|
||||||
|
class SearchPipeline:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
search_request: SearchRequest,
|
||||||
|
user: User | None,
|
||||||
|
db_session: Session,
|
||||||
|
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
|
||||||
|
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||||
|
| None = None,
|
||||||
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||||
|
):
|
||||||
|
self.search_request = search_request
|
||||||
|
self.user = user
|
||||||
|
self.db_session = db_session
|
||||||
|
self.bypass_acl = bypass_acl
|
||||||
|
self.retrieval_metrics_callback = retrieval_metrics_callback
|
||||||
|
self.rerank_metrics_callback = rerank_metrics_callback
|
||||||
|
|
||||||
|
self.embedding_model = get_current_db_embedding_model(db_session)
|
||||||
|
self.document_index = get_default_document_index(
|
||||||
|
primary_index_name=self.embedding_model.index_name,
|
||||||
|
secondary_index_name=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._search_query: SearchQuery | None = None
|
||||||
|
self._predicted_search_type: SearchType | None = None
|
||||||
|
self._predicted_flow: QueryFlow | None = None
|
||||||
|
|
||||||
|
self._retrieved_docs: list[InferenceChunk] | None = None
|
||||||
|
self._reranked_docs: list[InferenceChunk] | None = None
|
||||||
|
self._relevant_chunk_indicies: list[int] | None = None
|
||||||
|
|
||||||
|
"""Pre-processing"""
|
||||||
|
|
||||||
|
def _run_preprocessing(self) -> None:
|
||||||
|
(
|
||||||
|
final_search_query,
|
||||||
|
predicted_search_type,
|
||||||
|
predicted_flow,
|
||||||
|
) = retrieval_preprocessing(
|
||||||
|
search_request=self.search_request,
|
||||||
|
user=self.user,
|
||||||
|
db_session=self.db_session,
|
||||||
|
bypass_acl=self.bypass_acl,
|
||||||
|
)
|
||||||
|
self._predicted_search_type = predicted_search_type
|
||||||
|
self._predicted_flow = predicted_flow
|
||||||
|
self._search_query = final_search_query
|
||||||
|
|
||||||
|
@property
|
||||||
|
def search_query(self) -> SearchQuery:
|
||||||
|
if self._search_query is not None:
|
||||||
|
return self._search_query
|
||||||
|
|
||||||
|
self._run_preprocessing()
|
||||||
|
return cast(SearchQuery, self._search_query)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def predicted_search_type(self) -> SearchType:
|
||||||
|
if self._predicted_search_type is not None:
|
||||||
|
return self._predicted_search_type
|
||||||
|
|
||||||
|
self._run_preprocessing()
|
||||||
|
return cast(SearchType, self._predicted_search_type)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def predicted_flow(self) -> QueryFlow:
|
||||||
|
if self._predicted_flow is not None:
|
||||||
|
return self._predicted_flow
|
||||||
|
|
||||||
|
self._run_preprocessing()
|
||||||
|
return cast(QueryFlow, self._predicted_flow)
|
||||||
|
|
||||||
|
"""Retrieval"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retrieved_docs(self) -> list[InferenceChunk]:
|
||||||
|
if self._retrieved_docs is not None:
|
||||||
|
return self._retrieved_docs
|
||||||
|
|
||||||
|
self._retrieved_docs = retrieve_chunks(
|
||||||
|
query=self.search_query,
|
||||||
|
document_index=self.document_index,
|
||||||
|
db_session=self.db_session,
|
||||||
|
hybrid_alpha=self.search_request.hybrid_alpha,
|
||||||
|
multilingual_expansion_str=MULTILINGUAL_QUERY_EXPANSION,
|
||||||
|
retrieval_metrics_callback=self.retrieval_metrics_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
# self._retrieved_docs = chunks_to_search_docs(retrieved_chunks)
|
||||||
|
return cast(list[InferenceChunk], self._retrieved_docs)
|
||||||
|
|
||||||
|
"""Post-Processing"""
|
||||||
|
|
||||||
|
def _run_postprocessing(self) -> None:
|
||||||
|
postprocessing_generator = search_postprocessing(
|
||||||
|
search_query=self.search_query,
|
||||||
|
retrieved_chunks=self.retrieved_docs,
|
||||||
|
rerank_metrics_callback=self.rerank_metrics_callback,
|
||||||
|
)
|
||||||
|
self._reranked_docs = cast(list[InferenceChunk], next(postprocessing_generator))
|
||||||
|
|
||||||
|
relevant_chunk_ids = cast(list[str], next(postprocessing_generator))
|
||||||
|
self._relevant_chunk_indicies = [
|
||||||
|
ind
|
||||||
|
for ind, chunk in enumerate(self._reranked_docs)
|
||||||
|
if chunk.unique_id in relevant_chunk_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reranked_docs(self) -> list[InferenceChunk]:
|
||||||
|
if self._reranked_docs is not None:
|
||||||
|
return self._reranked_docs
|
||||||
|
|
||||||
|
self._run_postprocessing()
|
||||||
|
return cast(list[InferenceChunk], self._reranked_docs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def relevant_chunk_indicies(self) -> list[int]:
|
||||||
|
if self._relevant_chunk_indicies is not None:
|
||||||
|
return self._relevant_chunk_indicies
|
||||||
|
|
||||||
|
self._run_postprocessing()
|
||||||
|
return cast(list[int], self._relevant_chunk_indicies)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chunk_relevance_list(self) -> list[bool]:
|
||||||
|
return [
|
||||||
|
True if ind in self.relevant_chunk_indicies else False
|
||||||
|
for ind in range(len(self.reranked_docs))
|
||||||
|
]
|
222
backend/danswer/search/postprocessing/postprocessing.py
Normal file
222
backend/danswer/search/postprocessing/postprocessing.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||||
|
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||||
|
from danswer.document_index.document_index_utils import (
|
||||||
|
translate_boost_count_to_multiplier,
|
||||||
|
)
|
||||||
|
from danswer.indexing.models import InferenceChunk
|
||||||
|
from danswer.search.models import ChunkMetric
|
||||||
|
from danswer.search.models import MAX_METRICS_CONTENT
|
||||||
|
from danswer.search.models import RerankMetricsContainer
|
||||||
|
from danswer.search.models import SearchQuery
|
||||||
|
from danswer.search.models import SearchType
|
||||||
|
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||||
|
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||||
|
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||||
|
from danswer.utils.timing import log_function_time
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
|
||||||
|
top_links = [
|
||||||
|
c.source_links[0] if c.source_links is not None else "No Link" for c in chunks
|
||||||
|
]
|
||||||
|
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
||||||
|
|
||||||
|
|
||||||
|
def should_rerank(query: SearchQuery) -> bool:
|
||||||
|
# Don't re-rank for keyword search
|
||||||
|
return query.search_type != SearchType.KEYWORD and not query.skip_rerank
|
||||||
|
|
||||||
|
|
||||||
|
def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool:
|
||||||
|
return not query.skip_llm_chunk_filter
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time(print_only=True)
|
||||||
|
def semantic_reranking(
|
||||||
|
query: str,
|
||||||
|
chunks: list[InferenceChunk],
|
||||||
|
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
||||||
|
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
||||||
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||||
|
) -> tuple[list[InferenceChunk], list[int]]:
|
||||||
|
"""Reranks chunks based on cross-encoder models. Additionally provides the original indices
|
||||||
|
of the chunks in their new sorted order.
|
||||||
|
|
||||||
|
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
||||||
|
"""
|
||||||
|
cross_encoders = CrossEncoderEnsembleModel()
|
||||||
|
passages = [chunk.content for chunk in chunks]
|
||||||
|
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
|
||||||
|
|
||||||
|
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
|
||||||
|
|
||||||
|
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))
|
||||||
|
|
||||||
|
cross_models_min = numpy.min(sim_scores)
|
||||||
|
|
||||||
|
shifted_sim_scores = sum(
|
||||||
|
[enc_n_scores - cross_models_min for enc_n_scores in sim_scores]
|
||||||
|
) / len(sim_scores)
|
||||||
|
|
||||||
|
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
|
||||||
|
recency_multiplier = [chunk.recency_bias for chunk in chunks]
|
||||||
|
boosted_sim_scores = shifted_sim_scores * boosts * recency_multiplier
|
||||||
|
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
|
||||||
|
model_max - model_min
|
||||||
|
)
|
||||||
|
orig_indices = [i for i in range(len(normalized_b_s_scores))]
|
||||||
|
scored_results = list(
|
||||||
|
zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices)
|
||||||
|
)
|
||||||
|
scored_results.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip(
|
||||||
|
*scored_results
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assign new chunk scores based on reranking
|
||||||
|
for ind, chunk in enumerate(ranked_chunks):
|
||||||
|
chunk.score = ranked_sim_scores[ind]
|
||||||
|
|
||||||
|
if rerank_metrics_callback is not None:
|
||||||
|
chunk_metrics = [
|
||||||
|
ChunkMetric(
|
||||||
|
document_id=chunk.document_id,
|
||||||
|
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
||||||
|
first_link=chunk.source_links[0] if chunk.source_links else None,
|
||||||
|
score=chunk.score if chunk.score is not None else 0,
|
||||||
|
)
|
||||||
|
for chunk in ranked_chunks
|
||||||
|
]
|
||||||
|
|
||||||
|
rerank_metrics_callback(
|
||||||
|
RerankMetricsContainer(
|
||||||
|
metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores # type: ignore
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return list(ranked_chunks), list(ranked_indices)
|
||||||
|
|
||||||
|
|
||||||
|
def rerank_chunks(
|
||||||
|
query: SearchQuery,
|
||||||
|
chunks_to_rerank: list[InferenceChunk],
|
||||||
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||||
|
) -> list[InferenceChunk]:
|
||||||
|
ranked_chunks, _ = semantic_reranking(
|
||||||
|
query=query.query,
|
||||||
|
chunks=chunks_to_rerank[: query.num_rerank],
|
||||||
|
rerank_metrics_callback=rerank_metrics_callback,
|
||||||
|
)
|
||||||
|
lower_chunks = chunks_to_rerank[query.num_rerank :]
|
||||||
|
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
||||||
|
for lower_chunk in lower_chunks:
|
||||||
|
lower_chunk.score = None
|
||||||
|
ranked_chunks.extend(lower_chunks)
|
||||||
|
return ranked_chunks
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time(print_only=True)
|
||||||
|
def filter_chunks(
|
||||||
|
query: SearchQuery,
|
||||||
|
chunks_to_filter: list[InferenceChunk],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Filters chunks based on whether the LLM thought they were relevant to the query.
|
||||||
|
|
||||||
|
Returns a list of the unique chunk IDs that were marked as relevant"""
|
||||||
|
chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks]
|
||||||
|
llm_chunk_selection = llm_batch_eval_chunks(
|
||||||
|
query=query.query,
|
||||||
|
chunk_contents=[chunk.content for chunk in chunks_to_filter],
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
chunk.unique_id
|
||||||
|
for ind, chunk in enumerate(chunks_to_filter)
|
||||||
|
if llm_chunk_selection[ind]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def search_postprocessing(
|
||||||
|
search_query: SearchQuery,
|
||||||
|
retrieved_chunks: list[InferenceChunk],
|
||||||
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||||
|
) -> Generator[list[InferenceChunk] | list[str], None, None]:
|
||||||
|
post_processing_tasks: list[FunctionCall] = []
|
||||||
|
|
||||||
|
rerank_task_id = None
|
||||||
|
if should_rerank(search_query):
|
||||||
|
post_processing_tasks.append(
|
||||||
|
FunctionCall(
|
||||||
|
rerank_chunks,
|
||||||
|
(
|
||||||
|
search_query,
|
||||||
|
retrieved_chunks,
|
||||||
|
rerank_metrics_callback,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
rerank_task_id = post_processing_tasks[-1].result_id
|
||||||
|
else:
|
||||||
|
final_chunks = retrieved_chunks
|
||||||
|
# NOTE: if we don't rerank, we can return the chunks immediately
|
||||||
|
# since we know this is the final order
|
||||||
|
_log_top_chunk_links(search_query.search_type.value, final_chunks)
|
||||||
|
yield final_chunks
|
||||||
|
chunks_yielded = True
|
||||||
|
|
||||||
|
llm_filter_task_id = None
|
||||||
|
if should_apply_llm_based_relevance_filter(search_query):
|
||||||
|
post_processing_tasks.append(
|
||||||
|
FunctionCall(
|
||||||
|
filter_chunks,
|
||||||
|
(search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
llm_filter_task_id = post_processing_tasks[-1].result_id
|
||||||
|
|
||||||
|
post_processing_results = (
|
||||||
|
run_functions_in_parallel(post_processing_tasks)
|
||||||
|
if post_processing_tasks
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
reranked_chunks = cast(
|
||||||
|
list[InferenceChunk] | None,
|
||||||
|
post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None,
|
||||||
|
)
|
||||||
|
if reranked_chunks:
|
||||||
|
if chunks_yielded:
|
||||||
|
logger.error(
|
||||||
|
"Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_log_top_chunk_links(search_query.search_type.value, reranked_chunks)
|
||||||
|
yield reranked_chunks
|
||||||
|
|
||||||
|
llm_chunk_selection = cast(
|
||||||
|
list[str] | None,
|
||||||
|
post_processing_results.get(str(llm_filter_task_id))
|
||||||
|
if llm_filter_task_id
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
if llm_chunk_selection is not None:
|
||||||
|
yield [
|
||||||
|
chunk.unique_id
|
||||||
|
for chunk in reranked_chunks or retrieved_chunks
|
||||||
|
if chunk.unique_id in llm_chunk_selection
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
yield []
|
@@ -1,10 +1,10 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from danswer.search.models import QueryFlow
|
from danswer.search.enums import QueryFlow
|
||||||
from danswer.search.models import SearchType
|
from danswer.search.models import SearchType
|
||||||
|
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||||
from danswer.search.search_nlp_models import IntentModel
|
from danswer.search.search_nlp_models import IntentModel
|
||||||
from danswer.search.search_runner import remove_stop_words_and_punctuation
|
|
||||||
from danswer.server.query_and_chat.models import HelperResponse
|
from danswer.server.query_and_chat.models import HelperResponse
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
@@ -5,19 +5,16 @@ from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
|
|||||||
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
|
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
|
||||||
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
|
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
|
||||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||||
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
|
||||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
|
||||||
from danswer.db.models import Persona
|
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.search.access_filters import build_access_filters_for_user
|
from danswer.search.enums import QueryFlow
|
||||||
from danswer.search.danswer_helper import query_intent
|
from danswer.search.enums import RecencyBiasSetting
|
||||||
from danswer.search.models import BaseFilters
|
from danswer.search.models import BaseFilters
|
||||||
from danswer.search.models import IndexFilters
|
from danswer.search.models import IndexFilters
|
||||||
from danswer.search.models import QueryFlow
|
|
||||||
from danswer.search.models import RecencyBiasSetting
|
|
||||||
from danswer.search.models import RetrievalDetails
|
|
||||||
from danswer.search.models import SearchQuery
|
from danswer.search.models import SearchQuery
|
||||||
|
from danswer.search.models import SearchRequest
|
||||||
from danswer.search.models import SearchType
|
from danswer.search.models import SearchType
|
||||||
|
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
|
||||||
|
from danswer.search.preprocessing.danswer_helper import query_intent
|
||||||
from danswer.secondary_llm_flows.source_filter import extract_source_filter
|
from danswer.secondary_llm_flows.source_filter import extract_source_filter
|
||||||
from danswer.secondary_llm_flows.time_filter import extract_time_filter
|
from danswer.secondary_llm_flows.time_filter import extract_time_filter
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@@ -31,15 +28,12 @@ logger = setup_logger()
|
|||||||
|
|
||||||
@log_function_time(print_only=True)
|
@log_function_time(print_only=True)
|
||||||
def retrieval_preprocessing(
|
def retrieval_preprocessing(
|
||||||
query: str,
|
search_request: SearchRequest,
|
||||||
retrieval_details: RetrievalDetails,
|
|
||||||
persona: Persona,
|
|
||||||
user: User | None,
|
user: User | None,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
bypass_acl: bool = False,
|
bypass_acl: bool = False,
|
||||||
include_query_intent: bool = True,
|
include_query_intent: bool = True,
|
||||||
skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW,
|
enable_auto_detect_filters: bool = False,
|
||||||
skip_rerank_non_realtime: bool = not ENABLE_RERANKING_ASYNC_FLOW,
|
|
||||||
disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
|
disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
|
||||||
disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
disable_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
||||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||||
@@ -50,8 +44,12 @@ def retrieval_preprocessing(
|
|||||||
Then any filters or settings as part of the query are used
|
Then any filters or settings as part of the query are used
|
||||||
Then defaults to Persona settings if not specified by the query
|
Then defaults to Persona settings if not specified by the query
|
||||||
"""
|
"""
|
||||||
|
query = search_request.query
|
||||||
|
limit = search_request.limit
|
||||||
|
offset = search_request.offset
|
||||||
|
persona = search_request.persona
|
||||||
|
|
||||||
preset_filters = retrieval_details.filters or BaseFilters()
|
preset_filters = search_request.human_selected_filters or BaseFilters()
|
||||||
if persona and persona.document_sets and preset_filters.document_set is None:
|
if persona and persona.document_sets and preset_filters.document_set is None:
|
||||||
preset_filters.document_set = [
|
preset_filters.document_set = [
|
||||||
document_set.name for document_set in persona.document_sets
|
document_set.name for document_set in persona.document_sets
|
||||||
@@ -65,16 +63,20 @@ def retrieval_preprocessing(
|
|||||||
if disable_llm_filter_extraction:
|
if disable_llm_filter_extraction:
|
||||||
auto_detect_time_filter = False
|
auto_detect_time_filter = False
|
||||||
auto_detect_source_filter = False
|
auto_detect_source_filter = False
|
||||||
elif retrieval_details.enable_auto_detect_filters is False:
|
elif enable_auto_detect_filters is False:
|
||||||
logger.debug("Retrieval details disables auto detect filters")
|
logger.debug("Retrieval details disables auto detect filters")
|
||||||
auto_detect_time_filter = False
|
auto_detect_time_filter = False
|
||||||
auto_detect_source_filter = False
|
auto_detect_source_filter = False
|
||||||
elif persona.llm_filter_extraction is False:
|
elif persona and persona.llm_filter_extraction is False:
|
||||||
logger.debug("Persona disables auto detect filters")
|
logger.debug("Persona disables auto detect filters")
|
||||||
auto_detect_time_filter = False
|
auto_detect_time_filter = False
|
||||||
auto_detect_source_filter = False
|
auto_detect_source_filter = False
|
||||||
|
|
||||||
if time_filter is not None and persona.recency_bias != RecencyBiasSetting.AUTO:
|
if (
|
||||||
|
time_filter is not None
|
||||||
|
and persona
|
||||||
|
and persona.recency_bias != RecencyBiasSetting.AUTO
|
||||||
|
):
|
||||||
auto_detect_time_filter = False
|
auto_detect_time_filter = False
|
||||||
logger.debug("Not extract time filter - already provided")
|
logger.debug("Not extract time filter - already provided")
|
||||||
if source_filter is not None:
|
if source_filter is not None:
|
||||||
@@ -138,24 +140,18 @@ def retrieval_preprocessing(
|
|||||||
access_control_list=user_acl_filters,
|
access_control_list=user_acl_filters,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tranformer-based re-ranking to run at same time as LLM chunk relevance filter
|
llm_chunk_filter = False
|
||||||
# This one is only set globally, not via query or Persona settings
|
if persona:
|
||||||
skip_reranking = (
|
llm_chunk_filter = persona.llm_relevance_filter
|
||||||
skip_rerank_realtime
|
|
||||||
if retrieval_details.real_time
|
|
||||||
else skip_rerank_non_realtime
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_chunk_filter = persona.llm_relevance_filter
|
|
||||||
if disable_llm_chunk_filter:
|
if disable_llm_chunk_filter:
|
||||||
llm_chunk_filter = False
|
llm_chunk_filter = False
|
||||||
|
|
||||||
# Decays at 1 / (1 + (multiplier * num years))
|
# Decays at 1 / (1 + (multiplier * num years))
|
||||||
if persona.recency_bias == RecencyBiasSetting.NO_DECAY:
|
if persona and persona.recency_bias == RecencyBiasSetting.NO_DECAY:
|
||||||
recency_bias_multiplier = 0.0
|
recency_bias_multiplier = 0.0
|
||||||
elif persona.recency_bias == RecencyBiasSetting.BASE_DECAY:
|
elif persona and persona.recency_bias == RecencyBiasSetting.BASE_DECAY:
|
||||||
recency_bias_multiplier = base_recency_decay
|
recency_bias_multiplier = base_recency_decay
|
||||||
elif persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT:
|
elif persona and persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT:
|
||||||
recency_bias_multiplier = base_recency_decay * favor_recent_decay_multiplier
|
recency_bias_multiplier = base_recency_decay * favor_recent_decay_multiplier
|
||||||
else:
|
else:
|
||||||
if predicted_favor_recent:
|
if predicted_favor_recent:
|
||||||
@@ -166,14 +162,12 @@ def retrieval_preprocessing(
|
|||||||
return (
|
return (
|
||||||
SearchQuery(
|
SearchQuery(
|
||||||
query=query,
|
query=query,
|
||||||
search_type=persona.search_type,
|
search_type=persona.search_type if persona else SearchType.HYBRID,
|
||||||
filters=final_filters,
|
filters=final_filters,
|
||||||
recency_bias_multiplier=recency_bias_multiplier,
|
recency_bias_multiplier=recency_bias_multiplier,
|
||||||
num_hits=retrieval_details.limit
|
num_hits=limit if limit is not None else NUM_RETURNED_HITS,
|
||||||
if retrieval_details.limit is not None
|
offset=offset or 0,
|
||||||
else NUM_RETURNED_HITS,
|
skip_rerank=search_request.skip_rerank,
|
||||||
offset=retrieval_details.offset or 0,
|
|
||||||
skip_rerank=skip_reranking,
|
|
||||||
skip_llm_chunk_filter=not llm_chunk_filter,
|
skip_llm_chunk_filter=not llm_chunk_filter,
|
||||||
),
|
),
|
||||||
predicted_search_type,
|
predicted_search_type,
|
256
backend/danswer/search/retrieval/search_runner.py
Normal file
256
backend/danswer/search/retrieval/search_runner.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
import string
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from nltk.corpus import stopwords # type:ignore
|
||||||
|
from nltk.stem import WordNetLemmatizer # type:ignore
|
||||||
|
from nltk.tokenize import word_tokenize # type:ignore
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.chat.models import LlmDoc
|
||||||
|
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||||
|
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||||
|
from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||||
|
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||||
|
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||||
|
from danswer.document_index.interfaces import DocumentIndex
|
||||||
|
from danswer.indexing.models import InferenceChunk
|
||||||
|
from danswer.search.models import ChunkMetric
|
||||||
|
from danswer.search.models import IndexFilters
|
||||||
|
from danswer.search.models import MAX_METRICS_CONTENT
|
||||||
|
from danswer.search.models import RetrievalMetricsContainer
|
||||||
|
from danswer.search.models import SearchQuery
|
||||||
|
from danswer.search.models import SearchType
|
||||||
|
from danswer.search.search_nlp_models import EmbeddingModel
|
||||||
|
from danswer.search.search_nlp_models import EmbedTextType
|
||||||
|
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||||
|
from danswer.utils.timing import log_function_time
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def lemmatize_text(text: str) -> list[str]:
|
||||||
|
lemmatizer = WordNetLemmatizer()
|
||||||
|
word_tokens = word_tokenize(text)
|
||||||
|
return [lemmatizer.lemmatize(word) for word in word_tokens]
|
||||||
|
|
||||||
|
|
||||||
|
def remove_stop_words_and_punctuation(text: str) -> list[str]:
|
||||||
|
stop_words = set(stopwords.words("english"))
|
||||||
|
word_tokens = word_tokenize(text)
|
||||||
|
text_trimmed = [
|
||||||
|
word
|
||||||
|
for word in word_tokens
|
||||||
|
if (word.casefold() not in stop_words and word not in string.punctuation)
|
||||||
|
]
|
||||||
|
return text_trimmed or word_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def query_processing(
|
||||||
|
query: str,
|
||||||
|
) -> str:
|
||||||
|
query = " ".join(remove_stop_words_and_punctuation(query))
|
||||||
|
query = " ".join(lemmatize_text(query))
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
def combine_retrieval_results(
|
||||||
|
chunk_sets: list[list[InferenceChunk]],
|
||||||
|
) -> list[InferenceChunk]:
|
||||||
|
all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set]
|
||||||
|
|
||||||
|
unique_chunks: dict[tuple[str, int], InferenceChunk] = {}
|
||||||
|
for chunk in all_chunks:
|
||||||
|
key = (chunk.document_id, chunk.chunk_id)
|
||||||
|
if key not in unique_chunks:
|
||||||
|
unique_chunks[key] = chunk
|
||||||
|
continue
|
||||||
|
|
||||||
|
stored_chunk_score = unique_chunks[key].score or 0
|
||||||
|
this_chunk_score = chunk.score or 0
|
||||||
|
if stored_chunk_score < this_chunk_score:
|
||||||
|
unique_chunks[key] = chunk
|
||||||
|
|
||||||
|
sorted_chunks = sorted(
|
||||||
|
unique_chunks.values(), key=lambda x: x.score or 0, reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return sorted_chunks
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time(print_only=True)
|
||||||
|
def doc_index_retrieval(
|
||||||
|
query: SearchQuery,
|
||||||
|
document_index: DocumentIndex,
|
||||||
|
db_session: Session,
|
||||||
|
hybrid_alpha: float = HYBRID_ALPHA,
|
||||||
|
) -> list[InferenceChunk]:
|
||||||
|
if query.search_type == SearchType.KEYWORD:
|
||||||
|
top_chunks = document_index.keyword_retrieval(
|
||||||
|
query=query.query,
|
||||||
|
filters=query.filters,
|
||||||
|
time_decay_multiplier=query.recency_bias_multiplier,
|
||||||
|
num_to_retrieve=query.num_hits,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||||
|
|
||||||
|
model = EmbeddingModel(
|
||||||
|
model_name=db_embedding_model.model_name,
|
||||||
|
query_prefix=db_embedding_model.query_prefix,
|
||||||
|
passage_prefix=db_embedding_model.passage_prefix,
|
||||||
|
normalize=db_embedding_model.normalize,
|
||||||
|
# The below are globally set, this flow always uses the indexing one
|
||||||
|
server_host=MODEL_SERVER_HOST,
|
||||||
|
server_port=MODEL_SERVER_PORT,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
|
||||||
|
|
||||||
|
if query.search_type == SearchType.SEMANTIC:
|
||||||
|
top_chunks = document_index.semantic_retrieval(
|
||||||
|
query=query.query,
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
filters=query.filters,
|
||||||
|
time_decay_multiplier=query.recency_bias_multiplier,
|
||||||
|
num_to_retrieve=query.num_hits,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif query.search_type == SearchType.HYBRID:
|
||||||
|
top_chunks = document_index.hybrid_retrieval(
|
||||||
|
query=query.query,
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
filters=query.filters,
|
||||||
|
time_decay_multiplier=query.recency_bias_multiplier,
|
||||||
|
num_to_retrieve=query.num_hits,
|
||||||
|
offset=query.offset,
|
||||||
|
hybrid_alpha=hybrid_alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Invalid Search Flow")
|
||||||
|
|
||||||
|
return top_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _simplify_text(text: str) -> str:
|
||||||
|
return "".join(
|
||||||
|
char for char in text if char not in string.punctuation and not char.isspace()
|
||||||
|
).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def retrieve_chunks(
|
||||||
|
query: SearchQuery,
|
||||||
|
document_index: DocumentIndex,
|
||||||
|
db_session: Session,
|
||||||
|
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
|
||||||
|
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
|
||||||
|
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||||
|
| None = None,
|
||||||
|
) -> list[InferenceChunk]:
|
||||||
|
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
|
||||||
|
# Don't do query expansion on complex queries, rephrasings likely would not work well
|
||||||
|
if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query:
|
||||||
|
top_chunks = doc_index_retrieval(
|
||||||
|
query=query,
|
||||||
|
document_index=document_index,
|
||||||
|
db_session=db_session,
|
||||||
|
hybrid_alpha=hybrid_alpha,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
simplified_queries = set()
|
||||||
|
run_queries: list[tuple[Callable, tuple]] = []
|
||||||
|
|
||||||
|
# Currently only uses query expansion on multilingual use cases
|
||||||
|
query_rephrases = multilingual_query_expansion(
|
||||||
|
query.query, multilingual_expansion_str
|
||||||
|
)
|
||||||
|
# Just to be extra sure, add the original query.
|
||||||
|
query_rephrases.append(query.query)
|
||||||
|
for rephrase in set(query_rephrases):
|
||||||
|
# Sometimes the model rephrases the query in the same language with minor changes
|
||||||
|
# Avoid doing an extra search with the minor changes as this biases the results
|
||||||
|
simplified_rephrase = _simplify_text(rephrase)
|
||||||
|
if simplified_rephrase in simplified_queries:
|
||||||
|
continue
|
||||||
|
simplified_queries.add(simplified_rephrase)
|
||||||
|
|
||||||
|
q_copy = query.copy(update={"query": rephrase}, deep=True)
|
||||||
|
run_queries.append(
|
||||||
|
(
|
||||||
|
doc_index_retrieval,
|
||||||
|
(q_copy, document_index, db_session, hybrid_alpha),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||||
|
top_chunks = combine_retrieval_results(parallel_search_results)
|
||||||
|
|
||||||
|
if not top_chunks:
|
||||||
|
logger.info(
|
||||||
|
f"{query.search_type.value.capitalize()} search returned no results "
|
||||||
|
f"with filters: {query.filters}"
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
|
||||||
|
if retrieval_metrics_callback is not None:
|
||||||
|
chunk_metrics = [
|
||||||
|
ChunkMetric(
|
||||||
|
document_id=chunk.document_id,
|
||||||
|
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
||||||
|
first_link=chunk.source_links[0] if chunk.source_links else None,
|
||||||
|
score=chunk.score if chunk.score is not None else 0,
|
||||||
|
)
|
||||||
|
for chunk in top_chunks
|
||||||
|
]
|
||||||
|
retrieval_metrics_callback(
|
||||||
|
RetrievalMetricsContainer(
|
||||||
|
search_type=query.search_type, metrics=chunk_metrics
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return top_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc:
|
||||||
|
if not inf_chunks:
|
||||||
|
raise ValueError("Cannot combine empty list of chunks")
|
||||||
|
|
||||||
|
# Use the first link of the document
|
||||||
|
first_chunk = inf_chunks[0]
|
||||||
|
chunk_texts = [chunk.content for chunk in inf_chunks]
|
||||||
|
return LlmDoc(
|
||||||
|
document_id=first_chunk.document_id,
|
||||||
|
content="\n".join(chunk_texts),
|
||||||
|
semantic_identifier=first_chunk.semantic_identifier,
|
||||||
|
source_type=first_chunk.source_type,
|
||||||
|
metadata=first_chunk.metadata,
|
||||||
|
updated_at=first_chunk.updated_at,
|
||||||
|
link=first_chunk.source_links[0] if first_chunk.source_links else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def inference_documents_from_ids(
|
||||||
|
doc_identifiers: list[tuple[str, int]],
|
||||||
|
document_index: DocumentIndex,
|
||||||
|
) -> list[LlmDoc]:
|
||||||
|
# Currently only fetches whole docs
|
||||||
|
doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers)
|
||||||
|
|
||||||
|
# No need for ACL here because the doc ids were validated beforehand
|
||||||
|
filters = IndexFilters(access_control_list=None)
|
||||||
|
|
||||||
|
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||||
|
(document_index.id_based_retrieval, (doc_id, None, filters))
|
||||||
|
for doc_id in doc_ids_set
|
||||||
|
]
|
||||||
|
|
||||||
|
parallel_results = run_functions_tuples_in_parallel(
|
||||||
|
functions_with_args, allow_failures=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Any failures to retrieve would give a None, drop the Nones and empty lists
|
||||||
|
inference_chunks_sets = [res for res in parallel_results if res]
|
||||||
|
|
||||||
|
return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets]
|
@@ -1,645 +0,0 @@
|
|||||||
import string
|
|
||||||
from collections.abc import Callable
|
|
||||||
from collections.abc import Iterator
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
import numpy
|
|
||||||
from nltk.corpus import stopwords # type:ignore
|
|
||||||
from nltk.stem import WordNetLemmatizer # type:ignore
|
|
||||||
from nltk.tokenize import word_tokenize # type:ignore
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from danswer.chat.models import LlmDoc
|
|
||||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
|
||||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
|
||||||
from danswer.configs.chat_configs import HYBRID_ALPHA
|
|
||||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
|
||||||
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
|
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
|
||||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
|
|
||||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
|
|
||||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
|
||||||
from danswer.document_index.document_index_utils import (
|
|
||||||
translate_boost_count_to_multiplier,
|
|
||||||
)
|
|
||||||
from danswer.document_index.interfaces import DocumentIndex
|
|
||||||
from danswer.indexing.models import InferenceChunk
|
|
||||||
from danswer.search.models import ChunkMetric
|
|
||||||
from danswer.search.models import IndexFilters
|
|
||||||
from danswer.search.models import MAX_METRICS_CONTENT
|
|
||||||
from danswer.search.models import RerankMetricsContainer
|
|
||||||
from danswer.search.models import RetrievalMetricsContainer
|
|
||||||
from danswer.search.models import SearchDoc
|
|
||||||
from danswer.search.models import SearchQuery
|
|
||||||
from danswer.search.models import SearchType
|
|
||||||
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
|
||||||
from danswer.search.search_nlp_models import EmbeddingModel
|
|
||||||
from danswer.search.search_nlp_models import EmbedTextType
|
|
||||||
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
|
|
||||||
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
|
||||||
from danswer.utils.logger import setup_logger
|
|
||||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
|
||||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
|
||||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
|
||||||
from danswer.utils.timing import log_function_time
|
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
|
||||||
|
|
||||||
|
|
||||||
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
|
|
||||||
top_links = [
|
|
||||||
c.source_links[0] if c.source_links is not None else "No Link" for c in chunks
|
|
||||||
]
|
|
||||||
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
|
||||||
|
|
||||||
|
|
||||||
def lemmatize_text(text: str) -> list[str]:
|
|
||||||
lemmatizer = WordNetLemmatizer()
|
|
||||||
word_tokens = word_tokenize(text)
|
|
||||||
return [lemmatizer.lemmatize(word) for word in word_tokens]
|
|
||||||
|
|
||||||
|
|
||||||
def remove_stop_words_and_punctuation(text: str) -> list[str]:
|
|
||||||
stop_words = set(stopwords.words("english"))
|
|
||||||
word_tokens = word_tokenize(text)
|
|
||||||
text_trimmed = [
|
|
||||||
word
|
|
||||||
for word in word_tokens
|
|
||||||
if (word.casefold() not in stop_words and word not in string.punctuation)
|
|
||||||
]
|
|
||||||
return text_trimmed or word_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def query_processing(
|
|
||||||
query: str,
|
|
||||||
) -> str:
|
|
||||||
query = " ".join(remove_stop_words_and_punctuation(query))
|
|
||||||
query = " ".join(lemmatize_text(query))
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
|
||||||
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
|
|
||||||
search_docs = (
|
|
||||||
[
|
|
||||||
SearchDoc(
|
|
||||||
document_id=chunk.document_id,
|
|
||||||
chunk_ind=chunk.chunk_id,
|
|
||||||
semantic_identifier=chunk.semantic_identifier or "Unknown",
|
|
||||||
link=chunk.source_links.get(0) if chunk.source_links else None,
|
|
||||||
blurb=chunk.blurb,
|
|
||||||
source_type=chunk.source_type,
|
|
||||||
boost=chunk.boost,
|
|
||||||
hidden=chunk.hidden,
|
|
||||||
metadata=chunk.metadata,
|
|
||||||
score=chunk.score,
|
|
||||||
match_highlights=chunk.match_highlights,
|
|
||||||
updated_at=chunk.updated_at,
|
|
||||||
primary_owners=chunk.primary_owners,
|
|
||||||
secondary_owners=chunk.secondary_owners,
|
|
||||||
)
|
|
||||||
for chunk in chunks
|
|
||||||
]
|
|
||||||
if chunks
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
return search_docs
|
|
||||||
|
|
||||||
|
|
||||||
def combine_retrieval_results(
|
|
||||||
chunk_sets: list[list[InferenceChunk]],
|
|
||||||
) -> list[InferenceChunk]:
|
|
||||||
all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set]
|
|
||||||
|
|
||||||
unique_chunks: dict[tuple[str, int], InferenceChunk] = {}
|
|
||||||
for chunk in all_chunks:
|
|
||||||
key = (chunk.document_id, chunk.chunk_id)
|
|
||||||
if key not in unique_chunks:
|
|
||||||
unique_chunks[key] = chunk
|
|
||||||
continue
|
|
||||||
|
|
||||||
stored_chunk_score = unique_chunks[key].score or 0
|
|
||||||
this_chunk_score = chunk.score or 0
|
|
||||||
if stored_chunk_score < this_chunk_score:
|
|
||||||
unique_chunks[key] = chunk
|
|
||||||
|
|
||||||
sorted_chunks = sorted(
|
|
||||||
unique_chunks.values(), key=lambda x: x.score or 0, reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return sorted_chunks
|
|
||||||
|
|
||||||
|
|
||||||
@log_function_time(print_only=True)
|
|
||||||
def doc_index_retrieval(
|
|
||||||
query: SearchQuery,
|
|
||||||
document_index: DocumentIndex,
|
|
||||||
db_session: Session,
|
|
||||||
hybrid_alpha: float = HYBRID_ALPHA,
|
|
||||||
) -> list[InferenceChunk]:
|
|
||||||
if query.search_type == SearchType.KEYWORD:
|
|
||||||
top_chunks = document_index.keyword_retrieval(
|
|
||||||
query=query.query,
|
|
||||||
filters=query.filters,
|
|
||||||
time_decay_multiplier=query.recency_bias_multiplier,
|
|
||||||
num_to_retrieve=query.num_hits,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
|
||||||
|
|
||||||
model = EmbeddingModel(
|
|
||||||
model_name=db_embedding_model.model_name,
|
|
||||||
query_prefix=db_embedding_model.query_prefix,
|
|
||||||
passage_prefix=db_embedding_model.passage_prefix,
|
|
||||||
normalize=db_embedding_model.normalize,
|
|
||||||
# The below are globally set, this flow always uses the indexing one
|
|
||||||
server_host=MODEL_SERVER_HOST,
|
|
||||||
server_port=MODEL_SERVER_PORT,
|
|
||||||
)
|
|
||||||
|
|
||||||
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
|
|
||||||
|
|
||||||
if query.search_type == SearchType.SEMANTIC:
|
|
||||||
top_chunks = document_index.semantic_retrieval(
|
|
||||||
query=query.query,
|
|
||||||
query_embedding=query_embedding,
|
|
||||||
filters=query.filters,
|
|
||||||
time_decay_multiplier=query.recency_bias_multiplier,
|
|
||||||
num_to_retrieve=query.num_hits,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif query.search_type == SearchType.HYBRID:
|
|
||||||
top_chunks = document_index.hybrid_retrieval(
|
|
||||||
query=query.query,
|
|
||||||
query_embedding=query_embedding,
|
|
||||||
filters=query.filters,
|
|
||||||
time_decay_multiplier=query.recency_bias_multiplier,
|
|
||||||
num_to_retrieve=query.num_hits,
|
|
||||||
offset=query.offset,
|
|
||||||
hybrid_alpha=hybrid_alpha,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Invalid Search Flow")
|
|
||||||
|
|
||||||
return top_chunks
|
|
||||||
|
|
||||||
|
|
||||||
@log_function_time(print_only=True)
|
|
||||||
def semantic_reranking(
|
|
||||||
query: str,
|
|
||||||
chunks: list[InferenceChunk],
|
|
||||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
|
||||||
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
|
||||||
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
|
||||||
) -> tuple[list[InferenceChunk], list[int]]:
|
|
||||||
"""Reranks chunks based on cross-encoder models. Additionally provides the original indices
|
|
||||||
of the chunks in their new sorted order.
|
|
||||||
|
|
||||||
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
|
||||||
"""
|
|
||||||
cross_encoders = CrossEncoderEnsembleModel()
|
|
||||||
passages = [chunk.content for chunk in chunks]
|
|
||||||
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
|
|
||||||
|
|
||||||
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
|
|
||||||
|
|
||||||
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))
|
|
||||||
|
|
||||||
cross_models_min = numpy.min(sim_scores)
|
|
||||||
|
|
||||||
shifted_sim_scores = sum(
|
|
||||||
[enc_n_scores - cross_models_min for enc_n_scores in sim_scores]
|
|
||||||
) / len(sim_scores)
|
|
||||||
|
|
||||||
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
|
|
||||||
recency_multiplier = [chunk.recency_bias for chunk in chunks]
|
|
||||||
boosted_sim_scores = shifted_sim_scores * boosts * recency_multiplier
|
|
||||||
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
|
|
||||||
model_max - model_min
|
|
||||||
)
|
|
||||||
orig_indices = [i for i in range(len(normalized_b_s_scores))]
|
|
||||||
scored_results = list(
|
|
||||||
zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices)
|
|
||||||
)
|
|
||||||
scored_results.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip(
|
|
||||||
*scored_results
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assign new chunk scores based on reranking
|
|
||||||
for ind, chunk in enumerate(ranked_chunks):
|
|
||||||
chunk.score = ranked_sim_scores[ind]
|
|
||||||
|
|
||||||
if rerank_metrics_callback is not None:
|
|
||||||
chunk_metrics = [
|
|
||||||
ChunkMetric(
|
|
||||||
document_id=chunk.document_id,
|
|
||||||
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
|
||||||
first_link=chunk.source_links[0] if chunk.source_links else None,
|
|
||||||
score=chunk.score if chunk.score is not None else 0,
|
|
||||||
)
|
|
||||||
for chunk in ranked_chunks
|
|
||||||
]
|
|
||||||
|
|
||||||
rerank_metrics_callback(
|
|
||||||
RerankMetricsContainer(
|
|
||||||
metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores # type: ignore
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return list(ranked_chunks), list(ranked_indices)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_boost_legacy(
|
|
||||||
chunks: list[InferenceChunk],
|
|
||||||
norm_min: float = SIM_SCORE_RANGE_LOW,
|
|
||||||
norm_max: float = SIM_SCORE_RANGE_HIGH,
|
|
||||||
) -> list[InferenceChunk]:
|
|
||||||
scores = [chunk.score or 0 for chunk in chunks]
|
|
||||||
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
|
|
||||||
|
|
||||||
logger.debug(f"Raw similarity scores: {scores}")
|
|
||||||
|
|
||||||
score_min = min(scores)
|
|
||||||
score_max = max(scores)
|
|
||||||
score_range = score_max - score_min
|
|
||||||
|
|
||||||
if score_range != 0:
|
|
||||||
boosted_scores = [
|
|
||||||
((score - score_min) / score_range) * boost
|
|
||||||
for score, boost in zip(scores, boosts)
|
|
||||||
]
|
|
||||||
unnormed_boosted_scores = [
|
|
||||||
score * score_range + score_min for score in boosted_scores
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
unnormed_boosted_scores = [
|
|
||||||
score * boost for score, boost in zip(scores, boosts)
|
|
||||||
]
|
|
||||||
|
|
||||||
norm_min = min(norm_min, min(scores))
|
|
||||||
norm_max = max(norm_max, max(scores))
|
|
||||||
# This should never be 0 unless user has done some weird/wrong settings
|
|
||||||
norm_range = norm_max - norm_min
|
|
||||||
|
|
||||||
# For score display purposes
|
|
||||||
if norm_range != 0:
|
|
||||||
re_normed_scores = [
|
|
||||||
((score - norm_min) / norm_range) for score in unnormed_boosted_scores
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
re_normed_scores = unnormed_boosted_scores
|
|
||||||
|
|
||||||
rescored_chunks = list(zip(re_normed_scores, chunks))
|
|
||||||
rescored_chunks.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks)
|
|
||||||
|
|
||||||
final_chunks = list(boost_sorted_chunks)
|
|
||||||
final_scores = list(sorted_boosted_scores)
|
|
||||||
for ind, chunk in enumerate(final_chunks):
|
|
||||||
chunk.score = final_scores[ind]
|
|
||||||
|
|
||||||
logger.debug(f"Boost sorted similary scores: {list(final_scores)}")
|
|
||||||
|
|
||||||
return final_chunks
|
|
||||||
|
|
||||||
|
|
||||||
def apply_boost(
|
|
||||||
chunks: list[InferenceChunk],
|
|
||||||
# Need the range of values to not be too spread out for applying boost
|
|
||||||
# therefore norm across only the top few results
|
|
||||||
norm_cutoff: int = NUM_RERANKED_RESULTS,
|
|
||||||
norm_min: float = SIM_SCORE_RANGE_LOW,
|
|
||||||
norm_max: float = SIM_SCORE_RANGE_HIGH,
|
|
||||||
) -> list[InferenceChunk]:
|
|
||||||
scores = [chunk.score or 0.0 for chunk in chunks]
|
|
||||||
logger.debug(f"Raw similarity scores: {scores}")
|
|
||||||
|
|
||||||
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
|
|
||||||
recency_multiplier = [chunk.recency_bias for chunk in chunks]
|
|
||||||
|
|
||||||
norm_min = min(norm_min, min(scores[:norm_cutoff]))
|
|
||||||
norm_max = max(norm_max, max(scores[:norm_cutoff]))
|
|
||||||
# This should never be 0 unless user has done some weird/wrong settings
|
|
||||||
norm_range = norm_max - norm_min
|
|
||||||
|
|
||||||
boosted_scores = [
|
|
||||||
max(0, (score - norm_min) * boost * recency / norm_range)
|
|
||||||
for score, boost, recency in zip(scores, boosts, recency_multiplier)
|
|
||||||
]
|
|
||||||
|
|
||||||
rescored_chunks = list(zip(boosted_scores, chunks))
|
|
||||||
rescored_chunks.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks)
|
|
||||||
|
|
||||||
final_chunks = list(boost_sorted_chunks)
|
|
||||||
final_scores = list(sorted_boosted_scores)
|
|
||||||
for ind, chunk in enumerate(final_chunks):
|
|
||||||
chunk.score = final_scores[ind]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Boosted + Time Weighted sorted similarity scores: {list(final_scores)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return final_chunks
|
|
||||||
|
|
||||||
|
|
||||||
def _simplify_text(text: str) -> str:
|
|
||||||
return "".join(
|
|
||||||
char for char in text if char not in string.punctuation and not char.isspace()
|
|
||||||
).lower()
|
|
||||||
|
|
||||||
|
|
||||||
def retrieve_chunks(
|
|
||||||
query: SearchQuery,
|
|
||||||
document_index: DocumentIndex,
|
|
||||||
db_session: Session,
|
|
||||||
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
|
|
||||||
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
|
|
||||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
|
||||||
| None = None,
|
|
||||||
) -> list[InferenceChunk]:
|
|
||||||
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
|
|
||||||
# Don't do query expansion on complex queries, rephrasings likely would not work well
|
|
||||||
if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query:
|
|
||||||
top_chunks = doc_index_retrieval(
|
|
||||||
query=query,
|
|
||||||
document_index=document_index,
|
|
||||||
db_session=db_session,
|
|
||||||
hybrid_alpha=hybrid_alpha,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
simplified_queries = set()
|
|
||||||
run_queries: list[tuple[Callable, tuple]] = []
|
|
||||||
|
|
||||||
# Currently only uses query expansion on multilingual use cases
|
|
||||||
query_rephrases = multilingual_query_expansion(
|
|
||||||
query.query, multilingual_expansion_str
|
|
||||||
)
|
|
||||||
# Just to be extra sure, add the original query.
|
|
||||||
query_rephrases.append(query.query)
|
|
||||||
for rephrase in set(query_rephrases):
|
|
||||||
# Sometimes the model rephrases the query in the same language with minor changes
|
|
||||||
# Avoid doing an extra search with the minor changes as this biases the results
|
|
||||||
simplified_rephrase = _simplify_text(rephrase)
|
|
||||||
if simplified_rephrase in simplified_queries:
|
|
||||||
continue
|
|
||||||
simplified_queries.add(simplified_rephrase)
|
|
||||||
|
|
||||||
q_copy = query.copy(update={"query": rephrase}, deep=True)
|
|
||||||
run_queries.append(
|
|
||||||
(
|
|
||||||
doc_index_retrieval,
|
|
||||||
(q_copy, document_index, db_session, hybrid_alpha),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
|
||||||
top_chunks = combine_retrieval_results(parallel_search_results)
|
|
||||||
|
|
||||||
if not top_chunks:
|
|
||||||
logger.info(
|
|
||||||
f"{query.search_type.value.capitalize()} search returned no results "
|
|
||||||
f"with filters: {query.filters}"
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
|
|
||||||
if retrieval_metrics_callback is not None:
|
|
||||||
chunk_metrics = [
|
|
||||||
ChunkMetric(
|
|
||||||
document_id=chunk.document_id,
|
|
||||||
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
|
||||||
first_link=chunk.source_links[0] if chunk.source_links else None,
|
|
||||||
score=chunk.score if chunk.score is not None else 0,
|
|
||||||
)
|
|
||||||
for chunk in top_chunks
|
|
||||||
]
|
|
||||||
retrieval_metrics_callback(
|
|
||||||
RetrievalMetricsContainer(
|
|
||||||
search_type=query.search_type, metrics=chunk_metrics
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return top_chunks
|
|
||||||
|
|
||||||
|
|
||||||
def should_rerank(query: SearchQuery) -> bool:
|
|
||||||
# Don't re-rank for keyword search
|
|
||||||
return query.search_type != SearchType.KEYWORD and not query.skip_rerank
|
|
||||||
|
|
||||||
|
|
||||||
def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool:
|
|
||||||
return not query.skip_llm_chunk_filter
|
|
||||||
|
|
||||||
|
|
||||||
def rerank_chunks(
|
|
||||||
query: SearchQuery,
|
|
||||||
chunks_to_rerank: list[InferenceChunk],
|
|
||||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
|
||||||
) -> list[InferenceChunk]:
|
|
||||||
ranked_chunks, _ = semantic_reranking(
|
|
||||||
query=query.query,
|
|
||||||
chunks=chunks_to_rerank[: query.num_rerank],
|
|
||||||
rerank_metrics_callback=rerank_metrics_callback,
|
|
||||||
)
|
|
||||||
lower_chunks = chunks_to_rerank[query.num_rerank :]
|
|
||||||
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
|
||||||
for lower_chunk in lower_chunks:
|
|
||||||
lower_chunk.score = None
|
|
||||||
ranked_chunks.extend(lower_chunks)
|
|
||||||
return ranked_chunks
|
|
||||||
|
|
||||||
|
|
||||||
@log_function_time(print_only=True)
|
|
||||||
def filter_chunks(
|
|
||||||
query: SearchQuery,
|
|
||||||
chunks_to_filter: list[InferenceChunk],
|
|
||||||
) -> list[str]:
|
|
||||||
"""Filters chunks based on whether the LLM thought they were relevant to the query.
|
|
||||||
|
|
||||||
Returns a list of the unique chunk IDs that were marked as relevant"""
|
|
||||||
chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks]
|
|
||||||
llm_chunk_selection = llm_batch_eval_chunks(
|
|
||||||
query=query.query,
|
|
||||||
chunk_contents=[chunk.content for chunk in chunks_to_filter],
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
chunk.unique_id
|
|
||||||
for ind, chunk in enumerate(chunks_to_filter)
|
|
||||||
if llm_chunk_selection[ind]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def full_chunk_search(
|
|
||||||
query: SearchQuery,
|
|
||||||
document_index: DocumentIndex,
|
|
||||||
db_session: Session,
|
|
||||||
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
|
|
||||||
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
|
|
||||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
|
||||||
| None = None,
|
|
||||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
|
||||||
) -> tuple[list[InferenceChunk], list[bool]]:
|
|
||||||
"""A utility which provides an easier interface than `full_chunk_search_generator`.
|
|
||||||
Rather than returning the chunks and llm relevance filter results in two separate
|
|
||||||
yields, just returns them both at once."""
|
|
||||||
search_generator = full_chunk_search_generator(
|
|
||||||
search_query=query,
|
|
||||||
document_index=document_index,
|
|
||||||
db_session=db_session,
|
|
||||||
hybrid_alpha=hybrid_alpha,
|
|
||||||
multilingual_expansion_str=multilingual_expansion_str,
|
|
||||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
|
||||||
rerank_metrics_callback=rerank_metrics_callback,
|
|
||||||
)
|
|
||||||
top_chunks = cast(list[InferenceChunk], next(search_generator))
|
|
||||||
llm_chunk_selection = cast(list[bool], next(search_generator))
|
|
||||||
return top_chunks, llm_chunk_selection
|
|
||||||
|
|
||||||
|
|
||||||
def empty_search_generator() -> Iterator[list[InferenceChunk] | list[bool]]:
|
|
||||||
yield cast(list[InferenceChunk], [])
|
|
||||||
yield cast(list[bool], [])
|
|
||||||
|
|
||||||
|
|
||||||
def full_chunk_search_generator(
|
|
||||||
search_query: SearchQuery,
|
|
||||||
document_index: DocumentIndex,
|
|
||||||
db_session: Session,
|
|
||||||
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
|
|
||||||
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
|
|
||||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
|
||||||
| None = None,
|
|
||||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
|
||||||
) -> Iterator[list[InferenceChunk] | list[bool]]:
|
|
||||||
"""Always yields twice. Once with the selected chunks and once with the LLM relevance filter result.
|
|
||||||
If LLM filter results are turned off, returns a list of False
|
|
||||||
"""
|
|
||||||
chunks_yielded = False
|
|
||||||
|
|
||||||
retrieved_chunks = retrieve_chunks(
|
|
||||||
query=search_query,
|
|
||||||
document_index=document_index,
|
|
||||||
db_session=db_session,
|
|
||||||
hybrid_alpha=hybrid_alpha,
|
|
||||||
multilingual_expansion_str=multilingual_expansion_str,
|
|
||||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not retrieved_chunks:
|
|
||||||
yield cast(list[InferenceChunk], [])
|
|
||||||
yield cast(list[bool], [])
|
|
||||||
return
|
|
||||||
|
|
||||||
post_processing_tasks: list[FunctionCall] = []
|
|
||||||
|
|
||||||
rerank_task_id = None
|
|
||||||
if should_rerank(search_query):
|
|
||||||
post_processing_tasks.append(
|
|
||||||
FunctionCall(
|
|
||||||
rerank_chunks,
|
|
||||||
(
|
|
||||||
search_query,
|
|
||||||
retrieved_chunks,
|
|
||||||
rerank_metrics_callback,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
rerank_task_id = post_processing_tasks[-1].result_id
|
|
||||||
else:
|
|
||||||
final_chunks = retrieved_chunks
|
|
||||||
# NOTE: if we don't rerank, we can return the chunks immediately
|
|
||||||
# since we know this is the final order
|
|
||||||
_log_top_chunk_links(search_query.search_type.value, final_chunks)
|
|
||||||
yield final_chunks
|
|
||||||
chunks_yielded = True
|
|
||||||
|
|
||||||
llm_filter_task_id = None
|
|
||||||
if should_apply_llm_based_relevance_filter(search_query):
|
|
||||||
post_processing_tasks.append(
|
|
||||||
FunctionCall(
|
|
||||||
filter_chunks,
|
|
||||||
(search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
llm_filter_task_id = post_processing_tasks[-1].result_id
|
|
||||||
|
|
||||||
post_processing_results = (
|
|
||||||
run_functions_in_parallel(post_processing_tasks)
|
|
||||||
if post_processing_tasks
|
|
||||||
else {}
|
|
||||||
)
|
|
||||||
reranked_chunks = cast(
|
|
||||||
list[InferenceChunk] | None,
|
|
||||||
post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None,
|
|
||||||
)
|
|
||||||
if reranked_chunks:
|
|
||||||
if chunks_yielded:
|
|
||||||
logger.error(
|
|
||||||
"Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_log_top_chunk_links(search_query.search_type.value, reranked_chunks)
|
|
||||||
yield reranked_chunks
|
|
||||||
|
|
||||||
llm_chunk_selection = cast(
|
|
||||||
list[str] | None,
|
|
||||||
post_processing_results.get(str(llm_filter_task_id))
|
|
||||||
if llm_filter_task_id
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
if llm_chunk_selection is not None:
|
|
||||||
yield [
|
|
||||||
chunk.unique_id in llm_chunk_selection
|
|
||||||
for chunk in reranked_chunks or retrieved_chunks
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
yield [False for _ in reranked_chunks or retrieved_chunks]
|
|
||||||
|
|
||||||
|
|
||||||
def combine_inference_chunks(inf_chunks: list[InferenceChunk]) -> LlmDoc:
|
|
||||||
if not inf_chunks:
|
|
||||||
raise ValueError("Cannot combine empty list of chunks")
|
|
||||||
|
|
||||||
# Use the first link of the document
|
|
||||||
first_chunk = inf_chunks[0]
|
|
||||||
chunk_texts = [chunk.content for chunk in inf_chunks]
|
|
||||||
return LlmDoc(
|
|
||||||
document_id=first_chunk.document_id,
|
|
||||||
content="\n".join(chunk_texts),
|
|
||||||
semantic_identifier=first_chunk.semantic_identifier,
|
|
||||||
source_type=first_chunk.source_type,
|
|
||||||
metadata=first_chunk.metadata,
|
|
||||||
updated_at=first_chunk.updated_at,
|
|
||||||
link=first_chunk.source_links[0] if first_chunk.source_links else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def inference_documents_from_ids(
|
|
||||||
doc_identifiers: list[tuple[str, int]],
|
|
||||||
document_index: DocumentIndex,
|
|
||||||
) -> list[LlmDoc]:
|
|
||||||
# Currently only fetches whole docs
|
|
||||||
doc_ids_set = set(doc_id for doc_id, chunk_id in doc_identifiers)
|
|
||||||
|
|
||||||
# No need for ACL here because the doc ids were validated beforehand
|
|
||||||
filters = IndexFilters(access_control_list=None)
|
|
||||||
|
|
||||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
|
||||||
(document_index.id_based_retrieval, (doc_id, None, filters))
|
|
||||||
for doc_id in doc_ids_set
|
|
||||||
]
|
|
||||||
|
|
||||||
parallel_results = run_functions_tuples_in_parallel(
|
|
||||||
functions_with_args, allow_failures=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Any failures to retrieve would give a None, drop the Nones and empty lists
|
|
||||||
inference_chunks_sets = [res for res in parallel_results if res]
|
|
||||||
|
|
||||||
return [combine_inference_chunks(chunk_set) for chunk_set in inference_chunks_sets]
|
|
29
backend/danswer/search/utils.py
Normal file
29
backend/danswer/search/utils.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from danswer.indexing.models import InferenceChunk
|
||||||
|
from danswer.search.models import SearchDoc
|
||||||
|
|
||||||
|
|
||||||
|
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
|
||||||
|
search_docs = (
|
||||||
|
[
|
||||||
|
SearchDoc(
|
||||||
|
document_id=chunk.document_id,
|
||||||
|
chunk_ind=chunk.chunk_id,
|
||||||
|
semantic_identifier=chunk.semantic_identifier or "Unknown",
|
||||||
|
link=chunk.source_links.get(0) if chunk.source_links else None,
|
||||||
|
blurb=chunk.blurb,
|
||||||
|
source_type=chunk.source_type,
|
||||||
|
boost=chunk.boost,
|
||||||
|
hidden=chunk.hidden,
|
||||||
|
metadata=chunk.metadata,
|
||||||
|
score=chunk.score,
|
||||||
|
match_highlights=chunk.match_highlights,
|
||||||
|
updated_at=chunk.updated_at,
|
||||||
|
primary_owners=chunk.primary_owners,
|
||||||
|
secondary_owners=chunk.secondary_owners,
|
||||||
|
)
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
if chunks
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
return search_docs
|
@@ -11,8 +11,8 @@ from danswer.db.models import User
|
|||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
from danswer.llm.utils import get_default_llm_token_encode
|
from danswer.llm.utils import get_default_llm_token_encode
|
||||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||||
from danswer.search.access_filters import build_access_filters_for_user
|
|
||||||
from danswer.search.models import IndexFilters
|
from danswer.search.models import IndexFilters
|
||||||
|
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
|
||||||
from danswer.server.documents.models import ChunkInfo
|
from danswer.server.documents.models import ChunkInfo
|
||||||
from danswer.server.documents.models import DocumentInfo
|
from danswer.server.documents.models import DocumentInfo
|
||||||
|
|
||||||
|
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from danswer.db.models import Persona
|
from danswer.db.models import Persona
|
||||||
from danswer.db.models import StarterMessage
|
from danswer.db.models import StarterMessage
|
||||||
from danswer.search.models import RecencyBiasSetting
|
from danswer.search.enums import RecencyBiasSetting
|
||||||
from danswer.server.features.document_set.models import DocumentSet
|
from danswer.server.features.document_set.models import DocumentSet
|
||||||
from danswer.server.features.prompt.models import PromptSnapshot
|
from danswer.server.features.prompt.models import PromptSnapshot
|
||||||
|
|
||||||
|
@@ -6,13 +6,9 @@ from fastapi import Depends
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.search.models import SearchRequest
|
||||||
from danswer.search.access_filters import build_access_filters_for_user
|
from danswer.search.pipeline import SearchPipeline
|
||||||
from danswer.search.models import IndexFilters
|
|
||||||
from danswer.search.models import SearchQuery
|
|
||||||
from danswer.search.search_runner import full_chunk_search
|
|
||||||
from danswer.server.danswer_api.ingestion import api_key_dep
|
from danswer.server.danswer_api.ingestion import api_key_dep
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
@@ -70,27 +66,13 @@ def gpt_search(
|
|||||||
_: str | None = Depends(api_key_dep),
|
_: str | None = Depends(api_key_dep),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> GptSearchResponse:
|
) -> GptSearchResponse:
|
||||||
query = search_request.query
|
top_chunks = SearchPipeline(
|
||||||
|
search_request=SearchRequest(
|
||||||
user_acl_filters = build_access_filters_for_user(None, db_session)
|
query=search_request.query,
|
||||||
final_filters = IndexFilters(access_control_list=user_acl_filters)
|
),
|
||||||
|
user=None,
|
||||||
search_query = SearchQuery(
|
db_session=db_session,
|
||||||
query=query,
|
).reranked_docs
|
||||||
filters=final_filters,
|
|
||||||
recency_bias_multiplier=1.0,
|
|
||||||
skip_llm_chunk_filter=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
embedding_model = get_current_db_embedding_model(db_session)
|
|
||||||
|
|
||||||
document_index = get_default_document_index(
|
|
||||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
|
||||||
)
|
|
||||||
|
|
||||||
top_chunks, __ = full_chunk_search(
|
|
||||||
query=search_query, document_index=document_index, db_session=db_session
|
|
||||||
)
|
|
||||||
|
|
||||||
return GptSearchResponse(
|
return GptSearchResponse(
|
||||||
matching_document_chunks=[
|
matching_document_chunks=[
|
||||||
|
@@ -15,11 +15,11 @@ from danswer.document_index.factory import get_default_document_index
|
|||||||
from danswer.document_index.vespa.index import VespaIndex
|
from danswer.document_index.vespa.index import VespaIndex
|
||||||
from danswer.one_shot_answer.answer_question import stream_search_answer
|
from danswer.one_shot_answer.answer_question import stream_search_answer
|
||||||
from danswer.one_shot_answer.models import DirectQARequest
|
from danswer.one_shot_answer.models import DirectQARequest
|
||||||
from danswer.search.access_filters import build_access_filters_for_user
|
|
||||||
from danswer.search.danswer_helper import recommend_search_flow
|
|
||||||
from danswer.search.models import IndexFilters
|
from danswer.search.models import IndexFilters
|
||||||
from danswer.search.models import SearchDoc
|
from danswer.search.models import SearchDoc
|
||||||
from danswer.search.search_runner import chunks_to_search_docs
|
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
|
||||||
|
from danswer.search.preprocessing.danswer_helper import recommend_search_flow
|
||||||
|
from danswer.search.utils import chunks_to_search_docs
|
||||||
from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
||||||
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
||||||
from danswer.server.query_and_chat.models import AdminSearchRequest
|
from danswer.server.query_and_chat.models import AdminSearchRequest
|
||||||
|
@@ -8,15 +8,12 @@ from typing import TextIO
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.chat.chat_utils import get_chunks_for_qa
|
from danswer.chat.chat_utils import get_chunks_for_qa
|
||||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
from danswer.document_index.factory import get_default_document_index
|
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
from danswer.search.models import IndexFilters
|
|
||||||
from danswer.search.models import RerankMetricsContainer
|
from danswer.search.models import RerankMetricsContainer
|
||||||
from danswer.search.models import RetrievalMetricsContainer
|
from danswer.search.models import RetrievalMetricsContainer
|
||||||
from danswer.search.models import SearchQuery
|
from danswer.search.models import SearchRequest
|
||||||
from danswer.search.search_runner import full_chunk_search
|
from danswer.search.pipeline import SearchPipeline
|
||||||
from danswer.utils.callbacks import MetricsHander
|
from danswer.utils.callbacks import MetricsHander
|
||||||
|
|
||||||
|
|
||||||
@@ -81,35 +78,22 @@ def get_search_results(
|
|||||||
RetrievalMetricsContainer | None,
|
RetrievalMetricsContainer | None,
|
||||||
RerankMetricsContainer | None,
|
RerankMetricsContainer | None,
|
||||||
]:
|
]:
|
||||||
filters = IndexFilters(
|
|
||||||
source_type=None,
|
|
||||||
document_set=None,
|
|
||||||
time_cutoff=None,
|
|
||||||
access_control_list=None,
|
|
||||||
)
|
|
||||||
search_query = SearchQuery(
|
|
||||||
query=query,
|
|
||||||
filters=filters,
|
|
||||||
recency_bias_multiplier=1.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
|
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
|
||||||
rerank_metrics = MetricsHander[RerankMetricsContainer]()
|
rerank_metrics = MetricsHander[RerankMetricsContainer]()
|
||||||
|
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
embedding_model = get_current_db_embedding_model(db_session)
|
search_pipeline = SearchPipeline(
|
||||||
|
search_request=SearchRequest(
|
||||||
|
query=query,
|
||||||
|
),
|
||||||
|
user=None,
|
||||||
|
db_session=db_session,
|
||||||
|
retrieval_metrics_callback=retrieval_metrics.record_metric,
|
||||||
|
rerank_metrics_callback=rerank_metrics.record_metric,
|
||||||
|
)
|
||||||
|
|
||||||
document_index = get_default_document_index(
|
top_chunks = search_pipeline.reranked_docs
|
||||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
llm_chunk_selection = search_pipeline.chunk_relevance_list
|
||||||
)
|
|
||||||
|
|
||||||
top_chunks, llm_chunk_selection = full_chunk_search(
|
|
||||||
query=search_query,
|
|
||||||
document_index=document_index,
|
|
||||||
db_session=db_session,
|
|
||||||
retrieval_metrics_callback=retrieval_metrics.record_metric,
|
|
||||||
rerank_metrics_callback=rerank_metrics.record_metric,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_chunks_indices = get_chunks_for_qa(
|
llm_chunks_indices = get_chunks_for_qa(
|
||||||
chunks=top_chunks,
|
chunks=top_chunks,
|
||||||
|
Reference in New Issue
Block a user