Add selected docs in UI + rework the backend flow a bit(#754)

Changes the flow so that the selected docs are sent over in a separate packet rather than as part of the initial packet for the streaming QA endpoint.
This commit is contained in:
Chris Weaver 2023-11-21 19:46:12 -08:00 committed by GitHub
parent e78aefb408
commit c1e19d0d93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 389 additions and 181 deletions

View File

@ -42,7 +42,7 @@ from danswer.search.models import IndexFilters
from danswer.search.models import SearchQuery from danswer.search.models import SearchQuery
from danswer.search.models import SearchType from danswer.search.models import SearchType
from danswer.search.search_runner import chunks_to_search_docs from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import search_chunks from danswer.search.search_runner import full_chunk_search
from danswer.server.models import RetrievalDocs from danswer.server.models import RetrievalDocs
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import extract_embedded_json from danswer.utils.text_processing import extract_embedded_json
@ -140,8 +140,9 @@ def danswer_chat_retrieval(
) )
# Good Debug/Breakpoint # Good Debug/Breakpoint
top_chunks, _ = search_chunks( top_chunks, _ = full_chunk_search(
query=search_query, document_index=get_default_document_index() query=search_query,
document_index=get_default_document_index(),
) )
if not top_chunks: if not top_chunks:

View File

@ -1,6 +1,7 @@
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Iterator from collections.abc import Iterator
from functools import partial from functools import partial
from typing import cast
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -15,15 +16,18 @@ from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.models import LLMMetricsContainer from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_utils import get_chunks_for_qa from danswer.direct_qa.qa_utils import get_chunks_for_qa
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import query_intent
from danswer.search.models import QueryFlow from danswer.search.models import QueryFlow
from danswer.search.models import RerankMetricsContainer from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import RetrievalMetricsContainer
from danswer.search.search_runner import chunks_to_search_docs from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import danswer_search from danswer.search.search_runner import danswer_search
from danswer.search.search_runner import danswer_search_generator
from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
from danswer.server.models import LLMRelevanceFilterResponse
from danswer.server.models import QADocsResponse from danswer.server.models import QADocsResponse
from danswer.server.models import QAResponse from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest from danswer.server.models import QuestionRequest
@ -206,24 +210,19 @@ def answer_qa_query_stream(
question.favor_recent = favor_recent question.favor_recent = favor_recent
question.filters.source_type = source_filters question.filters.source_type = source_filters
top_chunks, llm_chunk_selection, query_event_id = danswer_search( search_generator = danswer_search_generator(
question=question, question=question,
user=user, user=user,
db_session=db_session, db_session=db_session,
document_index=get_default_document_index(), document_index=get_default_document_index(),
) )
# first fetch and return to the UI the top chunks so the user can
# immediately see some results
top_chunks = cast(list[InferenceChunk], next(search_generator))
top_docs = chunks_to_search_docs(top_chunks) top_docs = chunks_to_search_docs(top_chunks)
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
batch_offset=offset_count,
)
initial_response = QADocsResponse( initial_response = QADocsResponse(
top_documents=top_docs, top_documents=top_docs,
llm_chunks_indices=llm_chunks_indices,
# if generative AI is disabled, set flow as search so frontend # if generative AI is disabled, set flow as search so frontend
# doesn't ask the user if they want to run QA over more documents # doesn't ask the user if they want to run QA over more documents
predicted_flow=QueryFlow.SEARCH predicted_flow=QueryFlow.SEARCH
@ -233,13 +232,29 @@ def answer_qa_query_stream(
time_cutoff=time_cutoff, time_cutoff=time_cutoff,
favor_recent=favor_recent, favor_recent=favor_recent,
).dict() ).dict()
yield get_json_line(initial_response) yield get_json_line(initial_response)
if not top_chunks: if not top_chunks:
logger.debug("No Documents Found") logger.debug("No Documents Found")
return return
# next apply the LLM filtering
llm_chunk_selection = cast(list[bool], next(search_generator))
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
batch_offset=offset_count,
)
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
relevant_chunk_indices=llm_chunks_indices
).dict()
yield get_json_line(llm_relevance_filtering_response)
# finally get the query ID from the search generator for updating the
# row in Postgres. This is the end of the `search_generator` - any future
# calls to `next` will raise StopIteration
query_event_id = cast(int, next(search_generator))
if disable_generative_answer: if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled") logger.debug("Skipping QA because generative AI is disabled")
return return

View File

@ -94,6 +94,10 @@ class InferenceChunk(BaseChunk):
# when the doc was last updated # when the doc was last updated
updated_at: datetime | None updated_at: datetime | None
@property
def unique_id(self) -> str:
return f"{self.document_id}__{self.chunk_id}"
def __repr__(self) -> str: def __repr__(self) -> str:
blurb_words = self.blurb.split() blurb_words = self.blurb.split()
short_blurb = "" short_blurb = ""

View File

@ -1,6 +1,6 @@
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Generator
from copy import deepcopy from copy import deepcopy
from typing import Any
from typing import cast from typing import cast
import numpy import numpy
@ -50,6 +50,13 @@ from danswer.utils.timing import log_function_time
logger = setup_logger() logger = setup_logger()
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
top_links = [
c.source_links[0] if c.source_links is not None else "No Link" for c in chunks
]
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
def lemmatize_text(text: str) -> list[str]: def lemmatize_text(text: str) -> list[str]:
lemmatizer = WordNetLemmatizer() lemmatizer = WordNetLemmatizer()
word_tokens = word_tokenize(text) word_tokens = word_tokenize(text)
@ -329,29 +336,15 @@ def apply_boost(
return final_chunks return final_chunks
def search_chunks( def retrieve_chunks(
query: SearchQuery, query: SearchQuery,
document_index: DocumentIndex, document_index: DocumentIndex,
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION, multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None, | None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, ) -> list[InferenceChunk]:
) -> tuple[list[InferenceChunk], list[bool]]: """Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
"""Returns a list of the best chunks from search/reranking and if the chunks are relevant via LLM.
For sake of speed, the system cannot rerank all retrieved chunks
Also pass the chunks through LLM to determine if they are relevant (binary for speed)
Only the first max_llm_filter_chunks
"""
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
top_links = [
c.source_links[0] if c.source_links is not None else "No Link"
for c in chunks
]
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
# Don't do query expansion on complex queries, rephrasings likely would not work well # Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_query_expansion or "\n" in query.query or "\r" in query.query: if not multilingual_query_expansion or "\n" in query.query or "\r" in query.query:
top_chunks = doc_index_retrieval( top_chunks = doc_index_retrieval(
@ -377,7 +370,7 @@ def search_chunks(
f"{query.search_type.value.capitalize()} search returned no results " f"{query.search_type.value.capitalize()} search returned no results "
f"with filters: {query.filters}" f"with filters: {query.filters}"
) )
return [], [] return []
if retrieval_metrics_callback is not None: if retrieval_metrics_callback is not None:
chunk_metrics = [ chunk_metrics = [
@ -393,67 +386,160 @@ def search_chunks(
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics) RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
) )
functions_to_run: list[FunctionCall] = [] return top_chunks
# Keyword Search should not do reranking
if query.search_type == SearchType.KEYWORD or query.skip_rerank: def should_rerank(query: SearchQuery) -> bool:
_log_top_chunk_links(query.search_type.value, top_chunks) # don't re-rank for keyword search
run_rerank_id: str | None = None return query.search_type != SearchType.KEYWORD and not query.skip_rerank
else:
run_rerank = FunctionCall(
semantic_reranking, def should_apply_llm_based_relevance_filter(query: SearchQuery) -> bool:
(query.query, top_chunks[: query.num_rerank]), return not query.skip_llm_chunk_filter
{"rerank_metrics_callback": rerank_metrics_callback},
def rerank_chunks(
query: SearchQuery,
chunks_to_rerank: list[InferenceChunk],
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> list[InferenceChunk]:
ranked_chunks, _ = semantic_reranking(
query=query.query,
chunks=chunks_to_rerank[: query.num_rerank],
rerank_metrics_callback=rerank_metrics_callback,
)
lower_chunks = chunks_to_rerank[query.num_rerank :]
# Scores from rerank cannot be meaningfully combined with scores without rerank
for lower_chunk in lower_chunks:
lower_chunk.score = None
ranked_chunks.extend(lower_chunks)
return ranked_chunks
def filter_chunks(
query: SearchQuery,
chunks_to_filter: list[InferenceChunk],
) -> list[str]:
"""Filters chunks based on whether the LLM thought they were relevant to the query.
Returns a list of the unique chunk IDs that were marked as relevant"""
chunks_to_filter = chunks_to_filter[: query.max_llm_filter_chunks]
llm_chunk_selection = llm_batch_eval_chunks(
query=query.query,
chunk_contents=[chunk.content for chunk in chunks_to_filter],
)
return [
chunk.unique_id
for ind, chunk in enumerate(chunks_to_filter)
if llm_chunk_selection[ind]
]
def full_chunk_search(
query: SearchQuery,
document_index: DocumentIndex,
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk], list[bool]]:
"""A utility which provides an easier interface than `full_chunk_search_generator`.
Rather than returning the chunks and llm relevance filter results in two separate
yields, just returns them both at once."""
search_generator = full_chunk_search_generator(
query=query,
document_index=document_index,
hybrid_alpha=hybrid_alpha,
multilingual_query_expansion=multilingual_query_expansion,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
top_chunks = cast(list[InferenceChunk], next(search_generator))
llm_chunk_selection = cast(list[bool], next(search_generator))
return top_chunks, llm_chunk_selection
def full_chunk_search_generator(
query: SearchQuery,
document_index: DocumentIndex,
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Generator[list[InferenceChunk] | list[bool], None, None]:
"""Always yields twice. Once with the selected chunks and once with the LLM relevance filter result."""
chunks_yielded = False
retrieved_chunks = retrieve_chunks(
query=query,
document_index=document_index,
hybrid_alpha=hybrid_alpha,
multilingual_query_expansion=multilingual_query_expansion,
retrieval_metrics_callback=retrieval_metrics_callback,
)
post_processing_tasks: list[FunctionCall] = []
rerank_task_id = None
if should_rerank(query):
post_processing_tasks.append(
FunctionCall(
rerank_chunks,
(
query,
retrieved_chunks,
rerank_metrics_callback,
),
)
) )
functions_to_run.append(run_rerank) rerank_task_id = post_processing_tasks[-1].result_id
run_rerank_id = run_rerank.result_id
run_llm_filter_id = None
if not query.skip_llm_chunk_filter:
run_llm_filter = FunctionCall(
llm_batch_eval_chunks,
(
query.query,
[chunk.content for chunk in top_chunks[: query.max_llm_filter_chunks]],
),
{},
)
functions_to_run.append(run_llm_filter)
run_llm_filter_id = run_llm_filter.result_id
parallel_results: dict[str, Any] = {}
if functions_to_run:
parallel_results = run_functions_in_parallel(functions_to_run)
ranked_results = parallel_results.get(str(run_rerank_id))
if ranked_results is None:
ranked_chunks = top_chunks
sorted_indices = [i for i in range(len(top_chunks))]
else: else:
ranked_chunks, orig_indices = ranked_results final_chunks = retrieved_chunks
sorted_indices = orig_indices + list(range(len(orig_indices), len(top_chunks))) # NOTE: if we don't rerank, we can return the chunks immediately
lower_chunks = top_chunks[query.num_rerank :] # since we know this is the final order
# Scores from rerank cannot be meaningfully combined with scores without rerank _log_top_chunk_links(query.search_type.value, final_chunks)
for lower_chunk in lower_chunks: yield final_chunks
lower_chunk.score = None chunks_yielded = True
ranked_chunks.extend(lower_chunks)
llm_chunk_selection = parallel_results.get(str(run_llm_filter_id)) llm_filter_task_id = None
if llm_chunk_selection is None: if should_apply_llm_based_relevance_filter(query):
reranked_llm_chunk_selection = [True for _ in top_chunks] post_processing_tasks.append(
else: FunctionCall(
llm_chunk_selection.extend( filter_chunks,
[False for _ in top_chunks[query.max_llm_filter_chunks :]] (query, retrieved_chunks[: query.max_llm_filter_chunks]),
)
) )
reranked_llm_chunk_selection = [ llm_filter_task_id = post_processing_tasks[-1].result_id
llm_chunk_selection[ind] for ind in sorted_indices
]
_log_top_chunk_links(query.search_type.value, ranked_chunks)
return ranked_chunks, reranked_llm_chunk_selection post_processing_results = run_functions_in_parallel(post_processing_tasks)
reranked_chunks = cast(
list[InferenceChunk] | None,
post_processing_results.get(str(rerank_task_id)) if rerank_task_id else None,
)
if reranked_chunks:
if chunks_yielded:
logger.error(
"Trying to yield re-ranked chunks, but chunks were already yielded. This should never happen."
)
else:
_log_top_chunk_links(query.search_type.value, reranked_chunks)
yield reranked_chunks
llm_chunk_selection = cast(
list[str] | None,
post_processing_results.get(str(llm_filter_task_id))
if llm_filter_task_id
else None,
)
if llm_chunk_selection:
yield [chunk.unique_id in llm_chunk_selection for chunk in retrieved_chunks]
else:
yield [True for _ in reranked_chunks or retrieved_chunks]
def danswer_search( def danswer_search_generator(
question: QuestionRequest, question: QuestionRequest,
user: User | None, user: User | None,
db_session: Session, db_session: Session,
@ -462,7 +548,11 @@ def danswer_search(
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None, | None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk], list[bool], int]: ) -> Generator[list[InferenceChunk] | list[bool] | int, None, None]:
"""The main entry point for search. This fetches the relevant documents from Vespa
based on the provided query (applying permissions / filters), does any specified
post-processing, and returns the results. It also create an entry in the query_event table
for this search event."""
query_event_id = create_query_event( query_event_id = create_query_event(
query=question.query, query=question.query,
search_type=question.search_type, search_type=question.search_type,
@ -490,20 +580,54 @@ def danswer_search(
skip_llm_chunk_filter=skip_llm_chunk_filter, skip_llm_chunk_filter=skip_llm_chunk_filter,
) )
top_chunks, llm_chunk_selection = search_chunks( search_generator = full_chunk_search_generator(
query=search_query, query=search_query,
document_index=document_index, document_index=document_index,
retrieval_metrics_callback=retrieval_metrics_callback, retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback, rerank_metrics_callback=rerank_metrics_callback,
) )
top_chunks = cast(list[InferenceChunk], next(search_generator))
yield top_chunks
retrieved_ids = [doc.document_id for doc in top_chunks] if top_chunks else [] llm_chunk_selection = cast(list[bool], next(search_generator))
yield llm_chunk_selection
update_query_event_retrieved_documents( update_query_event_retrieved_documents(
db_session=db_session, db_session=db_session,
retrieved_document_ids=retrieved_ids, retrieved_document_ids=[doc.document_id for doc in top_chunks]
if top_chunks
else [],
query_id=query_event_id, query_id=query_event_id,
user_id=None if user is None else user.id, user_id=None if user is None else user.id,
) )
yield query_event_id
def danswer_search(
question: QuestionRequest,
user: User | None,
db_session: Session,
document_index: DocumentIndex,
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk], list[bool], int]:
"""Returns a tuple of the top chunks, the LLM relevance filter results, and the query event ID.
Presents a simpler interface than the underlying `danswer_search_generator`, as callers no
longer need to worry about the order / have nicer typing. This should be used for flows which
do not require streaming."""
search_generator = danswer_search_generator(
question=question,
user=user,
db_session=db_session,
document_index=document_index,
skip_llm_chunk_filter=skip_llm_chunk_filter,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
top_chunks = cast(list[InferenceChunk], next(search_generator))
llm_chunk_selection = cast(list[bool], next(search_generator))
query_event_id = cast(int, next(search_generator))
return top_chunks, llm_chunk_selection, query_event_id return top_chunks, llm_chunk_selection, query_event_id

View File

@ -222,7 +222,6 @@ class QAResponse(SearchResponse):
# First chunk of info for streaming QA # First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs): class QADocsResponse(RetrievalDocs):
llm_chunks_indices: list[int]
predicted_flow: QueryFlow predicted_flow: QueryFlow
predicted_search: SearchType predicted_search: SearchType
time_cutoff: datetime | None time_cutoff: datetime | None
@ -236,6 +235,11 @@ class QADocsResponse(RetrievalDocs):
return initial_dict return initial_dict
# second chunk of info for streaming QA
class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
class CreateChatSessionID(BaseModel): class CreateChatSessionID(BaseModel):
chat_session_id: int chat_session_id: int

View File

@ -5,6 +5,7 @@ interface HoverPopupProps {
popupContent: string | JSX.Element; popupContent: string | JSX.Element;
classNameModifications?: string; classNameModifications?: string;
direction?: "left" | "bottom"; direction?: "left" | "bottom";
style?: "basic" | "dark";
} }
export const HoverPopup = ({ export const HoverPopup = ({
@ -12,6 +13,7 @@ export const HoverPopup = ({
popupContent, popupContent,
classNameModifications, classNameModifications,
direction = "bottom", direction = "bottom",
style = "basic",
}: HoverPopupProps) => { }: HoverPopupProps) => {
const [hovered, setHovered] = useState(false); const [hovered, setHovered] = useState(false);
@ -37,7 +39,10 @@ export const HoverPopup = ({
<div className={`absolute ${popupDirectionClass} z-30`}> <div className={`absolute ${popupDirectionClass} z-30`}>
<div <div
className={ className={
`bg-gray-800 px-3 py-2 rounded shadow-lg ` + `px-3 py-2 rounded ` +
(style === "dark"
? "bg-dark-tremor-background-muted border border-gray-800"
: "bg-gray-800 shadow-lg") +
(classNameModifications || "") (classNameModifications || "")
} }
> >

View File

@ -3,8 +3,9 @@ import { DocumentFeedbackBlock } from "./DocumentFeedbackBlock";
import { getSourceIcon } from "../source"; import { getSourceIcon } from "../source";
import { useState } from "react"; import { useState } from "react";
import { PopupSpec } from "../admin/connectors/Popup"; import { PopupSpec } from "../admin/connectors/Popup";
import { timeAgo } from "@/lib/time"; import { HoverPopup } from "@/components/HoverPopup";
import { DocumentUpdatedAtBadge } from "./DocumentUpdatedAtBadge"; import { DocumentUpdatedAtBadge } from "./DocumentUpdatedAtBadge";
import { FiCrosshair, FiInfo, FiRadio } from "react-icons/fi";
export const buildDocumentSummaryDisplay = ( export const buildDocumentSummaryDisplay = (
matchHighlights: string[], matchHighlights: string[],
@ -106,12 +107,14 @@ export const buildDocumentSummaryDisplay = (
interface DocumentDisplayProps { interface DocumentDisplayProps {
document: DanswerDocument; document: DanswerDocument;
queryEventId: number | null; queryEventId: number | null;
isSelected: boolean;
setPopup: (popupSpec: PopupSpec | null) => void; setPopup: (popupSpec: PopupSpec | null) => void;
} }
export const DocumentDisplay = ({ export const DocumentDisplay = ({
document, document,
queryEventId, queryEventId,
isSelected,
setPopup, setPopup,
}: DocumentDisplayProps) => { }: DocumentDisplayProps) => {
const [isHovered, setIsHovered] = useState(false); const [isHovered, setIsHovered] = useState(false);
@ -132,7 +135,31 @@ export const DocumentDisplay = ({
> >
<div className="flex relative"> <div className="flex relative">
{document.score !== null && ( {document.score !== null && (
<div className="absolute -left-10 top-2/4 -translate-y-2/4 w-10 flex"> <div
className={
"absolute top-2/4 -translate-y-2/4 flex " +
(isSelected ? "-left-14 w-14" : "-left-10 w-10")
}
>
{isSelected && (
<div className="w-4 h-4 my-auto mr-1 flex flex-col">
<HoverPopup
mainContent={<FiRadio className="text-gray-500 my-auto" />}
popupContent={
<div className="text-xs text-gray-300 w-36 flex">
<div className="flex mx-auto">
<div className="w-3 h-3 flex flex-col my-auto mr-1">
<FiInfo className="my-auto" />
</div>
<div className="my-auto">The AI liked this doc!</div>
</div>
</div>
}
direction="bottom"
style="dark"
/>
</div>
)}
<div <div
className={` className={`
text-xs text-xs

View File

@ -33,6 +33,17 @@ const removeDuplicateDocs = (documents: DanswerDocument[]) => {
return output; return output;
}; };
const getSelectedDocumentIds = (
documents: DanswerDocument[],
selectedIndices: number[]
) => {
const selectedDocumentIds = new Set<string>();
selectedIndices.forEach((ind) => {
selectedDocumentIds.add(documents[ind].document_id);
});
return selectedDocumentIds;
};
interface SearchResultsDisplayProps { interface SearchResultsDisplayProps {
searchResponse: SearchResponse | null; searchResponse: SearchResponse | null;
validQuestionResponse: ValidQuestionResponse; validQuestionResponse: ValidQuestionResponse;
@ -95,6 +106,11 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
}); });
} }
const selectedDocumentIds = getSelectedDocumentIds(
documents || [],
searchResponse.selectedDocIndices || []
);
const shouldDisplayQA = const shouldDisplayQA =
searchResponse.suggestedFlowType === FlowType.QUESTION_ANSWER || searchResponse.suggestedFlowType === FlowType.QUESTION_ANSWER ||
defaultOverrides.forceDisplayQA; defaultOverrides.forceDisplayQA;
@ -175,6 +191,7 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
key={document.document_id} key={document.document_id}
document={document} document={document}
queryEventId={queryEventId} queryEventId={queryEventId}
isSelected={selectedDocumentIds.has(document.document_id)}
setPopup={setPopup} setPopup={setPopup}
/> />
))} ))}

View File

@ -5,12 +5,10 @@ import { SearchBar } from "./SearchBar";
import { SearchResultsDisplay } from "./SearchResultsDisplay"; import { SearchResultsDisplay } from "./SearchResultsDisplay";
import { SourceSelector } from "./filtering/Filters"; import { SourceSelector } from "./filtering/Filters";
import { Connector, DocumentSet } from "@/lib/types"; import { Connector, DocumentSet } from "@/lib/types";
import { SearchTypeSelector } from "./SearchTypeSelector";
import { import {
DanswerDocument, DanswerDocument,
Quote, Quote,
SearchResponse, SearchResponse,
Source,
FlowType, FlowType,
SearchType, SearchType,
SearchDefaultOverrides, SearchDefaultOverrides,
@ -18,7 +16,6 @@ import {
ValidQuestionResponse, ValidQuestionResponse,
} from "@/lib/search/interfaces"; } from "@/lib/search/interfaces";
import { searchRequestStreamed } from "@/lib/search/streamingQa"; import { searchRequestStreamed } from "@/lib/search/streamingQa";
import Cookies from "js-cookie";
import { SearchHelper } from "./SearchHelper"; import { SearchHelper } from "./SearchHelper";
import { CancellationToken, cancellable } from "@/lib/search/cancellable"; import { CancellationToken, cancellable } from "@/lib/search/cancellable";
import { NEXT_PUBLIC_DISABLE_STREAMING } from "@/lib/constants"; import { NEXT_PUBLIC_DISABLE_STREAMING } from "@/lib/constants";
@ -77,6 +74,7 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
documents: null, documents: null,
suggestedSearchType: null, suggestedSearchType: null,
suggestedFlowType: null, suggestedFlowType: null,
selectedDocIndices: null,
error: null, error: null,
queryEventId: null, queryEventId: null,
}; };
@ -105,6 +103,11 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
...(prevState || initialSearchResponse), ...(prevState || initialSearchResponse),
suggestedFlowType, suggestedFlowType,
})); }));
const updateSelectedDocIndices = (docIndices: number[]) =>
setSearchResponse((prevState) => ({
...(prevState || initialSearchResponse),
selectedDocIndices: docIndices,
}));
const updateError = (error: FlowType) => const updateError = (error: FlowType) =>
setSearchResponse((prevState) => ({ setSearchResponse((prevState) => ({
...(prevState || initialSearchResponse), ...(prevState || initialSearchResponse),
@ -159,6 +162,10 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
cancellationToken: lastSearchCancellationToken.current, cancellationToken: lastSearchCancellationToken.current,
fn: updateSuggestedFlowType, fn: updateSuggestedFlowType,
}), }),
updateSelectedDocIndices: cancellable({
cancellationToken: lastSearchCancellationToken.current,
fn: updateSelectedDocIndices,
}),
updateError: cancellable({ updateError: cancellable({
cancellationToken: lastSearchCancellationToken.current, cancellationToken: lastSearchCancellationToken.current,
fn: updateError, fn: updateError,
@ -185,7 +192,7 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
}; };
return ( return (
<div className="relative max-w-[2000px] xl:max-w-[1400px] mx-auto"> <div className="relative max-w-[2000px] xl:max-w-[1430px] mx-auto">
<div className="absolute left-0 hidden 2xl:block w-64"> <div className="absolute left-0 hidden 2xl:block w-64">
{(connectors.length > 0 || documentSets.length > 0) && ( {(connectors.length > 0 || documentSets.length > 0) && (
<SourceSelector <SourceSelector
@ -195,7 +202,7 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
/> />
)} )}
<div className="mt-10 pr-2"> <div className="mt-10 pr-5">
<SearchHelper <SearchHelper
isFetching={isFetching} isFetching={isFetching}
searchResponse={searchResponse} searchResponse={searchResponse}

View File

@ -13,10 +13,14 @@ export const SearchType = {
}; };
export type SearchType = (typeof SearchType)[keyof typeof SearchType]; export type SearchType = (typeof SearchType)[keyof typeof SearchType];
export interface AnswerPiece { export interface AnswerPiecePacket {
answer_piece: string; answer_piece: string;
} }
export interface ErrorMessagePacket {
error: string;
}
export interface Quote { export interface Quote {
quote: string; quote: string;
document_id: string; document_id: string;
@ -26,6 +30,10 @@ export interface Quote {
semantic_identifier: string; semantic_identifier: string;
} }
export interface QuotesInfoPacket {
quotes: Quote[];
}
export interface DanswerDocument { export interface DanswerDocument {
document_id: string; document_id: string;
link: string; link: string;
@ -39,12 +47,29 @@ export interface DanswerDocument {
updated_at: string | null; updated_at: string | null;
} }
export interface DocumentInfoPacket {
top_documents: DanswerDocument[];
predicted_flow: FlowType | null;
predicted_search: SearchType | null;
time_cutoff: string | null;
favor_recent: boolean;
}
export interface LLMRelevanceFilterPacket {
relevant_chunk_indices: number[];
}
export interface QueryEventIdPacket {
query_event_id: number;
}
export interface SearchResponse { export interface SearchResponse {
suggestedSearchType: SearchType | null; suggestedSearchType: SearchType | null;
suggestedFlowType: FlowType | null; suggestedFlowType: FlowType | null;
answer: string | null; answer: string | null;
quotes: Quote[] | null; quotes: Quote[] | null;
documents: DanswerDocument[] | null; documents: DanswerDocument[] | null;
selectedDocIndices: number[] | null;
error: string | null; error: string | null;
queryEventId: number | null; queryEventId: number | null;
} }
@ -73,6 +98,7 @@ export interface SearchRequestArgs {
updateCurrentAnswer: (val: string) => void; updateCurrentAnswer: (val: string) => void;
updateQuotes: (quotes: Quote[]) => void; updateQuotes: (quotes: Quote[]) => void;
updateDocs: (documents: DanswerDocument[]) => void; updateDocs: (documents: DanswerDocument[]) => void;
updateSelectedDocIndices: (docIndices: number[]) => void;
updateSuggestedSearchType: (searchType: SearchType) => void; updateSuggestedSearchType: (searchType: SearchType) => void;
updateSuggestedFlowType: (flowType: FlowType) => void; updateSuggestedFlowType: (flowType: FlowType) => void;
updateError: (error: string) => void; updateError: (error: string) => void;

View File

@ -21,6 +21,10 @@ export const searchRequest = async ({
selectedSearchType, selectedSearchType,
offset, offset,
}: SearchRequestArgs) => { }: SearchRequestArgs) => {
/*
NOTE: does not support the full functionality (AI selected answers). Should not be used if
at all possible - use `searchRequestStreamed` instead.
*/
let answer = ""; let answer = "";
let quotes: Quote[] | null = null; let quotes: Quote[] | null = null;
let relevantDocuments: DanswerDocument[] | null = null; let relevantDocuments: DanswerDocument[] | null = null;

View File

@ -1,57 +1,17 @@
import { import {
AnswerPiecePacket,
DanswerDocument, DanswerDocument,
DocumentInfoPacket,
ErrorMessagePacket,
LLMRelevanceFilterPacket,
QueryEventIdPacket,
Quote, Quote,
QuotesInfoPacket,
SearchRequestArgs, SearchRequestArgs,
SearchType,
} from "./interfaces"; } from "./interfaces";
import { processRawChunkString } from "./streamingUtils";
import { buildFilters } from "./utils"; import { buildFilters } from "./utils";
const processSingleChunk = (
chunk: string,
currPartialChunk: string | null
): [{ [key: string]: any } | null, string | null] => {
const completeChunk = (currPartialChunk || "") + chunk;
try {
// every complete chunk should be valid JSON
const chunkJson = JSON.parse(completeChunk);
return [chunkJson, null];
} catch (err) {
// if it's not valid JSON, then it's probably an incomplete chunk
return [null, completeChunk];
}
};
const processRawChunkString = (
rawChunkString: string,
previousPartialChunk: string | null
): [any[], string | null] => {
/* This is required because, in practice, we see that nginx does not send over
each chunk one at a time even with buffering turned off. Instead,
chunks are sometimes in batches or are sometimes incomplete */
if (!rawChunkString) {
return [[], null];
}
const chunkSections = rawChunkString
.split("\n")
.filter((chunk) => chunk.length > 0);
let parsedChunkSections: any[] = [];
let currPartialChunk = previousPartialChunk;
chunkSections.forEach((chunk) => {
const [processedChunk, partialChunk] = processSingleChunk(
chunk,
currPartialChunk
);
if (processedChunk) {
parsedChunkSections.push(processedChunk);
currPartialChunk = null;
} else {
currPartialChunk = partialChunk;
}
});
return [parsedChunkSections, currPartialChunk];
};
export const searchRequestStreamed = async ({ export const searchRequestStreamed = async ({
query, query,
sources, sources,
@ -62,6 +22,7 @@ export const searchRequestStreamed = async ({
updateDocs, updateDocs,
updateSuggestedSearchType, updateSuggestedSearchType,
updateSuggestedFlowType, updateSuggestedFlowType,
updateSelectedDocIndices,
updateError, updateError,
updateQueryEventId, updateQueryEventId,
offset, offset,
@ -87,7 +48,7 @@ export const searchRequestStreamed = async ({
const reader = response.body?.getReader(); const reader = response.body?.getReader();
const decoder = new TextDecoder("utf-8"); const decoder = new TextDecoder("utf-8");
let previousPartialChunk = null; let previousPartialChunk: string | null = null;
while (true) { while (true) {
const rawChunk = await reader?.read(); const rawChunk = await reader?.read();
if (!rawChunk) { if (!rawChunk) {
@ -99,47 +60,50 @@ export const searchRequestStreamed = async ({
} }
// Process each chunk as it arrives // Process each chunk as it arrives
const [completedChunks, partialChunk] = processRawChunkString( const [completedChunks, partialChunk] = processRawChunkString<
decoder.decode(value, { stream: true }), | AnswerPiecePacket
previousPartialChunk | ErrorMessagePacket
); | QuotesInfoPacket
| DocumentInfoPacket
| LLMRelevanceFilterPacket
| QueryEventIdPacket
>(decoder.decode(value, { stream: true }), previousPartialChunk);
if (!completedChunks.length && !partialChunk) { if (!completedChunks.length && !partialChunk) {
break; break;
} }
previousPartialChunk = partialChunk; previousPartialChunk = partialChunk as string | null;
completedChunks.forEach((chunk) => { completedChunks.forEach((chunk) => {
// TODO: clean up response / this logic // check for answer peice / end of answer
const answerChunk = chunk.answer_piece; if (Object.hasOwn(chunk, "answer_piece")) {
if (answerChunk) { const answerPiece = (chunk as AnswerPiecePacket).answer_piece;
answer += answerChunk; if (answerPiece !== null) {
updateCurrentAnswer(answer); answer += (chunk as AnswerPiecePacket).answer_piece;
return;
}
if (answerChunk === null) {
// set quotes as non-null to signify that the answer is finished and
// we're now looking for quotes
updateQuotes([]);
if (
answer &&
!answer.endsWith(".") &&
!answer.endsWith("?") &&
!answer.endsWith("!")
) {
answer += ".";
updateCurrentAnswer(answer); updateCurrentAnswer(answer);
} else {
// set quotes as non-null to signify that the answer is finished and
// we're now looking for quotes
updateQuotes([]);
if (
answer &&
!answer.endsWith(".") &&
!answer.endsWith("?") &&
!answer.endsWith("!")
) {
answer += ".";
updateCurrentAnswer(answer);
}
} }
return; return;
} }
const errorMsg = chunk.error; if (Object.hasOwn(chunk, "error")) {
if (errorMsg) { updateError((chunk as ErrorMessagePacket).error);
updateError(errorMsg);
return; return;
} }
// These all come together // These all come together
if (Object.hasOwn(chunk, "top_documents")) { if (Object.hasOwn(chunk, "top_documents")) {
chunk = chunk as DocumentInfoPacket;
const topDocuments = chunk.top_documents as DanswerDocument[] | null; const topDocuments = chunk.top_documents as DanswerDocument[] | null;
if (topDocuments) { if (topDocuments) {
relevantDocuments = topDocuments; relevantDocuments = topDocuments;
@ -153,19 +117,29 @@ export const searchRequestStreamed = async ({
if (chunk.predicted_search) { if (chunk.predicted_search) {
updateSuggestedSearchType(chunk.predicted_search); updateSuggestedSearchType(chunk.predicted_search);
} }
return;
}
if (Object.hasOwn(chunk, "relevant_chunk_indices")) {
const relevantChunkIndices = (chunk as LLMRelevanceFilterPacket)
.relevant_chunk_indices;
if (relevantChunkIndices) {
updateSelectedDocIndices(relevantChunkIndices);
}
return; return;
} }
// Check for quote section // Check for quote section
if (chunk.quotes) { if (Object.hasOwn(chunk, "quotes")) {
quotes = chunk.quotes as Quote[]; quotes = (chunk as QuotesInfoPacket).quotes;
updateQuotes(quotes); updateQuotes(quotes);
return; return;
} }
// check for query ID section // check for query ID section
if (chunk.query_event_id) { if (Object.hasOwn(chunk, "query_event_id")) {
updateQueryEventId(chunk.query_event_id); updateQueryEventId((chunk as QueryEventIdPacket).query_event_id);
return; return;
} }

View File

@ -1,4 +1,4 @@
import { AnswerPiece, ValidQuestionResponse } from "./interfaces"; import { AnswerPiecePacket, ValidQuestionResponse } from "./interfaces";
import { processRawChunkString } from "./streamingUtils"; import { processRawChunkString } from "./streamingUtils";
export interface QuestionValidationArgs { export interface QuestionValidationArgs {
@ -45,7 +45,7 @@ export const questionValidationStreamed = async <T>({
} }
const [completedChunks, partialChunk] = processRawChunkString< const [completedChunks, partialChunk] = processRawChunkString<
AnswerPiece | ValidQuestionResponse AnswerPiecePacket | ValidQuestionResponse
>(decoder.decode(value, { stream: true }), previousPartialChunk); >(decoder.decode(value, { stream: true }), previousPartialChunk);
if (!completedChunks.length && !partialChunk) { if (!completedChunks.length && !partialChunk) {
break; break;
@ -54,7 +54,7 @@ export const questionValidationStreamed = async <T>({
completedChunks.forEach((chunk) => { completedChunks.forEach((chunk) => {
if (Object.hasOwn(chunk, "answer_piece")) { if (Object.hasOwn(chunk, "answer_piece")) {
reasoning += (chunk as AnswerPiece).answer_piece; reasoning += (chunk as AnswerPiecePacket).answer_piece;
update({ update({
reasoning, reasoning,
}); });