diff --git a/backend/danswer/chat/chat_llm.py b/backend/danswer/chat/chat_llm.py index f3b4d7cd65..e067b73da0 100644 --- a/backend/danswer/chat/chat_llm.py +++ b/backend/danswer/chat/chat_llm.py @@ -42,7 +42,7 @@ from danswer.search.models import IndexFilters from danswer.search.models import SearchQuery from danswer.search.models import SearchType from danswer.search.search_runner import chunks_to_search_docs -from danswer.search.search_runner import search_chunks +from danswer.search.search_runner import full_chunk_search from danswer.server.models import RetrievalDocs from danswer.utils.logger import setup_logger from danswer.utils.text_processing import extract_embedded_json @@ -140,8 +140,9 @@ def danswer_chat_retrieval( ) # Good Debug/Breakpoint - top_chunks, _ = search_chunks( - query=search_query, document_index=get_default_document_index() + top_chunks, _ = full_chunk_search( + query=search_query, + document_index=get_default_document_index(), ) if not top_chunks: diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index b8c02d0e15..c94d3225ce 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -1,6 +1,7 @@ from collections.abc import Callable from collections.abc import Iterator from functools import partial +from typing import cast from sqlalchemy.orm import Session @@ -15,15 +16,18 @@ from danswer.direct_qa.interfaces import StreamingError from danswer.direct_qa.models import LLMMetricsContainer from danswer.direct_qa.qa_utils import get_chunks_for_qa from danswer.document_index.factory import get_default_document_index +from danswer.indexing.models import InferenceChunk from danswer.search.danswer_helper import query_intent from danswer.search.models import QueryFlow from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer from danswer.search.search_runner import chunks_to_search_docs from danswer.search.search_runner import danswer_search +from danswer.search.search_runner import danswer_search_generator from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.source_filter import extract_question_source_filters from danswer.secondary_llm_flows.time_filter import extract_question_time_filters +from danswer.server.models import LLMRelevanceFilterResponse from danswer.server.models import QADocsResponse from danswer.server.models import QAResponse from danswer.server.models import QuestionRequest @@ -206,24 +210,19 @@ def answer_qa_query_stream( question.favor_recent = favor_recent question.filters.source_type = source_filters - top_chunks, llm_chunk_selection, query_event_id = danswer_search( + search_generator = danswer_search_generator( question=question, user=user, db_session=db_session, document_index=get_default_document_index(), ) + # first fetch and return to the UI the top chunks so the user can + # immediately see some results + top_chunks = cast(list[InferenceChunk], next(search_generator)) top_docs = chunks_to_search_docs(top_chunks) - - llm_chunks_indices = get_chunks_for_qa( - chunks=top_chunks, - llm_chunk_selection=llm_chunk_selection, - batch_offset=offset_count, - ) - initial_response = QADocsResponse( top_documents=top_docs, - llm_chunks_indices=llm_chunks_indices, # if generative AI is disabled, set flow as search so frontend # doesn't ask the user if they want to run QA over more documents predicted_flow=QueryFlow.SEARCH @@ -233,13 +232,29 @@ def answer_qa_query_stream( time_cutoff=time_cutoff, favor_recent=favor_recent, ).dict() - yield get_json_line(initial_response) if not top_chunks: logger.debug("No Documents Found") return + # next apply the LLM filtering + llm_chunk_selection = cast(list[bool], next(search_generator)) + llm_chunks_indices = get_chunks_for_qa( + chunks=top_chunks, + llm_chunk_selection=llm_chunk_selection, + batch_offset=offset_count, + ) + llm_relevance_filtering_response = LLMRelevanceFilterResponse( + relevant_chunk_indices=llm_chunks_indices + ).dict() + yield get_json_line(llm_relevance_filtering_response) + + # finally get the query ID from the search generator for updating the + # row in Postgres. This is the end of the `search_generator` - any future + # calls to `next` will raise StopIteration + query_event_id = cast(int, next(search_generator)) + if disable_generative_answer: logger.debug("Skipping QA because generative AI is disabled") return diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index 934646b027..8c9bc0e60c 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -94,6 +94,10 @@ class InferenceChunk(BaseChunk): # when the doc was last updated updated_at: datetime | None + @property + def unique_id(self) -> str: + return f"{self.document_id}__{self.chunk_id}" + def __repr__(self) -> str: blurb_words = self.blurb.split() short_blurb = "" diff --git a/backend/danswer/search/search_runner.py b/backend/danswer/search/search_runner.py index 92e640c9d4..2a6097b544 100644 --- a/backend/danswer/search/search_runner.py +++ b/backend/danswer/search/search_runner.py @@ -1,6 +1,6 @@ from collections.abc import Callable +from collections.abc import Generator from copy import deepcopy -from typing import Any from typing import cast import numpy @@ -50,6 +50,13 @@ from danswer.utils.timing import log_function_time logger = setup_logger() +def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None: + top_links = [ + c.source_links[0] if c.source_links is not None else "No Link" for c in chunks + ] + logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") + + def lemmatize_text(text: str) -> list[str]: lemmatizer = WordNetLemmatizer() word_tokens = word_tokenize(text) @@ -329,29 +336,15 @@ def apply_boost( return final_chunks -def search_chunks( +def retrieve_chunks( query: SearchQuery, document_index: DocumentIndex, hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, - rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> tuple[list[InferenceChunk], list[bool]]: - """Returns a list of the best chunks from search/reranking and if the chunks are relevant via LLM. - For sake of speed, the system cannot rerank all retrieved chunks - Also pass the chunks through LLM to determine if they are relevant (binary for speed) - - Only the first max_llm_filter_chunks - """ - - def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None: - top_links = [ - c.source_links[0] if c.source_links is not None else "No Link" - for c in chunks - ] - logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") - +) -> list[InferenceChunk]: + """Returns a list of the best chunks from an initial keyword/semantic/ hybrid search.""" # Don't do query expansion on complex queries, rephrasings likely would not work well if not multilingual_query_expansion or "\n" in query.query or "\r" in query.query: top_chunks = doc_index_retrieval( @@ -377,7 +370,7 @@ def search_chunks( f"{query.search_type.value.capitalize()} search returned no results " f"with filters: {query.filters}" ) - return [], [] + return [] if retrieval_metrics_callback is not None: chunk_metrics = [ @@ -393,67 +386,160 @@ def search_chunks( RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics) ) - functions_to_run: list[FunctionCall] = [] + return top_chunks - # Keyword Search should not do reranking - if query.search_type == SearchType.KEYWORD or query.skip_rerank: - _log_top_chunk_links(query.search_type.value, top_chunks) - run_rerank_id: str | None = None - else: - run_rerank = FunctionCall( - semantic_reranking, - (query.query, top_chunks[: query.num_rerank]), - {"rerank_metrics_callback": rerank_metrics_callback}, + +def should_rerank(query: SearchQuery) -> bool: + # don't re-rank for keyword search + return query.search_type != SearchType.KEYWORD and not query.skip_rerank + + +def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool: + return not query.skip_llm_chunk_filter + + +def rerank_chunks( + query: SearchQuery, + chunks_to_rerank: list[InferenceChunk], + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> list[InferenceChunk]: + ranked_chunks, _ = semantic_reranking( + query=query.query, + chunks=chunks_to_rerank[: query.num_rerank], + rerank_metrics_callback=rerank_metrics_callback, + ) + lower_chunks = chunks_to_rerank[query.num_rerank :] + # Scores from rerank cannot be meaningfully combined with scores without rerank + for lower_chunk in lower_chunks: + lower_chunk.score = None + ranked_chunks.extend(lower_chunks) + return ranked_chunks + + +def filter_chunks( + query: SearchQuery, + chunks_to_filter: list[InferenceChunk], +) -> list[str]: + """Filters chunks based on whether the LLM thought they were relevant to the query. + + Returns a list of the unique chunk IDs that were marked as relevant""" + chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks] + llm_chunk_selection = llm_batch_eval_chunks( + query=query.query, + chunk_contents=[chunk.content for chunk in chunks_to_filter], + ) + return [ + chunk.unique_id + for ind, chunk in enumerate(chunks_to_filter) + if llm_chunk_selection[ind] + ] + + +def full_chunk_search( + query: SearchQuery, + document_index: DocumentIndex, + hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search + multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION, + retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] + | None = None, + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> tuple[list[InferenceChunk], list[bool]]: + """A utility which provides an easier interface than `full_chunk_search_generator`. + Rather than returning the chunks and llm relevance filter results in two separate + yields, just returns them both at once.""" + search_generator = full_chunk_search_generator( + query=query, + document_index=document_index, + hybrid_alpha=hybrid_alpha, + multilingual_query_expansion=multilingual_query_expansion, + retrieval_metrics_callback=retrieval_metrics_callback, + rerank_metrics_callback=rerank_metrics_callback, + ) + top_chunks = cast(list[InferenceChunk], next(search_generator)) + llm_chunk_selection = cast(list[bool], next(search_generator)) + return top_chunks, llm_chunk_selection + + +def full_chunk_search_generator( + query: SearchQuery, + document_index: DocumentIndex, + hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search + multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION, + retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] + | None = None, + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> Generator[list[InferenceChunk] | list[bool], None, None]: + """Always yields twice. Once with the selected chunks and once with the LLM relevance filter result.""" + chunks_yielded = False + + retrieved_chunks = retrieve_chunks( + query=query, + document_index=document_index, + hybrid_alpha=hybrid_alpha, + multilingual_query_expansion=multilingual_query_expansion, + retrieval_metrics_callback=retrieval_metrics_callback, + ) + + post_processing_tasks: list[FunctionCall] = [] + + rerank_task_id = None + if should_rerank(query): + post_processing_tasks.append( + FunctionCall( + rerank_chunks, + ( + query, + retrieved_chunks, + rerank_metrics_callback, + ), + ) ) - functions_to_run.append(run_rerank) - run_rerank_id = run_rerank.result_id - - run_llm_filter_id = None - if not query.skip_llm_chunk_filter: - run_llm_filter = FunctionCall( - llm_batch_eval_chunks, - ( - query.query, - [chunk.content for chunk in top_chunks[: query.max_llm_filter_chunks]], - ), - {}, - ) - functions_to_run.append(run_llm_filter) - run_llm_filter_id = run_llm_filter.result_id - - parallel_results: dict[str, Any] = {} - if functions_to_run: - parallel_results = run_functions_in_parallel(functions_to_run) - - ranked_results = parallel_results.get(str(run_rerank_id)) - if ranked_results is None: - ranked_chunks = top_chunks - sorted_indices = [i for i in range(len(top_chunks))] + rerank_task_id = post_processing_tasks[-1].result_id else: - ranked_chunks, orig_indices = ranked_results - sorted_indices = orig_indices + list(range(len(orig_indices), len(top_chunks))) - lower_chunks = top_chunks[query.num_rerank :] - # Scores from rerank cannot be meaningfully combined with scores without rerank - for lower_chunk in lower_chunks: - lower_chunk.score = None - ranked_chunks.extend(lower_chunks) + final_chunks = retrieved_chunks + # NOTE: if we don't rerank, we can return the chunks immediately + # since we know this is the final order + _log_top_chunk_links(query.search_type.value, final_chunks) + yield final_chunks + chunks_yielded = True - llm_chunk_selection = parallel_results.get(str(run_llm_filter_id)) - if llm_chunk_selection is None: - reranked_llm_chunk_selection = [True for _ in top_chunks] - else: - llm_chunk_selection.extend( - [False for _ in top_chunks[query.max_llm_filter_chunks :]] + llm_filter_task_id = None + if should_apply_llm_based_relevance_filter(query): + post_processing_tasks.append( + FunctionCall( + filter_chunks, + (query, retrieved_chunks[: query.max_llm_filter_chunks]), + ) ) - reranked_llm_chunk_selection = [ - llm_chunk_selection[ind] for ind in sorted_indices - ] - _log_top_chunk_links(query.search_type.value, ranked_chunks) + llm_filter_task_id = post_processing_tasks[-1].result_id - return ranked_chunks, reranked_llm_chunk_selection + post_processing_results = run_functions_in_parallel(post_processing_tasks) + reranked_chunks = cast( + list[InferenceChunk] | None, + post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None, + ) + if reranked_chunks: + if chunks_yielded: + logger.error( + "Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen." + ) + else: + _log_top_chunk_links(query.search_type.value, reranked_chunks) + yield reranked_chunks + + llm_chunk_selection = cast( + list[str] | None, + post_processing_results.get(str(llm_filter_task_id)) + if llm_filter_task_id + else None, + ) + if llm_chunk_selection: + yield [chunk.unique_id in llm_chunk_selection for chunk in retrieved_chunks] + else: + yield [True for _ in reranked_chunks or retrieved_chunks] -def danswer_search( +def danswer_search_generator( question: QuestionRequest, user: User | None, db_session: Session, @@ -462,7 +548,11 @@ def danswer_search( retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> tuple[list[InferenceChunk], list[bool], int]: +) -> Generator[list[InferenceChunk] | list[bool] | int, None, None]: + """The main entry point for search. This fetches the relevant documents from Vespa + based on the provided query (applying permissions / filters), does any specified + post-processing, and returns the results. It also create an entry in the query_event table + for this search event.""" query_event_id = create_query_event( query=question.query, search_type=question.search_type, @@ -490,20 +580,54 @@ def danswer_search( skip_llm_chunk_filter=skip_llm_chunk_filter, ) - top_chunks, llm_chunk_selection = search_chunks( + search_generator = full_chunk_search_generator( query=search_query, document_index=document_index, retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, ) + top_chunks = cast(list[InferenceChunk], next(search_generator)) + yield top_chunks - retrieved_ids = [doc.document_id for doc in top_chunks] if top_chunks else [] + llm_chunk_selection = cast(list[bool], next(search_generator)) + yield llm_chunk_selection update_query_event_retrieved_documents( db_session=db_session, - retrieved_document_ids=retrieved_ids, + retrieved_document_ids=[doc.document_id for doc in top_chunks] + if top_chunks + else [], query_id=query_event_id, user_id=None if user is None else user.id, ) + yield query_event_id + +def danswer_search( + question: QuestionRequest, + user: User | None, + db_session: Session, + document_index: DocumentIndex, + skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, + retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] + | None = None, + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> tuple[list[InferenceChunk], list[bool], int]: + """Returns a tuple of the top chunks, the LLM relevance filter results, and the query event ID. + + Presents a simpler interface than the underlying `danswer_search_generator`, as callers no + longer need to worry about the order / have nicer typing. This should be used for flows which + do not require streaming.""" + search_generator = danswer_search_generator( + question=question, + user=user, + db_session=db_session, + document_index=document_index, + skip_llm_chunk_filter=skip_llm_chunk_filter, + retrieval_metrics_callback=retrieval_metrics_callback, + rerank_metrics_callback=rerank_metrics_callback, + ) + top_chunks = cast(list[InferenceChunk], next(search_generator)) + llm_chunk_selection = cast(list[bool], next(search_generator)) + query_event_id = cast(int, next(search_generator)) return top_chunks, llm_chunk_selection, query_event_id diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 30f3e41af9..57cb7386be 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -222,7 +222,6 @@ class QAResponse(SearchResponse): # First chunk of info for streaming QA class QADocsResponse(RetrievalDocs): - llm_chunks_indices: list[int] predicted_flow: QueryFlow predicted_search: SearchType time_cutoff: datetime | None @@ -236,6 +235,11 @@ class QADocsResponse(RetrievalDocs): return initial_dict +# second chunk of info for streaming QA +class LLMRelevanceFilterResponse(BaseModel): + relevant_chunk_indices: list[int] + + class CreateChatSessionID(BaseModel): chat_session_id: int diff --git a/web/src/components/HoverPopup.tsx b/web/src/components/HoverPopup.tsx index 8263d52d1a..886cde0238 100644 --- a/web/src/components/HoverPopup.tsx +++ b/web/src/components/HoverPopup.tsx @@ -5,6 +5,7 @@ interface HoverPopupProps { popupContent: string | JSX.Element; classNameModifications?: string; direction?: "left" | "bottom"; + style?: "basic" | "dark"; } export const HoverPopup = ({ @@ -12,6 +13,7 @@ export const HoverPopup = ({ popupContent, classNameModifications, direction = "bottom", + style = "basic", }: HoverPopupProps) => { const [hovered, setHovered] = useState(false); @@ -37,7 +39,10 @@ export const HoverPopup = ({
diff --git a/web/src/components/search/DocumentDisplay.tsx b/web/src/components/search/DocumentDisplay.tsx index ccf14f0c56..adc666cb76 100644 --- a/web/src/components/search/DocumentDisplay.tsx +++ b/web/src/components/search/DocumentDisplay.tsx @@ -3,8 +3,9 @@ import { DocumentFeedbackBlock } from "./DocumentFeedbackBlock"; import { getSourceIcon } from "../source"; import { useState } from "react"; import { PopupSpec } from "../admin/connectors/Popup"; -import { timeAgo } from "@/lib/time"; +import { HoverPopup } from "@/components/HoverPopup"; import { DocumentUpdatedAtBadge } from "./DocumentUpdatedAtBadge"; +import { FiCrosshair, FiInfo, FiRadio } from "react-icons/fi"; export const buildDocumentSummaryDisplay = ( matchHighlights: string[], @@ -106,12 +107,14 @@ export const buildDocumentSummaryDisplay = ( interface DocumentDisplayProps { document: DanswerDocument; queryEventId: number | null; + isSelected: boolean; setPopup: (popupSpec: PopupSpec | null) => void; } export const DocumentDisplay = ({ document, queryEventId, + isSelected, setPopup, }: DocumentDisplayProps) => { const [isHovered, setIsHovered] = useState(false); @@ -132,7 +135,31 @@ export const DocumentDisplay = ({ >
{document.score !== null && ( -
+
+ {isSelected && ( +
+ } + popupContent={ +
+
+
+ +
+
The AI liked this doc!
+
+
+ } + direction="bottom" + style="dark" + /> +
+ )}
{ return output; }; +const getSelectedDocumentIds = ( + documents: DanswerDocument[], + selectedIndices: number[] +) => { + const selectedDocumentIds = new Set(); + selectedIndices.forEach((ind) => { + selectedDocumentIds.add(documents[ind].document_id); + }); + return selectedDocumentIds; +}; + interface SearchResultsDisplayProps { searchResponse: SearchResponse | null; validQuestionResponse: ValidQuestionResponse; @@ -95,6 +106,11 @@ export const SearchResultsDisplay: React.FC = ({ }); } + const selectedDocumentIds = getSelectedDocumentIds( + documents || [], + searchResponse.selectedDocIndices || [] + ); + const shouldDisplayQA = searchResponse.suggestedFlowType === FlowType.QUESTION_ANSWER || defaultOverrides.forceDisplayQA; @@ -175,6 +191,7 @@ export const SearchResultsDisplay: React.FC = ({ key={document.document_id} document={document} queryEventId={queryEventId} + isSelected={selectedDocumentIds.has(document.document_id)} setPopup={setPopup} /> ))} diff --git a/web/src/components/search/SearchSection.tsx b/web/src/components/search/SearchSection.tsx index 370afd17d1..d70a5fbd43 100644 --- a/web/src/components/search/SearchSection.tsx +++ b/web/src/components/search/SearchSection.tsx @@ -5,12 +5,10 @@ import { SearchBar } from "./SearchBar"; import { SearchResultsDisplay } from "./SearchResultsDisplay"; import { SourceSelector } from "./filtering/Filters"; import { Connector, DocumentSet } from "@/lib/types"; -import { SearchTypeSelector } from "./SearchTypeSelector"; import { DanswerDocument, Quote, SearchResponse, - Source, FlowType, SearchType, SearchDefaultOverrides, @@ -18,7 +16,6 @@ import { ValidQuestionResponse, } from "@/lib/search/interfaces"; import { searchRequestStreamed } from "@/lib/search/streamingQa"; -import Cookies from "js-cookie"; import { SearchHelper } from "./SearchHelper"; import { CancellationToken, cancellable } from "@/lib/search/cancellable"; import { NEXT_PUBLIC_DISABLE_STREAMING } from "@/lib/constants"; @@ -77,6 +74,7 @@ export const SearchSection: React.FC = ({ documents: null, suggestedSearchType: null, suggestedFlowType: null, + selectedDocIndices: null, error: null, queryEventId: null, }; @@ -105,6 +103,11 @@ export const SearchSection: React.FC = ({ ...(prevState || initialSearchResponse), suggestedFlowType, })); + const updateSelectedDocIndices = (docIndices: number[]) => + setSearchResponse((prevState) => ({ + ...(prevState || initialSearchResponse), + selectedDocIndices: docIndices, + })); const updateError = (error: FlowType) => setSearchResponse((prevState) => ({ ...(prevState || initialSearchResponse), @@ -159,6 +162,10 @@ export const SearchSection: React.FC = ({ cancellationToken: lastSearchCancellationToken.current, fn: updateSuggestedFlowType, }), + updateSelectedDocIndices: cancellable({ + cancellationToken: lastSearchCancellationToken.current, + fn: updateSelectedDocIndices, + }), updateError: cancellable({ cancellationToken: lastSearchCancellationToken.current, fn: updateError, @@ -185,7 +192,7 @@ export const SearchSection: React.FC = ({ }; return ( -
+
{(connectors.length > 0 || documentSets.length > 0) && ( = ({ /> )} -
+
void; updateQuotes: (quotes: Quote[]) => void; updateDocs: (documents: DanswerDocument[]) => void; + updateSelectedDocIndices: (docIndices: number[]) => void; updateSuggestedSearchType: (searchType: SearchType) => void; updateSuggestedFlowType: (flowType: FlowType) => void; updateError: (error: string) => void; diff --git a/web/src/lib/search/qa.ts b/web/src/lib/search/qa.ts index 2f54f1cbf2..d61935a06b 100644 --- a/web/src/lib/search/qa.ts +++ b/web/src/lib/search/qa.ts @@ -21,6 +21,10 @@ export const searchRequest = async ({ selectedSearchType, offset, }: SearchRequestArgs) => { + /* + NOTE: does not support the full functionality (AI selected answers). Should not be used if + at all possible - use `searchRequestStreamed` instead. + */ let answer = ""; let quotes: Quote[] | null = null; let relevantDocuments: DanswerDocument[] | null = null; diff --git a/web/src/lib/search/streamingQa.ts b/web/src/lib/search/streamingQa.ts index cdb29dfe3c..183e25aa16 100644 --- a/web/src/lib/search/streamingQa.ts +++ b/web/src/lib/search/streamingQa.ts @@ -1,57 +1,17 @@ import { + AnswerPiecePacket, DanswerDocument, + DocumentInfoPacket, + ErrorMessagePacket, + LLMRelevanceFilterPacket, + QueryEventIdPacket, Quote, + QuotesInfoPacket, SearchRequestArgs, - SearchType, } from "./interfaces"; +import { processRawChunkString } from "./streamingUtils"; import { buildFilters } from "./utils"; -const processSingleChunk = ( - chunk: string, - currPartialChunk: string | null -): [{ [key: string]: any } | null, string | null] => { - const completeChunk = (currPartialChunk || "") + chunk; - try { - // every complete chunk should be valid JSON - const chunkJson = JSON.parse(completeChunk); - return [chunkJson, null]; - } catch (err) { - // if it's not valid JSON, then it's probably an incomplete chunk - return [null, completeChunk]; - } -}; - -const processRawChunkString = ( - rawChunkString: string, - previousPartialChunk: string | null -): [any[], string | null] => { - /* This is required because, in practice, we see that nginx does not send over - each chunk one at a time even with buffering turned off. Instead, - chunks are sometimes in batches or are sometimes incomplete */ - if (!rawChunkString) { - return [[], null]; - } - const chunkSections = rawChunkString - .split("\n") - .filter((chunk) => chunk.length > 0); - let parsedChunkSections: any[] = []; - let currPartialChunk = previousPartialChunk; - chunkSections.forEach((chunk) => { - const [processedChunk, partialChunk] = processSingleChunk( - chunk, - currPartialChunk - ); - if (processedChunk) { - parsedChunkSections.push(processedChunk); - currPartialChunk = null; - } else { - currPartialChunk = partialChunk; - } - }); - - return [parsedChunkSections, currPartialChunk]; -}; - export const searchRequestStreamed = async ({ query, sources, @@ -62,6 +22,7 @@ export const searchRequestStreamed = async ({ updateDocs, updateSuggestedSearchType, updateSuggestedFlowType, + updateSelectedDocIndices, updateError, updateQueryEventId, offset, @@ -87,7 +48,7 @@ export const searchRequestStreamed = async ({ const reader = response.body?.getReader(); const decoder = new TextDecoder("utf-8"); - let previousPartialChunk = null; + let previousPartialChunk: string | null = null; while (true) { const rawChunk = await reader?.read(); if (!rawChunk) { @@ -99,47 +60,50 @@ export const searchRequestStreamed = async ({ } // Process each chunk as it arrives - const [completedChunks, partialChunk] = processRawChunkString( - decoder.decode(value, { stream: true }), - previousPartialChunk - ); + const [completedChunks, partialChunk] = processRawChunkString< + | AnswerPiecePacket + | ErrorMessagePacket + | QuotesInfoPacket + | DocumentInfoPacket + | LLMRelevanceFilterPacket + | QueryEventIdPacket + >(decoder.decode(value, { stream: true }), previousPartialChunk); if (!completedChunks.length && !partialChunk) { break; } - previousPartialChunk = partialChunk; + previousPartialChunk = partialChunk as string | null; completedChunks.forEach((chunk) => { - // TODO: clean up response / this logic - const answerChunk = chunk.answer_piece; - if (answerChunk) { - answer += answerChunk; - updateCurrentAnswer(answer); - return; - } - - if (answerChunk === null) { - // set quotes as non-null to signify that the answer is finished and - // we're now looking for quotes - updateQuotes([]); - if ( - answer && - !answer.endsWith(".") && - !answer.endsWith("?") && - !answer.endsWith("!") - ) { - answer += "."; + // check for answer peice / end of answer + if (Object.hasOwn(chunk, "answer_piece")) { + const answerPiece = (chunk as AnswerPiecePacket).answer_piece; + if (answerPiece !== null) { + answer += (chunk as AnswerPiecePacket).answer_piece; updateCurrentAnswer(answer); + } else { + // set quotes as non-null to signify that the answer is finished and + // we're now looking for quotes + updateQuotes([]); + if ( + answer && + !answer.endsWith(".") && + !answer.endsWith("?") && + !answer.endsWith("!") + ) { + answer += "."; + updateCurrentAnswer(answer); + } } return; } - const errorMsg = chunk.error; - if (errorMsg) { - updateError(errorMsg); + if (Object.hasOwn(chunk, "error")) { + updateError((chunk as ErrorMessagePacket).error); return; } // These all come together if (Object.hasOwn(chunk, "top_documents")) { + chunk = chunk as DocumentInfoPacket; const topDocuments = chunk.top_documents as DanswerDocument[] | null; if (topDocuments) { relevantDocuments = topDocuments; @@ -153,19 +117,29 @@ export const searchRequestStreamed = async ({ if (chunk.predicted_search) { updateSuggestedSearchType(chunk.predicted_search); } + + return; + } + + if (Object.hasOwn(chunk, "relevant_chunk_indices")) { + const relevantChunkIndices = (chunk as LLMRelevanceFilterPacket) + .relevant_chunk_indices; + if (relevantChunkIndices) { + updateSelectedDocIndices(relevantChunkIndices); + } return; } // Check for quote section - if (chunk.quotes) { - quotes = chunk.quotes as Quote[]; + if (Object.hasOwn(chunk, "quotes")) { + quotes = (chunk as QuotesInfoPacket).quotes; updateQuotes(quotes); return; } // check for query ID section - if (chunk.query_event_id) { - updateQueryEventId(chunk.query_event_id); + if (Object.hasOwn(chunk, "query_event_id")) { + updateQueryEventId((chunk as QueryEventIdPacket).query_event_id); return; } diff --git a/web/src/lib/search/streamingQuestionValidation.ts b/web/src/lib/search/streamingQuestionValidation.ts index 2cd2ccad4a..bad5ab46aa 100644 --- a/web/src/lib/search/streamingQuestionValidation.ts +++ b/web/src/lib/search/streamingQuestionValidation.ts @@ -1,4 +1,4 @@ -import { AnswerPiece, ValidQuestionResponse } from "./interfaces"; +import { AnswerPiecePacket, ValidQuestionResponse } from "./interfaces"; import { processRawChunkString } from "./streamingUtils"; export interface QuestionValidationArgs { @@ -45,7 +45,7 @@ export const questionValidationStreamed = async ({ } const [completedChunks, partialChunk] = processRawChunkString< - AnswerPiece | ValidQuestionResponse + AnswerPiecePacket | ValidQuestionResponse >(decoder.decode(value, { stream: true }), previousPartialChunk); if (!completedChunks.length && !partialChunk) { break; @@ -54,7 +54,7 @@ export const questionValidationStreamed = async ({ completedChunks.forEach((chunk) => { if (Object.hasOwn(chunk, "answer_piece")) { - reasoning += (chunk as AnswerPiece).answer_piece; + reasoning += (chunk as AnswerPiecePacket).answer_piece; update({ reasoning, });