diff --git a/backend/danswer/chat/chat_llm.py b/backend/danswer/chat/chat_llm.py index f3b4d7cd65..e067b73da0 100644 --- a/backend/danswer/chat/chat_llm.py +++ b/backend/danswer/chat/chat_llm.py @@ -42,7 +42,7 @@ from danswer.search.models import IndexFilters from danswer.search.models import SearchQuery from danswer.search.models import SearchType from danswer.search.search_runner import chunks_to_search_docs -from danswer.search.search_runner import search_chunks +from danswer.search.search_runner import full_chunk_search from danswer.server.models import RetrievalDocs from danswer.utils.logger import setup_logger from danswer.utils.text_processing import extract_embedded_json @@ -140,8 +140,9 @@ def danswer_chat_retrieval( ) # Good Debug/Breakpoint - top_chunks, _ = search_chunks( - query=search_query, document_index=get_default_document_index() + top_chunks, _ = full_chunk_search( + query=search_query, + document_index=get_default_document_index(), ) if not top_chunks: diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index b8c02d0e15..c94d3225ce 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -1,6 +1,7 @@ from collections.abc import Callable from collections.abc import Iterator from functools import partial +from typing import cast from sqlalchemy.orm import Session @@ -15,15 +16,18 @@ from danswer.direct_qa.interfaces import StreamingError from danswer.direct_qa.models import LLMMetricsContainer from danswer.direct_qa.qa_utils import get_chunks_for_qa from danswer.document_index.factory import get_default_document_index +from danswer.indexing.models import InferenceChunk from danswer.search.danswer_helper import query_intent from danswer.search.models import QueryFlow from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer from danswer.search.search_runner import chunks_to_search_docs from danswer.search.search_runner import danswer_search +from danswer.search.search_runner import danswer_search_generator from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.source_filter import extract_question_source_filters from danswer.secondary_llm_flows.time_filter import extract_question_time_filters +from danswer.server.models import LLMRelevanceFilterResponse from danswer.server.models import QADocsResponse from danswer.server.models import QAResponse from danswer.server.models import QuestionRequest @@ -206,24 +210,19 @@ def answer_qa_query_stream( question.favor_recent = favor_recent question.filters.source_type = source_filters - top_chunks, llm_chunk_selection, query_event_id = danswer_search( + search_generator = danswer_search_generator( question=question, user=user, db_session=db_session, document_index=get_default_document_index(), ) + # first fetch and return to the UI the top chunks so the user can + # immediately see some results + top_chunks = cast(list[InferenceChunk], next(search_generator)) top_docs = chunks_to_search_docs(top_chunks) - - llm_chunks_indices = get_chunks_for_qa( - chunks=top_chunks, - llm_chunk_selection=llm_chunk_selection, - batch_offset=offset_count, - ) - initial_response = QADocsResponse( top_documents=top_docs, - llm_chunks_indices=llm_chunks_indices, # if generative AI is disabled, set flow as search so frontend # doesn't ask the user if they want to run QA over more documents predicted_flow=QueryFlow.SEARCH @@ -233,13 +232,29 @@ def answer_qa_query_stream( time_cutoff=time_cutoff, favor_recent=favor_recent, ).dict() - yield get_json_line(initial_response) if not top_chunks: logger.debug("No Documents Found") return + # next apply the LLM filtering + llm_chunk_selection = cast(list[bool], next(search_generator)) + llm_chunks_indices = get_chunks_for_qa( + chunks=top_chunks, + llm_chunk_selection=llm_chunk_selection, + batch_offset=offset_count, + ) + llm_relevance_filtering_response = LLMRelevanceFilterResponse( + relevant_chunk_indices=llm_chunks_indices + ).dict() + yield get_json_line(llm_relevance_filtering_response) + + # finally get the query ID from the search generator for updating the + # row in Postgres. This is the end of the `search_generator` - any future + # calls to `next` will raise StopIteration + query_event_id = cast(int, next(search_generator)) + if disable_generative_answer: logger.debug("Skipping QA because generative AI is disabled") return diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index 934646b027..8c9bc0e60c 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -94,6 +94,10 @@ class InferenceChunk(BaseChunk): # when the doc was last updated updated_at: datetime | None + @property + def unique_id(self) -> str: + return f"{self.document_id}__{self.chunk_id}" + def __repr__(self) -> str: blurb_words = self.blurb.split() short_blurb = "" diff --git a/backend/danswer/search/search_runner.py b/backend/danswer/search/search_runner.py index 92e640c9d4..2a6097b544 100644 --- a/backend/danswer/search/search_runner.py +++ b/backend/danswer/search/search_runner.py @@ -1,6 +1,6 @@ from collections.abc import Callable +from collections.abc import Generator from copy import deepcopy -from typing import Any from typing import cast import numpy @@ -50,6 +50,13 @@ from danswer.utils.timing import log_function_time logger = setup_logger() +def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None: + top_links = [ + c.source_links[0] if c.source_links is not None else "No Link" for c in chunks + ] + logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") + + def lemmatize_text(text: str) -> list[str]: lemmatizer = WordNetLemmatizer() word_tokens = word_tokenize(text) @@ -329,29 +336,15 @@ def apply_boost( return final_chunks -def search_chunks( +def retrieve_chunks( query: SearchQuery, document_index: DocumentIndex, hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> tuple[list[InferenceChunk], list[bool]]: - """Returns a list of the best chunks from search/reranking and if the chunks are relevant via LLM. - For sake of speed, the system cannot rerank all retrieved chunks - Also pass the chunks through LLM to determine if they are relevant (binary for speed) - - Only the first max_llm_filter_chunks - """ - - def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None: - top_links = [ - c.source_links[0] if c.source_links is not None else "No Link" - for c in chunks - ] - logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") - +) -> list[InferenceChunk]: + """Returns a list of the best chunks from an initial keyword/semantic/ hybrid search.""" # Don't do query expansion on complex queries, rephrasings likely would not work well if not multilingual_query_expansion or "\n" in query.query or "\r" in query.query: top_chunks = doc_index_retrieval( @@ -377,7 +370,7 @@ def search_chunks( f"{query.search_type.value.capitalize()} search returned no results " f"with filters: {query.filters}" ) - return [], [] + return [] if retrieval_metrics_callback is not None: chunk_metrics = [ @@ -393,67 +386,160 @@ def search_chunks( RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics) ) - functions_to_run: list[FunctionCall] = [] + return top_chunks - # Keyword Search should not do reranking - if query.search_type == SearchType.KEYWORD or query.skip_rerank: - _log_top_chunk_links(query.search_type.value, top_chunks) - run_rerank_id: str | None = None - else: - run_rerank = FunctionCall( - semantic_reranking, - (query.query, top_chunks[: query.num_rerank]), - {"rerank_metrics_callback": rerank_metrics_callback}, + +def should_rerank(query: SearchQuery) -> bool: + # don't re-rank for keyword search + return query.search_type != SearchType.KEYWORD and not query.skip_rerank + + +def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool: + return not query.skip_llm_chunk_filter + + +def rerank_chunks( + query: SearchQuery, + chunks_to_rerank: list[InferenceChunk], + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> list[InferenceChunk]: + ranked_chunks, _ = semantic_reranking( + query=query.query, + chunks=chunks_to_rerank[: query.num_rerank], + rerank_metrics_callback=rerank_metrics_callback, + ) + lower_chunks = chunks_to_rerank[query.num_rerank :] + # Scores from rerank cannot be meaningfully combined with scores without rerank + for lower_chunk in lower_chunks: + lower_chunk.score = None + ranked_chunks.extend(lower_chunks) + return ranked_chunks + + +def filter_chunks( + query: SearchQuery, + chunks_to_filter: list[InferenceChunk], +) -> list[str]: + """Filters chunks based on whether the LLM thought they were relevant to the query. + + Returns a list of the unique chunk IDs that were marked as relevant""" + chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks] + llm_chunk_selection = llm_batch_eval_chunks( + query=query.query, + chunk_contents=[chunk.content for chunk in chunks_to_filter], + ) + return [ + chunk.unique_id + for ind, chunk in enumerate(chunks_to_filter) + if llm_chunk_selection[ind] + ] + + +def full_chunk_search( + query: SearchQuery, + document_index: DocumentIndex, + hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search + multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION, + retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] + | None = None, + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> tuple[list[InferenceChunk], list[bool]]: + """A utility which provides an easier interface than `full_chunk_search_generator`. + Rather than returning the chunks and llm relevance filter results in two separate + yields, just returns them both at once.""" + search_generator = full_chunk_search_generator( + query=query, + document_index=document_index, + hybrid_alpha=hybrid_alpha, + multilingual_query_expansion=multilingual_query_expansion, + retrieval_metrics_callback=retrieval_metrics_callback, + rerank_metrics_callback=rerank_metrics_callback, + ) + top_chunks = cast(list[InferenceChunk], next(search_generator)) + llm_chunk_selection = cast(list[bool], next(search_generator)) + return top_chunks, llm_chunk_selection + + +def full_chunk_search_generator( + query: SearchQuery, + document_index: DocumentIndex, + hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search + multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION, + retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] + | None = None, + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> Generator[list[InferenceChunk] | list[bool], None, None]: + """Always yields twice. Once with the selected chunks and once with the LLM relevance filter result.""" + chunks_yielded = False + + retrieved_chunks = retrieve_chunks( + query=query, + document_index=document_index, + hybrid_alpha=hybrid_alpha, + multilingual_query_expansion=multilingual_query_expansion, + retrieval_metrics_callback=retrieval_metrics_callback, + ) + + post_processing_tasks: list[FunctionCall] = [] + + rerank_task_id = None + if should_rerank(query): + post_processing_tasks.append( + FunctionCall( + rerank_chunks, + ( + query, + retrieved_chunks, + rerank_metrics_callback, + ), + ) ) - functions_to_run.append(run_rerank) - run_rerank_id = run_rerank.result_id - - run_llm_filter_id = None - if not query.skip_llm_chunk_filter: - run_llm_filter = FunctionCall( - llm_batch_eval_chunks, - ( - query.query, - [chunk.content for chunk in top_chunks[: query.max_llm_filter_chunks]], - ), - {}, - ) - functions_to_run.append(run_llm_filter) - run_llm_filter_id = run_llm_filter.result_id - - parallel_results: dict[str, Any] = {} - if functions_to_run: - parallel_results = run_functions_in_parallel(functions_to_run) - - ranked_results = parallel_results.get(str(run_rerank_id)) - if ranked_results is None: - ranked_chunks = top_chunks - sorted_indices = [i for i in range(len(top_chunks))] + rerank_task_id = post_processing_tasks[-1].result_id else: - ranked_chunks, orig_indices = ranked_results - sorted_indices = orig_indices + list(range(len(orig_indices), len(top_chunks))) - lower_chunks = top_chunks[query.num_rerank :] - # Scores from rerank cannot be meaningfully combined with scores without rerank - for lower_chunk in lower_chunks: - lower_chunk.score = None - ranked_chunks.extend(lower_chunks) + final_chunks = retrieved_chunks + # NOTE: if we don't rerank, we can return the chunks immediately + # since we know this is the final order + _log_top_chunk_links(query.search_type.value, final_chunks) + yield final_chunks + chunks_yielded = True - llm_chunk_selection = parallel_results.get(str(run_llm_filter_id)) - if llm_chunk_selection is None: - reranked_llm_chunk_selection = [True for _ in top_chunks] - else: - llm_chunk_selection.extend( - [False for _ in top_chunks[query.max_llm_filter_chunks :]] + llm_filter_task_id = None + if should_apply_llm_based_relevance_filter(query): + post_processing_tasks.append( + FunctionCall( + filter_chunks, + (query, retrieved_chunks[: query.max_llm_filter_chunks]), + ) ) - reranked_llm_chunk_selection = [ - llm_chunk_selection[ind] for ind in sorted_indices - ] - _log_top_chunk_links(query.search_type.value, ranked_chunks) + llm_filter_task_id = post_processing_tasks[-1].result_id - return ranked_chunks, reranked_llm_chunk_selection + post_processing_results = run_functions_in_parallel(post_processing_tasks) + reranked_chunks = cast( + list[InferenceChunk] | None, + post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None, + ) + if reranked_chunks: + if chunks_yielded: + logger.error( + "Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen." + ) + else: + _log_top_chunk_links(query.search_type.value, reranked_chunks) + yield reranked_chunks + + llm_chunk_selection = cast( + list[str] | None, + post_processing_results.get(str(llm_filter_task_id)) + if llm_filter_task_id + else None, + ) + if llm_chunk_selection: + yield [chunk.unique_id in llm_chunk_selection for chunk in retrieved_chunks] + else: + yield [True for _ in reranked_chunks or retrieved_chunks] -def danswer_search( +def danswer_search_generator( question: QuestionRequest, user: User | None, db_session: Session, @@ -462,7 +548,11 @@ def danswer_search( retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> tuple[list[InferenceChunk], list[bool], int]: +) -> Generator[list[InferenceChunk] | list[bool] | int, None, None]: + """The main entry point for search. This fetches the relevant documents from Vespa + based on the provided query (applying permissions / filters), does any specified + post-processing, and returns the results. It also create an entry in the query_event table + for this search event.""" query_event_id = create_query_event( query=question.query, search_type=question.search_type, @@ -490,20 +580,54 @@ def danswer_search( skip_llm_chunk_filter=skip_llm_chunk_filter, ) - top_chunks, llm_chunk_selection = search_chunks( + search_generator = full_chunk_search_generator( query=search_query, document_index=document_index, retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, ) + top_chunks = cast(list[InferenceChunk], next(search_generator)) + yield top_chunks - retrieved_ids = [doc.document_id for doc in top_chunks] if top_chunks else [] + llm_chunk_selection = cast(list[bool], next(search_generator)) + yield llm_chunk_selection update_query_event_retrieved_documents( db_session=db_session, - retrieved_document_ids=retrieved_ids, + retrieved_document_ids=[doc.document_id for doc in top_chunks] + if top_chunks + else [], query_id=query_event_id, user_id=None if user is None else user.id, ) + yield query_event_id + +def danswer_search( + question: QuestionRequest, + user: User | None, + db_session: Session, + document_index: DocumentIndex, + skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, + retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] + | None = None, + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> tuple[list[InferenceChunk], list[bool], int]: + """Returns a tuple of the top chunks, the LLM relevance filter results, and the query event ID. + + Presents a simpler interface than the underlying `danswer_search_generator`, as callers no + longer need to worry about the order / have nicer typing. This should be used for flows which + do not require streaming.""" + search_generator = danswer_search_generator( + question=question, + user=user, + db_session=db_session, + document_index=document_index, + skip_llm_chunk_filter=skip_llm_chunk_filter, + retrieval_metrics_callback=retrieval_metrics_callback, + rerank_metrics_callback=rerank_metrics_callback, + ) + top_chunks = cast(list[InferenceChunk], next(search_generator)) + llm_chunk_selection = cast(list[bool], next(search_generator)) + query_event_id = cast(int, next(search_generator)) return top_chunks, llm_chunk_selection, query_event_id diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 30f3e41af9..57cb7386be 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -222,7 +222,6 @@ class QAResponse(SearchResponse): # First chunk of info for streaming QA class QADocsResponse(RetrievalDocs): - llm_chunks_indices: list[int] predicted_flow: QueryFlow predicted_search: SearchType time_cutoff: datetime | None @@ -236,6 +235,11 @@ class QADocsResponse(RetrievalDocs): return initial_dict +# second chunk of info for streaming QA +class LLMRelevanceFilterResponse(BaseModel): + relevant_chunk_indices: list[int] + + class CreateChatSessionID(BaseModel): chat_session_id: int diff --git a/web/src/components/HoverPopup.tsx b/web/src/components/HoverPopup.tsx index 8263d52d1a..886cde0238 100644 --- a/web/src/components/HoverPopup.tsx +++ b/web/src/components/HoverPopup.tsx @@ -5,6 +5,7 @@ interface HoverPopupProps { popupContent: string | JSX.Element; classNameModifications?: string; direction?: "left" | "bottom"; + style?: "basic" | "dark"; } export const HoverPopup = ({ @@ -12,6 +13,7 @@ export const HoverPopup = ({ popupContent, classNameModifications, direction = "bottom", + style = "basic", }: HoverPopupProps) => { const [hovered, setHovered] = useState(false); @@ -37,7 +39,10 @@ export const HoverPopup = ({