Add selected docs in UI + rework the backend flow a bit(#754)

Changes the flow so that the selected docs are sent over in a separate packet rather than as part of the initial packet for the streaming QA endpoint.
This commit is contained in:
Chris Weaver
2023-11-21 19:46:12 -08:00
committed by GitHub
parent e78aefb408
commit c1e19d0d93
13 changed files with 389 additions and 181 deletions

View File

@@ -42,7 +42,7 @@ from danswer.search.models import IndexFilters
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import search_chunks
from danswer.search.search_runner import full_chunk_search
from danswer.server.models import RetrievalDocs
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import extract_embedded_json
@@ -140,8 +140,9 @@ def danswer_chat_retrieval(
)
# Good Debug/Breakpoint
top_chunks, _ = search_chunks(
query=search_query, document_index=get_default_document_index()
top_chunks, _ = full_chunk_search(
query=search_query,
document_index=get_default_document_index(),
)
if not top_chunks:

View File

@@ -1,6 +1,7 @@
from collections.abc import Callable
from collections.abc import Iterator
from functools import partial
from typing import cast
from sqlalchemy.orm import Session
@@ -15,15 +16,18 @@ from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_utils import get_chunks_for_qa
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.search.danswer_helper import query_intent
from danswer.search.models import QueryFlow
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import danswer_search
from danswer.search.search_runner import danswer_search_generator
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
from danswer.server.models import LLMRelevanceFilterResponse
from danswer.server.models import QADocsResponse
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
@@ -206,24 +210,19 @@ def answer_qa_query_stream(
question.favor_recent = favor_recent
question.filters.source_type = source_filters
top_chunks, llm_chunk_selection, query_event_id = danswer_search(
search_generator = danswer_search_generator(
question=question,
user=user,
db_session=db_session,
document_index=get_default_document_index(),
)
# first fetch and return to the UI the top chunks so the user can
# immediately see some results
top_chunks = cast(list[InferenceChunk], next(search_generator))
top_docs = chunks_to_search_docs(top_chunks)
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
batch_offset=offset_count,
)
initial_response = QADocsResponse(
top_documents=top_docs,
llm_chunks_indices=llm_chunks_indices,
# if generative AI is disabled, set flow as search so frontend
# doesn't ask the user if they want to run QA over more documents
predicted_flow=QueryFlow.SEARCH
@@ -233,13 +232,29 @@ def answer_qa_query_stream(
time_cutoff=time_cutoff,
favor_recent=favor_recent,
).dict()
yield get_json_line(initial_response)
if not top_chunks:
logger.debug("No Documents Found")
return
# next apply the LLM filtering
llm_chunk_selection = cast(list[bool], next(search_generator))
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
batch_offset=offset_count,
)
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
relevant_chunk_indices=llm_chunks_indices
).dict()
yield get_json_line(llm_relevance_filtering_response)
# finally get the query ID from the search generator for updating the
# row in Postgres. This is the end of the `search_generator` - any future
# calls to `next` will raise StopIteration
query_event_id = cast(int, next(search_generator))
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return

View File

@@ -94,6 +94,10 @@ class InferenceChunk(BaseChunk):
# when the doc was last updated
updated_at: datetime | None
@property
def unique_id(self) -> str:
return f"{self.document_id}__{self.chunk_id}"
def __repr__(self) -> str:
blurb_words = self.blurb.split()
short_blurb = ""

View File

@@ -1,6 +1,6 @@
from collections.abc import Callable
from collections.abc import Generator
from copy import deepcopy
from typing import Any
from typing import cast
import numpy
@@ -50,6 +50,13 @@ from danswer.utils.timing import log_function_time
logger = setup_logger()
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
top_links = [
c.source_links[0] if c.source_links is not None else "No Link" for c in chunks
]
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
def lemmatize_text(text: str) -> list[str]:
lemmatizer = WordNetLemmatizer()
word_tokens = word_tokenize(text)
@@ -329,29 +336,15 @@ def apply_boost(
return final_chunks
def search_chunks(
def retrieve_chunks(
query: SearchQuery,
document_index: DocumentIndex,
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk], list[bool]]:
"""Returns a list of the best chunks from search/reranking and if the chunks are relevant via LLM.
For sake of speed, the system cannot rerank all retrieved chunks
Also pass the chunks through LLM to determine if they are relevant (binary for speed)
Only the first max_llm_filter_chunks
"""
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
top_links = [
c.source_links[0] if c.source_links is not None else "No Link"
for c in chunks
]
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
) -> list[InferenceChunk]:
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
# Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_query_expansion or "\n" in query.query or "\r" in query.query:
top_chunks = doc_index_retrieval(
@@ -377,7 +370,7 @@ def search_chunks(
f"{query.search_type.value.capitalize()} search returned no results "
f"with filters: {query.filters}"
)
return [], []
return []
if retrieval_metrics_callback is not None:
chunk_metrics = [
@@ -393,67 +386,160 @@ def search_chunks(
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
)
functions_to_run: list[FunctionCall] = []
return top_chunks
# Keyword Search should not do reranking
if query.search_type == SearchType.KEYWORD or query.skip_rerank:
_log_top_chunk_links(query.search_type.value, top_chunks)
run_rerank_id: str | None = None
else:
run_rerank = FunctionCall(
semantic_reranking,
(query.query, top_chunks[: query.num_rerank]),
{"rerank_metrics_callback": rerank_metrics_callback},
def should_rerank(query: SearchQuery) -> bool:
# don't re-rank for keyword search
return query.search_type != SearchType.KEYWORD and not query.skip_rerank
def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool:
return not query.skip_llm_chunk_filter
def rerank_chunks(
query: SearchQuery,
chunks_to_rerank: list[InferenceChunk],
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> list[InferenceChunk]:
ranked_chunks, _ = semantic_reranking(
query=query.query,
chunks=chunks_to_rerank[: query.num_rerank],
rerank_metrics_callback=rerank_metrics_callback,
)
lower_chunks = chunks_to_rerank[query.num_rerank :]
# Scores from rerank cannot be meaningfully combined with scores without rerank
for lower_chunk in lower_chunks:
lower_chunk.score = None
ranked_chunks.extend(lower_chunks)
return ranked_chunks
def filter_chunks(
query: SearchQuery,
chunks_to_filter: list[InferenceChunk],
) -> list[str]:
"""Filters chunks based on whether the LLM thought they were relevant to the query.
Returns a list of the unique chunk IDs that were marked as relevant"""
chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks]
llm_chunk_selection = llm_batch_eval_chunks(
query=query.query,
chunk_contents=[chunk.content for chunk in chunks_to_filter],
)
return [
chunk.unique_id
for ind, chunk in enumerate(chunks_to_filter)
if llm_chunk_selection[ind]
]
def full_chunk_search(
query: SearchQuery,
document_index: DocumentIndex,
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk], list[bool]]:
"""A utility which provides an easier interface than `full_chunk_search_generator`.
Rather than returning the chunks and llm relevance filter results in two separate
yields, just returns them both at once."""
search_generator = full_chunk_search_generator(
query=query,
document_index=document_index,
hybrid_alpha=hybrid_alpha,
multilingual_query_expansion=multilingual_query_expansion,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
top_chunks = cast(list[InferenceChunk], next(search_generator))
llm_chunk_selection = cast(list[bool], next(search_generator))
return top_chunks, llm_chunk_selection
def full_chunk_search_generator(
query: SearchQuery,
document_index: DocumentIndex,
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Generator[list[InferenceChunk] | list[bool], None, None]:
"""Always yields twice. Once with the selected chunks and once with the LLM relevance filter result."""
chunks_yielded = False
retrieved_chunks = retrieve_chunks(
query=query,
document_index=document_index,
hybrid_alpha=hybrid_alpha,
multilingual_query_expansion=multilingual_query_expansion,
retrieval_metrics_callback=retrieval_metrics_callback,
)
post_processing_tasks: list[FunctionCall] = []
rerank_task_id = None
if should_rerank(query):
post_processing_tasks.append(
FunctionCall(
rerank_chunks,
(
query,
retrieved_chunks,
rerank_metrics_callback,
),
)
)
functions_to_run.append(run_rerank)
run_rerank_id = run_rerank.result_id
run_llm_filter_id = None
if not query.skip_llm_chunk_filter:
run_llm_filter = FunctionCall(
llm_batch_eval_chunks,
(
query.query,
[chunk.content for chunk in top_chunks[: query.max_llm_filter_chunks]],
),
{},
)
functions_to_run.append(run_llm_filter)
run_llm_filter_id = run_llm_filter.result_id
parallel_results: dict[str, Any] = {}
if functions_to_run:
parallel_results = run_functions_in_parallel(functions_to_run)
ranked_results = parallel_results.get(str(run_rerank_id))
if ranked_results is None:
ranked_chunks = top_chunks
sorted_indices = [i for i in range(len(top_chunks))]
rerank_task_id = post_processing_tasks[-1].result_id
else:
ranked_chunks, orig_indices = ranked_results
sorted_indices = orig_indices + list(range(len(orig_indices), len(top_chunks)))
lower_chunks = top_chunks[query.num_rerank :]
# Scores from rerank cannot be meaningfully combined with scores without rerank
for lower_chunk in lower_chunks:
lower_chunk.score = None
ranked_chunks.extend(lower_chunks)
final_chunks = retrieved_chunks
# NOTE: if we don't rerank, we can return the chunks immediately
# since we know this is the final order
_log_top_chunk_links(query.search_type.value, final_chunks)
yield final_chunks
chunks_yielded = True
llm_chunk_selection = parallel_results.get(str(run_llm_filter_id))
if llm_chunk_selection is None:
reranked_llm_chunk_selection = [True for _ in top_chunks]
else:
llm_chunk_selection.extend(
[False for _ in top_chunks[query.max_llm_filter_chunks :]]
llm_filter_task_id = None
if should_apply_llm_based_relevance_filter(query):
post_processing_tasks.append(
FunctionCall(
filter_chunks,
(query, retrieved_chunks[: query.max_llm_filter_chunks]),
)
)
reranked_llm_chunk_selection = [
llm_chunk_selection[ind] for ind in sorted_indices
]
_log_top_chunk_links(query.search_type.value, ranked_chunks)
llm_filter_task_id = post_processing_tasks[-1].result_id
return ranked_chunks, reranked_llm_chunk_selection
post_processing_results = run_functions_in_parallel(post_processing_tasks)
reranked_chunks = cast(
list[InferenceChunk] | None,
post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None,
)
if reranked_chunks:
if chunks_yielded:
logger.error(
"Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen."
)
else:
_log_top_chunk_links(query.search_type.value, reranked_chunks)
yield reranked_chunks
llm_chunk_selection = cast(
list[str] | None,
post_processing_results.get(str(llm_filter_task_id))
if llm_filter_task_id
else None,
)
if llm_chunk_selection:
yield [chunk.unique_id in llm_chunk_selection for chunk in retrieved_chunks]
else:
yield [True for _ in reranked_chunks or retrieved_chunks]
def danswer_search(
def danswer_search_generator(
question: QuestionRequest,
user: User | None,
db_session: Session,
@@ -462,7 +548,11 @@ def danswer_search(
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk], list[bool], int]:
) -> Generator[list[InferenceChunk] | list[bool] | int, None, None]:
"""The main entry point for search. This fetches the relevant documents from Vespa
based on the provided query (applying permissions / filters), does any specified
post-processing, and returns the results. It also create an entry in the query_event table
for this search event."""
query_event_id = create_query_event(
query=question.query,
search_type=question.search_type,
@@ -490,20 +580,54 @@ def danswer_search(
skip_llm_chunk_filter=skip_llm_chunk_filter,
)
top_chunks, llm_chunk_selection = search_chunks(
search_generator = full_chunk_search_generator(
query=search_query,
document_index=document_index,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
top_chunks = cast(list[InferenceChunk], next(search_generator))
yield top_chunks
retrieved_ids = [doc.document_id for doc in top_chunks] if top_chunks else []
llm_chunk_selection = cast(list[bool], next(search_generator))
yield llm_chunk_selection
update_query_event_retrieved_documents(
db_session=db_session,
retrieved_document_ids=retrieved_ids,
retrieved_document_ids=[doc.document_id for doc in top_chunks]
if top_chunks
else [],
query_id=query_event_id,
user_id=None if user is None else user.id,
)
yield query_event_id
def danswer_search(
question: QuestionRequest,
user: User | None,
db_session: Session,
document_index: DocumentIndex,
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk], list[bool], int]:
"""Returns a tuple of the top chunks, the LLM relevance filter results, and the query event ID.
Presents a simpler interface than the underlying `danswer_search_generator`, as callers no
longer need to worry about the order / have nicer typing. This should be used for flows which
do not require streaming."""
search_generator = danswer_search_generator(
question=question,
user=user,
db_session=db_session,
document_index=document_index,
skip_llm_chunk_filter=skip_llm_chunk_filter,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
top_chunks = cast(list[InferenceChunk], next(search_generator))
llm_chunk_selection = cast(list[bool], next(search_generator))
query_event_id = cast(int, next(search_generator))
return top_chunks, llm_chunk_selection, query_event_id

View File

@@ -222,7 +222,6 @@ class QAResponse(SearchResponse):
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs):
llm_chunks_indices: list[int]
predicted_flow: QueryFlow
predicted_search: SearchType
time_cutoff: datetime | None
@@ -236,6 +235,11 @@ class QADocsResponse(RetrievalDocs):
return initial_dict
# second chunk of info for streaming QA
class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
class CreateChatSessionID(BaseModel):
chat_session_id: int