Better error message on GPT failures (#187)

* Better error message on GPT-call failures

* Add support for disabling Generative AI
This commit is contained in:
Chris Weaver
2023-07-16 16:25:33 -07:00
committed by GitHub
parent 6c584f0650
commit 676538da61
14 changed files with 246 additions and 107 deletions

View File

@@ -12,6 +12,10 @@ APP_PORT = 8080
##### #####
BLURB_LENGTH = 200 # Characters. Blurbs will be truncated at the first punctuation after this many characters. BLURB_LENGTH = 200 # Characters. Blurbs will be truncated at the first punctuation after this many characters.
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
# DISABLE_GENERATIVE_AI will turn of the question answering part of Danswer. Use this
# if you want to use Danswer as a search engine only and/or you are not comfortable sending
# anything to OpenAI. TODO: update this message once we support Azure / open source generative models.
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
##### #####
# Web Configs # Web Configs

View File

@@ -1,6 +1,7 @@
from typing import Any from typing import Any
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.llm import OpenAIChatCompletionQA from danswer.direct_qa.llm import OpenAIChatCompletionQA
from danswer.direct_qa.llm import OpenAICompletionQA from danswer.direct_qa.llm import OpenAICompletionQA
@@ -14,4 +15,4 @@ def get_default_backend_qa_model(
elif internal_model == "openai-chat-completion": elif internal_model == "openai-chat-completion":
return OpenAIChatCompletionQA(**kwargs) return OpenAIChatCompletionQA(**kwargs)
else: else:
raise ValueError("Unknown internal QA model set.") raise UnknownModelError(internal_model)

View File

@@ -1,29 +1,34 @@
import time
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.qdrant.store import QdrantIndex from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import query_intent
from danswer.search.keyword_search import retrieve_keyword_documents from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.server.models import QAResponse from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest from danswer.server.models import QuestionRequest
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger() logger = setup_logger()
@log_function_time()
def answer_question( def answer_question(
question: QuestionRequest, user: User | None, qa_model_timeout: int = QA_TIMEOUT question: QuestionRequest,
user: User | None,
qa_model_timeout: int = QA_TIMEOUT,
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
) -> QAResponse: ) -> QAResponse:
start_time = time.time()
query = question.query query = question.query
collection = question.collection collection = question.collection
filters = question.filters filters = question.filters
@@ -55,7 +60,32 @@ def answer_question(
predicted_search=predicted_search, predicted_search=predicted_search,
) )
qa_model = get_default_backend_qa_model(timeout=qa_model_timeout) if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
# set flow as search so frontend doesn't ask the user if they want
# to run QA over more documents
predicted_flow=QueryFlow.SEARCH,
predicted_search=predicted_search,
)
try:
qa_model = get_default_backend_qa_model(timeout=qa_model_timeout)
except (UnknownModelError, OpenAIKeyMissing) as e:
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
predicted_flow=predicted_flow,
predicted_search=predicted_search,
error_msg=str(e),
)
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(ranked_chunks): if chunk_offset >= len(ranked_chunks):
raise ValueError("Chunks offset too large, should not retry this many times") raise ValueError("Chunks offset too large, should not retry this many times")
@@ -71,8 +101,6 @@ def answer_question(
answer, quotes = None, None answer, quotes = None, None
error_msg = f"Error occurred in call to LLM - {e}" error_msg = f"Error occurred in call to LLM - {e}"
logger.info(f"Total QA took {time.time() - start_time} seconds")
return QAResponse( return QAResponse(
answer=answer, answer=answer,
quotes=quotes, quotes=quotes,

View File

@@ -0,0 +1,8 @@
class OpenAIKeyMissing(Exception):
def __init__(self, msg: str = "Unable to find an OpenAI Key") -> None:
super().__init__(msg)
class UnknownModelError(Exception):
def __init__(self, model_name: str) -> None:
super().__init__(f"Unknown Internal QA model name: {model_name}")

View File

@@ -27,6 +27,7 @@ from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import OPENAI_MODEL_VERSION from danswer.configs.model_configs import OPENAI_MODEL_VERSION
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ANSWER_PAT from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import get_chat_reflexion_msg from danswer.direct_qa.qa_prompts import get_chat_reflexion_msg
@@ -278,7 +279,7 @@ class OpenAICompletionQA(OpenAIQAModel):
try: try:
self.api_key = api_key or get_openai_api_key() self.api_key = api_key or get_openai_api_key()
except ConfigNotFoundError: except ConfigNotFoundError:
raise RuntimeError("No OpenAI Key available") raise OpenAIKeyMissing()
@log_function_time() @log_function_time()
def answer_question( def answer_question(
@@ -391,7 +392,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
try: try:
self.api_key = api_key or get_openai_api_key() self.api_key = api_key or get_openai_api_key()
except ConfigNotFoundError: except ConfigNotFoundError:
raise RuntimeError("No OpenAI Key available") raise OpenAIKeyMissing()
@log_function_time() @log_function_time()
def answer_question( def answer_question(
@@ -482,6 +483,6 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
logger.debug(model_output) logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs) _, quotes_dict = process_answer(model_output, context_docs)
yield {} if quotes_dict is None else quotes_dict yield {} if quotes_dict is None else quotes_dict

View File

@@ -5,6 +5,7 @@ from typing import cast
from danswer.auth.schemas import UserRole from danswer.auth.schemas import UserRole
from danswer.auth.users import current_admin_user from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user from danswer.auth.users import current_user
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
@@ -293,6 +294,20 @@ def connector_run_once(
def validate_existing_openai_api_key( def validate_existing_openai_api_key(
_: User = Depends(current_admin_user), _: User = Depends(current_admin_user),
) -> None: ) -> None:
# OpenAI key is only used for generative QA, so no need to validate this
# if it's turned off
if DISABLE_GENERATIVE_AI:
return
# always check if key exists
try:
openai_api_key = get_openai_api_key()
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# don't call OpenAI every single time, only validate every so often
check_key_time = "openai_api_key_last_check_time" check_key_time = "openai_api_key_last_check_time"
kv_store = get_dynamic_config_store() kv_store = get_dynamic_config_store()
curr_time = datetime.now() curr_time = datetime.now()
@@ -308,12 +323,10 @@ def validate_existing_openai_api_key(
get_dynamic_config_store().store(check_key_time, curr_time.timestamp()) get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
try: try:
openai_api_key = get_openai_api_key()
is_valid = check_openai_api_key_is_valid(openai_api_key) is_valid = check_openai_api_key_is_valid(openai_api_key)
except ConfigNotFoundError: except ValueError:
raise HTTPException(status_code=404, detail="Key not found") # this is the case where they aren't using an OpenAI-based model
except ValueError as e: is_valid = True
raise HTTPException(status_code=400, detail=str(e))
if not is_valid: if not is_valid:
raise HTTPException(status_code=400, detail="Invalid API key provided") raise HTTPException(status_code=400, detail="Invalid API key provided")

View File

@@ -1,8 +1,8 @@
import time
from collections.abc import Generator from collections.abc import Generator
from danswer.auth.users import current_user from danswer.auth.users import current_user
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.qdrant.store import QdrantIndex from danswer.datastores.qdrant.store import QdrantIndex
@@ -10,10 +10,13 @@ from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.answer_question import answer_question from danswer.direct_qa.answer_question import answer_question
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.llm import get_json_line from danswer.direct_qa.llm import get_json_line
from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents from danswer.search.semantic_search import retrieve_ranked_documents
@@ -22,6 +25,7 @@ from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest from danswer.server.models import QuestionRequest
from danswer.server.models import SearchResponse from danswer.server.models import SearchResponse
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import Depends from fastapi import Depends
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@@ -102,9 +106,10 @@ def stream_direct_qa(
logger.debug(f"Received QA query: {question.query}") logger.debug(f"Received QA query: {question.query}")
logger.debug(f"Query filters: {question.filters}") logger.debug(f"Query filters: {question.filters}")
def stream_qa_portions() -> Generator[str, None, None]: @log_generator_function_time()
start_time = time.time() def stream_qa_portions(
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
) -> Generator[str, None, None]:
query = question.query query = question.query
collection = question.collection collection = question.collection
filters = question.filters filters = question.filters
@@ -142,13 +147,27 @@ def stream_direct_qa(
initial_response_dict = { initial_response_dict = {
top_documents_key: [top_doc.json() for top_doc in top_docs], top_documents_key: [top_doc.json() for top_doc in top_docs],
unranked_top_docs_key: [doc.json() for doc in unranked_top_docs], unranked_top_docs_key: [doc.json() for doc in unranked_top_docs],
predicted_flow_key: predicted_flow, # if generative AI is disabled, set flow as search so frontend
# doesn't ask the user if they want to run QA over more documents
predicted_flow_key: QueryFlow.SEARCH
if disable_generative_answer
else predicted_flow,
predicted_search_key: predicted_search, predicted_search_key: predicted_search,
} }
logger.debug(send_packet_debug_msg.format(initial_response_dict)) logger.debug(send_packet_debug_msg.format(initial_response_dict))
yield get_json_line(initial_response_dict) yield get_json_line(initial_response_dict)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT) if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return
try:
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
except (UnknownModelError, OpenAIKeyMissing) as e:
logger.exception("Unable to get QA model")
yield get_json_line({"error": str(e)})
return
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(ranked_chunks): if chunk_offset >= len(ranked_chunks):
raise ValueError( raise ValueError(
@@ -165,10 +184,10 @@ def stream_direct_qa(
continue continue
logger.debug(f"Sending packet: {response_dict}") logger.debug(f"Sending packet: {response_dict}")
yield get_json_line(response_dict) yield get_json_line(response_dict)
except Exception: except Exception as e:
# exception is logged in the answer_question method, no need to re-log # exception is logged in the answer_question method, no need to re-log
pass yield get_json_line({"error": str(e)})
logger.info(f"Total QA took {time.time() - start_time} seconds")
return return
return StreamingResponse(stream_qa_portions(), media_type="application/json") return StreamingResponse(stream_qa_portions(), media_type="application/json")

View File

@@ -1,5 +1,6 @@
import time import time
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Generator
from typing import Any from typing import Any
from typing import cast from typing import cast
from typing import TypeVar from typing import TypeVar
@@ -9,6 +10,7 @@ from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
F = TypeVar("F", bound=Callable) F = TypeVar("F", bound=Callable)
FG = TypeVar("FG", bound=Callable[..., Generator])
def log_function_time( def log_function_time(
@@ -34,3 +36,30 @@ def log_function_time(
return cast(F, wrapped_func) return cast(F, wrapped_func)
return timing_wrapper return timing_wrapper
def log_generator_function_time(
func_name: str | None = None,
) -> Callable[[FG], FG]:
"""Build a timing wrapper for a function which returns a generator.
Logs how long the function took to run.
Use like:
@log_generator_function_time()
def my_func():
...
yield X
...
"""
def timing_wrapper(func: FG) -> FG:
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
yield from func(*args, **kwargs)
logger.info(
f"{func_name or func.__name__} took {time.time() - start_time} seconds"
)
return cast(F, wrapped_func)
return timing_wrapper

View File

@@ -24,6 +24,7 @@ services:
- DISABLE_AUTH=${DISABLE_AUTH:-True} - DISABLE_AUTH=${DISABLE_AUTH:-True}
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-} - GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-} - GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
volumes: volumes:
- local_dynamic_storage:/home/storage - local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage - file_connector_tmp_storage:/home/file_connector_storage

View File

@@ -40,7 +40,7 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
return null; return null;
} }
const { answer, quotes, documents } = searchResponse; const { answer, quotes, documents, error } = searchResponse;
if (isFetching && !answer) { if (isFetching && !answer) {
return ( return (
@@ -67,73 +67,78 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
}); });
} }
const shouldDisplayQA =
searchResponse.suggestedFlowType === FlowType.QUESTION_ANSWER ||
defaultOverrides.forceDisplayQA;
console.log(shouldDisplayQA);
return ( return (
<> <>
{answer && {answer && shouldDisplayQA && (
(searchResponse.suggestedFlowType !== FlowType.SEARCH || <div className="min-h-[14rem]">
defaultOverrides.forceDisplayQA) && ( <div className="p-4 border-2 rounded-md border-gray-700">
<div className="min-h-[14rem]"> <div className="flex mb-1">
<div className="p-4 border-2 rounded-md border-gray-700"> <h2 className="text font-bold my-auto">AI Answer</h2>
<div className="flex mb-1">
<h2 className="text font-bold my-auto">AI Answer</h2>
</div>
<p className="mb-4">{answer}</p>
{quotes !== null && (
<>
<h2 className="text-sm font-bold mb-2">Sources</h2>
{isFetching && dedupedQuotes.length === 0 ? (
<LoadingAnimation text="Finding quotes" size="text-sm" />
) : (
<div className="flex">
{dedupedQuotes.length > 0 ? (
dedupedQuotes.map((quoteInfo) => (
<a
key={quoteInfo.document_id}
className="p-2 ml-1 border border-gray-800 rounded-lg text-sm flex max-w-[280px] hover:bg-gray-800"
href={quoteInfo.link}
target="_blank"
rel="noopener noreferrer"
>
{getSourceIcon(quoteInfo.source_type, "20")}
<p className="truncate break-all ml-2">
{quoteInfo.semantic_identifier ||
quoteInfo.document_id}
</p>
</a>
))
) : (
<div className="flex">
<InfoIcon
size="20"
className="text-red-500 my-auto flex flex-shrink-0"
/>
<div className="text-red-500 text-sm my-auto ml-1">
Did not find any exact quotes to support the above
answer.
</div>
</div>
)}
</div>
)}
</>
)}
</div> </div>
</div> <p className="mb-4">{answer}</p>
)}
{(answer === null || answer === undefined) && !isFetching && ( {quotes !== null && (
<div className="flex"> <>
<InfoIcon <h2 className="text-sm font-bold mb-2">Sources</h2>
size="20" {isFetching && dedupedQuotes.length === 0 ? (
className="text-red-500 my-auto flex flex-shrink-0" <LoadingAnimation text="Finding quotes" size="text-sm" />
/> ) : (
<div className="text-red-500 text-xs my-auto ml-1"> <div className="flex">
GPT hurt itself in its confusion :( {dedupedQuotes.length > 0 ? (
dedupedQuotes.map((quoteInfo) => (
<a
key={quoteInfo.document_id}
className="p-2 ml-1 border border-gray-800 rounded-lg text-sm flex max-w-[280px] hover:bg-gray-800"
href={quoteInfo.link}
target="_blank"
rel="noopener noreferrer"
>
{getSourceIcon(quoteInfo.source_type, "20")}
<p className="truncate break-all ml-2">
{quoteInfo.semantic_identifier ||
quoteInfo.document_id}
</p>
</a>
))
) : (
<div className="flex">
<InfoIcon
size="20"
className="text-red-500 my-auto flex flex-shrink-0"
/>
<div className="text-red-500 text-sm my-auto ml-1">
Did not find any exact quotes to support the above
answer.
</div>
</div>
)}
</div>
)}
</>
)}
</div> </div>
</div> </div>
)} )}
{(answer === null || answer === undefined) &&
!isFetching &&
shouldDisplayQA && (
<div className="flex">
<InfoIcon
size="20"
className="text-red-500 my-auto flex flex-shrink-0"
/>
<div className="text-red-500 text-xs my-auto ml-1">
{error ?? "GPT hurt itself in its confusion :("}
</div>
</div>
)}
{documents && documents.length > 0 && ( {documents && documents.length > 0 && (
<div className="mt-4"> <div className="mt-4">
<div className="font-bold border-b mb-4 pb-1 border-gray-800"> <div className="font-bold border-b mb-4 pb-1 border-gray-800">

View File

@@ -64,6 +64,7 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
documents: null, documents: null,
suggestedSearchType: null, suggestedSearchType: null,
suggestedFlowType: null, suggestedFlowType: null,
error: null,
}; };
const updateCurrentAnswer = (answer: string) => const updateCurrentAnswer = (answer: string) =>
setSearchResponse((prevState) => ({ setSearchResponse((prevState) => ({
@@ -90,6 +91,11 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
...(prevState || initialSearchResponse), ...(prevState || initialSearchResponse),
suggestedFlowType, suggestedFlowType,
})); }));
const updateError = (error: FlowType) =>
setSearchResponse((prevState) => ({
...(prevState || initialSearchResponse),
error,
}));
let lastSearchCancellationToken = useRef<CancellationToken | null>(null); let lastSearchCancellationToken = useRef<CancellationToken | null>(null);
const onSearch = async ({ const onSearch = async ({
@@ -131,6 +137,10 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
cancellationToken: lastSearchCancellationToken.current, cancellationToken: lastSearchCancellationToken.current,
fn: updateSuggestedFlowType, fn: updateSuggestedFlowType,
}), }),
updateError: cancellable({
cancellationToken: lastSearchCancellationToken.current,
fn: updateError,
}),
selectedSearchType: searchType ?? selectedSearchType, selectedSearchType: searchType ?? selectedSearchType,
offset: offset ?? defaultOverrides.offset, offset: offset ?? defaultOverrides.offset,
}); });

View File

@@ -34,6 +34,7 @@ export interface SearchResponse {
answer: string | null; answer: string | null;
quotes: Record<string, Quote> | null; quotes: Record<string, Quote> | null;
documents: DanswerDocument[] | null; documents: DanswerDocument[] | null;
error: string | null;
} }
export interface Source { export interface Source {
@@ -54,6 +55,7 @@ export interface SearchRequestArgs {
updateDocs: (documents: DanswerDocument[]) => void; updateDocs: (documents: DanswerDocument[]) => void;
updateSuggestedSearchType: (searchType: SearchType) => void; updateSuggestedSearchType: (searchType: SearchType) => void;
updateSuggestedFlowType: (flowType: FlowType) => void; updateSuggestedFlowType: (flowType: FlowType) => void;
updateError: (error: string) => void;
selectedSearchType: SearchType | null; selectedSearchType: SearchType | null;
offset: number | null; offset: number | null;
} }

View File

@@ -14,6 +14,7 @@ export const searchRequest = async ({
updateDocs, updateDocs,
updateSuggestedSearchType, updateSuggestedSearchType,
updateSuggestedFlowType, updateSuggestedFlowType,
updateError,
selectedSearchType, selectedSearchType,
offset, offset,
}: SearchRequestArgs) => { }: SearchRequestArgs) => {
@@ -72,6 +73,7 @@ export const searchRequest = async ({
updateSuggestedSearchType(data.predicted_search); updateSuggestedSearchType(data.predicted_search);
updateSuggestedFlowType(data.predicted_flow); updateSuggestedFlowType(data.predicted_flow);
updateError(data.error);
} catch (err) { } catch (err) {
console.error("Fetch error:", err); console.error("Fetch error:", err);
} }

View File

@@ -59,6 +59,7 @@ export const searchRequestStreamed = async ({
updateDocs, updateDocs,
updateSuggestedSearchType, updateSuggestedSearchType,
updateSuggestedFlowType, updateSuggestedFlowType,
updateError,
selectedSearchType, selectedSearchType,
offset, offset,
}: SearchRequestArgs) => { }: SearchRequestArgs) => {
@@ -121,7 +122,11 @@ export const searchRequestStreamed = async ({
if (answerChunk) { if (answerChunk) {
answer += answerChunk; answer += answerChunk;
updateCurrentAnswer(answer); updateCurrentAnswer(answer);
} else if (chunk.answer_finished) { return;
}
const answerFinished = chunk.answer_finished;
if (answerFinished) {
// set quotes as non-null to signify that the answer is finished and // set quotes as non-null to signify that the answer is finished and
// we're now looking for quotes // we're now looking for quotes
updateQuotes({}); updateQuotes({});
@@ -136,27 +141,38 @@ export const searchRequestStreamed = async ({
} else { } else {
updateCurrentAnswer(""); updateCurrentAnswer("");
} }
} else { return;
if (Object.hasOwn(chunk, "top_documents")) {
const docs = chunk.top_documents as any[] | null;
if (docs) {
relevantDocuments = docs.map(
(doc) => JSON.parse(doc) as DanswerDocument
);
updateDocs(relevantDocuments);
}
if (chunk.predicted_flow) {
updateSuggestedFlowType(chunk.predicted_flow);
}
if (chunk.predicted_search) {
updateSuggestedSearchType(chunk.predicted_search);
}
} else {
quotes = chunk as Record<string, Quote>;
updateQuotes(quotes);
}
} }
const errorMsg = chunk.error;
if (errorMsg) {
updateError(errorMsg);
return;
}
// These all come together
if (Object.hasOwn(chunk, "top_documents")) {
const topDocuments = chunk.top_documents as any[] | null;
if (topDocuments) {
relevantDocuments = topDocuments.map(
(doc) => JSON.parse(doc) as DanswerDocument
);
updateDocs(relevantDocuments);
}
if (chunk.predicted_flow) {
updateSuggestedFlowType(chunk.predicted_flow);
}
if (chunk.predicted_search) {
updateSuggestedSearchType(chunk.predicted_search);
}
return;
}
// if it doesn't match any of the above, assume it is a quote
quotes = chunk as Record<string, Quote>;
updateQuotes(quotes);
}); });
} }
} catch (err) { } catch (err) {