mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 03:58:30 +02:00
Add selected docs in UI + rework the backend flow a bit(#754)
Changes the flow so that the selected docs are sent over in a separate packet rather than as part of the initial packet for the streaming QA endpoint.
This commit is contained in:
@@ -42,7 +42,7 @@ from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import search_chunks
|
||||
from danswer.search.search_runner import full_chunk_search
|
||||
from danswer.server.models import RetrievalDocs
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import extract_embedded_json
|
||||
@@ -140,8 +140,9 @@ def danswer_chat_retrieval(
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
top_chunks, _ = search_chunks(
|
||||
query=search_query, document_index=get_default_document_index()
|
||||
top_chunks, _ = full_chunk_search(
|
||||
query=search_query,
|
||||
document_index=get_default_document_index(),
|
||||
)
|
||||
|
||||
if not top_chunks:
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -15,15 +16,18 @@ from danswer.direct_qa.interfaces import StreamingError
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_utils import get_chunks_for_qa
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import danswer_search
|
||||
from danswer.search.search_runner import danswer_search_generator
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
|
||||
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
|
||||
from danswer.server.models import LLMRelevanceFilterResponse
|
||||
from danswer.server.models import QADocsResponse
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
@@ -206,24 +210,19 @@ def answer_qa_query_stream(
|
||||
question.favor_recent = favor_recent
|
||||
question.filters.source_type = source_filters
|
||||
|
||||
top_chunks, llm_chunk_selection, query_event_id = danswer_search(
|
||||
search_generator = danswer_search_generator(
|
||||
question=question,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
document_index=get_default_document_index(),
|
||||
)
|
||||
|
||||
# first fetch and return to the UI the top chunks so the user can
|
||||
# immediately see some results
|
||||
top_chunks = cast(list[InferenceChunk], next(search_generator))
|
||||
top_docs = chunks_to_search_docs(top_chunks)
|
||||
|
||||
llm_chunks_indices = get_chunks_for_qa(
|
||||
chunks=top_chunks,
|
||||
llm_chunk_selection=llm_chunk_selection,
|
||||
batch_offset=offset_count,
|
||||
)
|
||||
|
||||
initial_response = QADocsResponse(
|
||||
top_documents=top_docs,
|
||||
llm_chunks_indices=llm_chunks_indices,
|
||||
# if generative AI is disabled, set flow as search so frontend
|
||||
# doesn't ask the user if they want to run QA over more documents
|
||||
predicted_flow=QueryFlow.SEARCH
|
||||
@@ -233,13 +232,29 @@ def answer_qa_query_stream(
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
).dict()
|
||||
|
||||
yield get_json_line(initial_response)
|
||||
|
||||
if not top_chunks:
|
||||
logger.debug("No Documents Found")
|
||||
return
|
||||
|
||||
# next apply the LLM filtering
|
||||
llm_chunk_selection = cast(list[bool], next(search_generator))
|
||||
llm_chunks_indices = get_chunks_for_qa(
|
||||
chunks=top_chunks,
|
||||
llm_chunk_selection=llm_chunk_selection,
|
||||
batch_offset=offset_count,
|
||||
)
|
||||
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=llm_chunks_indices
|
||||
).dict()
|
||||
yield get_json_line(llm_relevance_filtering_response)
|
||||
|
||||
# finally get the query ID from the search generator for updating the
|
||||
# row in Postgres. This is the end of the `search_generator` - any future
|
||||
# calls to `next` will raise StopIteration
|
||||
query_event_id = cast(int, next(search_generator))
|
||||
|
||||
if disable_generative_answer:
|
||||
logger.debug("Skipping QA because generative AI is disabled")
|
||||
return
|
||||
|
@@ -94,6 +94,10 @@ class InferenceChunk(BaseChunk):
|
||||
# when the doc was last updated
|
||||
updated_at: datetime | None
|
||||
|
||||
@property
|
||||
def unique_id(self) -> str:
|
||||
return f"{self.document_id}__{self.chunk_id}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
blurb_words = self.blurb.split()
|
||||
short_blurb = ""
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import numpy
|
||||
@@ -50,6 +50,13 @@ from danswer.utils.timing import log_function_time
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
|
||||
top_links = [
|
||||
c.source_links[0] if c.source_links is not None else "No Link" for c in chunks
|
||||
]
|
||||
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
||||
|
||||
|
||||
def lemmatize_text(text: str) -> list[str]:
|
||||
lemmatizer = WordNetLemmatizer()
|
||||
word_tokens = word_tokenize(text)
|
||||
@@ -329,29 +336,15 @@ def apply_boost(
|
||||
return final_chunks
|
||||
|
||||
|
||||
def search_chunks(
|
||||
def retrieve_chunks(
|
||||
query: SearchQuery,
|
||||
document_index: DocumentIndex,
|
||||
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
|
||||
multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> tuple[list[InferenceChunk], list[bool]]:
|
||||
"""Returns a list of the best chunks from search/reranking and if the chunks are relevant via LLM.
|
||||
For sake of speed, the system cannot rerank all retrieved chunks
|
||||
Also pass the chunks through LLM to determine if they are relevant (binary for speed)
|
||||
|
||||
Only the first max_llm_filter_chunks
|
||||
"""
|
||||
|
||||
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
|
||||
top_links = [
|
||||
c.source_links[0] if c.source_links is not None else "No Link"
|
||||
for c in chunks
|
||||
]
|
||||
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
||||
|
||||
) -> list[InferenceChunk]:
|
||||
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
|
||||
# Don't do query expansion on complex queries, rephrasings likely would not work well
|
||||
if not multilingual_query_expansion or "\n" in query.query or "\r" in query.query:
|
||||
top_chunks = doc_index_retrieval(
|
||||
@@ -377,7 +370,7 @@ def search_chunks(
|
||||
f"{query.search_type.value.capitalize()} search returned no results "
|
||||
f"with filters: {query.filters}"
|
||||
)
|
||||
return [], []
|
||||
return []
|
||||
|
||||
if retrieval_metrics_callback is not None:
|
||||
chunk_metrics = [
|
||||
@@ -393,67 +386,160 @@ def search_chunks(
|
||||
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
|
||||
)
|
||||
|
||||
functions_to_run: list[FunctionCall] = []
|
||||
return top_chunks
|
||||
|
||||
# Keyword Search should not do reranking
|
||||
if query.search_type == SearchType.KEYWORD or query.skip_rerank:
|
||||
_log_top_chunk_links(query.search_type.value, top_chunks)
|
||||
run_rerank_id: str | None = None
|
||||
else:
|
||||
run_rerank = FunctionCall(
|
||||
semantic_reranking,
|
||||
(query.query, top_chunks[: query.num_rerank]),
|
||||
{"rerank_metrics_callback": rerank_metrics_callback},
|
||||
|
||||
def should_rerank(query: SearchQuery) -> bool:
|
||||
# don't re-rank for keyword search
|
||||
return query.search_type != SearchType.KEYWORD and not query.skip_rerank
|
||||
|
||||
|
||||
def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool:
|
||||
return not query.skip_llm_chunk_filter
|
||||
|
||||
|
||||
def rerank_chunks(
|
||||
query: SearchQuery,
|
||||
chunks_to_rerank: list[InferenceChunk],
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
ranked_chunks, _ = semantic_reranking(
|
||||
query=query.query,
|
||||
chunks=chunks_to_rerank[: query.num_rerank],
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
lower_chunks = chunks_to_rerank[query.num_rerank :]
|
||||
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
||||
for lower_chunk in lower_chunks:
|
||||
lower_chunk.score = None
|
||||
ranked_chunks.extend(lower_chunks)
|
||||
return ranked_chunks
|
||||
|
||||
|
||||
def filter_chunks(
|
||||
query: SearchQuery,
|
||||
chunks_to_filter: list[InferenceChunk],
|
||||
) -> list[str]:
|
||||
"""Filters chunks based on whether the LLM thought they were relevant to the query.
|
||||
|
||||
Returns a list of the unique chunk IDs that were marked as relevant"""
|
||||
chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks]
|
||||
llm_chunk_selection = llm_batch_eval_chunks(
|
||||
query=query.query,
|
||||
chunk_contents=[chunk.content for chunk in chunks_to_filter],
|
||||
)
|
||||
return [
|
||||
chunk.unique_id
|
||||
for ind, chunk in enumerate(chunks_to_filter)
|
||||
if llm_chunk_selection[ind]
|
||||
]
|
||||
|
||||
|
||||
def full_chunk_search(
|
||||
query: SearchQuery,
|
||||
document_index: DocumentIndex,
|
||||
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
|
||||
multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> tuple[list[InferenceChunk], list[bool]]:
|
||||
"""A utility which provides an easier interface than `full_chunk_search_generator`.
|
||||
Rather than returning the chunks and llm relevance filter results in two separate
|
||||
yields, just returns them both at once."""
|
||||
search_generator = full_chunk_search_generator(
|
||||
query=query,
|
||||
document_index=document_index,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
multilingual_query_expansion=multilingual_query_expansion,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
top_chunks = cast(list[InferenceChunk], next(search_generator))
|
||||
llm_chunk_selection = cast(list[bool], next(search_generator))
|
||||
return top_chunks, llm_chunk_selection
|
||||
|
||||
|
||||
def full_chunk_search_generator(
|
||||
query: SearchQuery,
|
||||
document_index: DocumentIndex,
|
||||
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
|
||||
multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> Generator[list[InferenceChunk] | list[bool], None, None]:
|
||||
"""Always yields twice. Once with the selected chunks and once with the LLM relevance filter result."""
|
||||
chunks_yielded = False
|
||||
|
||||
retrieved_chunks = retrieve_chunks(
|
||||
query=query,
|
||||
document_index=document_index,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
multilingual_query_expansion=multilingual_query_expansion,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
)
|
||||
|
||||
post_processing_tasks: list[FunctionCall] = []
|
||||
|
||||
rerank_task_id = None
|
||||
if should_rerank(query):
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
rerank_chunks,
|
||||
(
|
||||
query,
|
||||
retrieved_chunks,
|
||||
rerank_metrics_callback,
|
||||
),
|
||||
)
|
||||
)
|
||||
functions_to_run.append(run_rerank)
|
||||
run_rerank_id = run_rerank.result_id
|
||||
|
||||
run_llm_filter_id = None
|
||||
if not query.skip_llm_chunk_filter:
|
||||
run_llm_filter = FunctionCall(
|
||||
llm_batch_eval_chunks,
|
||||
(
|
||||
query.query,
|
||||
[chunk.content for chunk in top_chunks[: query.max_llm_filter_chunks]],
|
||||
),
|
||||
{},
|
||||
)
|
||||
functions_to_run.append(run_llm_filter)
|
||||
run_llm_filter_id = run_llm_filter.result_id
|
||||
|
||||
parallel_results: dict[str, Any] = {}
|
||||
if functions_to_run:
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
|
||||
ranked_results = parallel_results.get(str(run_rerank_id))
|
||||
if ranked_results is None:
|
||||
ranked_chunks = top_chunks
|
||||
sorted_indices = [i for i in range(len(top_chunks))]
|
||||
rerank_task_id = post_processing_tasks[-1].result_id
|
||||
else:
|
||||
ranked_chunks, orig_indices = ranked_results
|
||||
sorted_indices = orig_indices + list(range(len(orig_indices), len(top_chunks)))
|
||||
lower_chunks = top_chunks[query.num_rerank :]
|
||||
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
||||
for lower_chunk in lower_chunks:
|
||||
lower_chunk.score = None
|
||||
ranked_chunks.extend(lower_chunks)
|
||||
final_chunks = retrieved_chunks
|
||||
# NOTE: if we don't rerank, we can return the chunks immediately
|
||||
# since we know this is the final order
|
||||
_log_top_chunk_links(query.search_type.value, final_chunks)
|
||||
yield final_chunks
|
||||
chunks_yielded = True
|
||||
|
||||
llm_chunk_selection = parallel_results.get(str(run_llm_filter_id))
|
||||
if llm_chunk_selection is None:
|
||||
reranked_llm_chunk_selection = [True for _ in top_chunks]
|
||||
else:
|
||||
llm_chunk_selection.extend(
|
||||
[False for _ in top_chunks[query.max_llm_filter_chunks :]]
|
||||
llm_filter_task_id = None
|
||||
if should_apply_llm_based_relevance_filter(query):
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
filter_chunks,
|
||||
(query, retrieved_chunks[: query.max_llm_filter_chunks]),
|
||||
)
|
||||
)
|
||||
reranked_llm_chunk_selection = [
|
||||
llm_chunk_selection[ind] for ind in sorted_indices
|
||||
]
|
||||
_log_top_chunk_links(query.search_type.value, ranked_chunks)
|
||||
llm_filter_task_id = post_processing_tasks[-1].result_id
|
||||
|
||||
return ranked_chunks, reranked_llm_chunk_selection
|
||||
post_processing_results = run_functions_in_parallel(post_processing_tasks)
|
||||
reranked_chunks = cast(
|
||||
list[InferenceChunk] | None,
|
||||
post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None,
|
||||
)
|
||||
if reranked_chunks:
|
||||
if chunks_yielded:
|
||||
logger.error(
|
||||
"Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen."
|
||||
)
|
||||
else:
|
||||
_log_top_chunk_links(query.search_type.value, reranked_chunks)
|
||||
yield reranked_chunks
|
||||
|
||||
llm_chunk_selection = cast(
|
||||
list[str] | None,
|
||||
post_processing_results.get(str(llm_filter_task_id))
|
||||
if llm_filter_task_id
|
||||
else None,
|
||||
)
|
||||
if llm_chunk_selection:
|
||||
yield [chunk.unique_id in llm_chunk_selection for chunk in retrieved_chunks]
|
||||
else:
|
||||
yield [True for _ in reranked_chunks or retrieved_chunks]
|
||||
|
||||
|
||||
def danswer_search(
|
||||
def danswer_search_generator(
|
||||
question: QuestionRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
@@ -462,7 +548,11 @@ def danswer_search(
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> tuple[list[InferenceChunk], list[bool], int]:
|
||||
) -> Generator[list[InferenceChunk] | list[bool] | int, None, None]:
|
||||
"""The main entry point for search. This fetches the relevant documents from Vespa
|
||||
based on the provided query (applying permissions / filters), does any specified
|
||||
post-processing, and returns the results. It also create an entry in the query_event table
|
||||
for this search event."""
|
||||
query_event_id = create_query_event(
|
||||
query=question.query,
|
||||
search_type=question.search_type,
|
||||
@@ -490,20 +580,54 @@ def danswer_search(
|
||||
skip_llm_chunk_filter=skip_llm_chunk_filter,
|
||||
)
|
||||
|
||||
top_chunks, llm_chunk_selection = search_chunks(
|
||||
search_generator = full_chunk_search_generator(
|
||||
query=search_query,
|
||||
document_index=document_index,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
top_chunks = cast(list[InferenceChunk], next(search_generator))
|
||||
yield top_chunks
|
||||
|
||||
retrieved_ids = [doc.document_id for doc in top_chunks] if top_chunks else []
|
||||
llm_chunk_selection = cast(list[bool], next(search_generator))
|
||||
yield llm_chunk_selection
|
||||
|
||||
update_query_event_retrieved_documents(
|
||||
db_session=db_session,
|
||||
retrieved_document_ids=retrieved_ids,
|
||||
retrieved_document_ids=[doc.document_id for doc in top_chunks]
|
||||
if top_chunks
|
||||
else [],
|
||||
query_id=query_event_id,
|
||||
user_id=None if user is None else user.id,
|
||||
)
|
||||
yield query_event_id
|
||||
|
||||
|
||||
def danswer_search(
|
||||
question: QuestionRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> tuple[list[InferenceChunk], list[bool], int]:
|
||||
"""Returns a tuple of the top chunks, the LLM relevance filter results, and the query event ID.
|
||||
|
||||
Presents a simpler interface than the underlying `danswer_search_generator`, as callers no
|
||||
longer need to worry about the order / have nicer typing. This should be used for flows which
|
||||
do not require streaming."""
|
||||
search_generator = danswer_search_generator(
|
||||
question=question,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
document_index=document_index,
|
||||
skip_llm_chunk_filter=skip_llm_chunk_filter,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
top_chunks = cast(list[InferenceChunk], next(search_generator))
|
||||
llm_chunk_selection = cast(list[bool], next(search_generator))
|
||||
query_event_id = cast(int, next(search_generator))
|
||||
return top_chunks, llm_chunk_selection, query_event_id
|
||||
|
@@ -222,7 +222,6 @@ class QAResponse(SearchResponse):
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(RetrievalDocs):
|
||||
llm_chunks_indices: list[int]
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
time_cutoff: datetime | None
|
||||
@@ -236,6 +235,11 @@ class QADocsResponse(RetrievalDocs):
|
||||
return initial_dict
|
||||
|
||||
|
||||
# second chunk of info for streaming QA
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
relevant_chunk_indices: list[int]
|
||||
|
||||
|
||||
class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: int
|
||||
|
||||
|
Reference in New Issue
Block a user