Better error message on GPT failures (#187)

* Better error message on GPT-call failures

* Add support for disabling Generative AI
This commit is contained in:
Chris Weaver 2023-07-16 16:25:33 -07:00 committed by GitHub
parent 6c584f0650
commit 676538da61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 246 additions and 107 deletions

View File

@ -12,6 +12,10 @@ APP_PORT = 8080
#####
BLURB_LENGTH = 200 # Characters. Blurbs will be truncated at the first punctuation after this many characters.
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
# DISABLE_GENERATIVE_AI will turn of the question answering part of Danswer. Use this
# if you want to use Danswer as a search engine only and/or you are not comfortable sending
# anything to OpenAI. TODO: update this message once we support Azure / open source generative models.
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
#####
# Web Configs

View File

@ -1,6 +1,7 @@
from typing import Any
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.llm import OpenAIChatCompletionQA
from danswer.direct_qa.llm import OpenAICompletionQA
@ -14,4 +15,4 @@ def get_default_backend_qa_model(
elif internal_model == "openai-chat-completion":
return OpenAIChatCompletionQA(**kwargs)
else:
raise ValueError("Unknown internal QA model set.")
raise UnknownModelError(internal_model)

View File

@ -1,29 +1,34 @@
import time
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.search.danswer_helper import query_intent
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
@log_function_time()
def answer_question(
question: QuestionRequest, user: User | None, qa_model_timeout: int = QA_TIMEOUT
question: QuestionRequest,
user: User | None,
qa_model_timeout: int = QA_TIMEOUT,
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
) -> QAResponse:
start_time = time.time()
query = question.query
collection = question.collection
filters = question.filters
@ -55,7 +60,32 @@ def answer_question(
predicted_search=predicted_search,
)
qa_model = get_default_backend_qa_model(timeout=qa_model_timeout)
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
# set flow as search so frontend doesn't ask the user if they want
# to run QA over more documents
predicted_flow=QueryFlow.SEARCH,
predicted_search=predicted_search,
)
try:
qa_model = get_default_backend_qa_model(timeout=qa_model_timeout)
except (UnknownModelError, OpenAIKeyMissing) as e:
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
predicted_flow=predicted_flow,
predicted_search=predicted_search,
error_msg=str(e),
)
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(ranked_chunks):
raise ValueError("Chunks offset too large, should not retry this many times")
@ -71,8 +101,6 @@ def answer_question(
answer, quotes = None, None
error_msg = f"Error occurred in call to LLM - {e}"
logger.info(f"Total QA took {time.time() - start_time} seconds")
return QAResponse(
answer=answer,
quotes=quotes,

View File

@ -0,0 +1,8 @@
class OpenAIKeyMissing(Exception):
def __init__(self, msg: str = "Unable to find an OpenAI Key") -> None:
super().__init__(msg)
class UnknownModelError(Exception):
def __init__(self, model_name: str) -> None:
super().__init__(f"Unknown Internal QA model name: {model_name}")

View File

@ -27,6 +27,7 @@ from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import OPENAI_MODEL_VERSION
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import get_chat_reflexion_msg
@ -278,7 +279,7 @@ class OpenAICompletionQA(OpenAIQAModel):
try:
self.api_key = api_key or get_openai_api_key()
except ConfigNotFoundError:
raise RuntimeError("No OpenAI Key available")
raise OpenAIKeyMissing()
@log_function_time()
def answer_question(
@ -391,7 +392,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
try:
self.api_key = api_key or get_openai_api_key()
except ConfigNotFoundError:
raise RuntimeError("No OpenAI Key available")
raise OpenAIKeyMissing()
@log_function_time()
def answer_question(
@ -482,6 +483,6 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
_, quotes_dict = process_answer(model_output, context_docs)
yield {} if quotes_dict is None else quotes_dict

View File

@ -5,6 +5,7 @@ from typing import cast
from danswer.auth.schemas import UserRole
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
@ -293,6 +294,20 @@ def connector_run_once(
def validate_existing_openai_api_key(
_: User = Depends(current_admin_user),
) -> None:
# OpenAI key is only used for generative QA, so no need to validate this
# if it's turned off
if DISABLE_GENERATIVE_AI:
return
# always check if key exists
try:
openai_api_key = get_openai_api_key()
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# don't call OpenAI every single time, only validate every so often
check_key_time = "openai_api_key_last_check_time"
kv_store = get_dynamic_config_store()
curr_time = datetime.now()
@ -308,12 +323,10 @@ def validate_existing_openai_api_key(
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
try:
openai_api_key = get_openai_api_key()
is_valid = check_openai_api_key_is_valid(openai_api_key)
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except ValueError:
# this is the case where they aren't using an OpenAI-based model
is_valid = True
if not is_valid:
raise HTTPException(status_code=400, detail="Invalid API key provided")

View File

@ -1,8 +1,8 @@
import time
from collections.abc import Generator
from danswer.auth.users import current_user
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.qdrant.store import QdrantIndex
@ -10,10 +10,13 @@ from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.answer_question import answer_question
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.llm import get_json_line
from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
@ -22,6 +25,7 @@ from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchResponse
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
from fastapi import APIRouter
from fastapi import Depends
from fastapi.responses import StreamingResponse
@ -102,9 +106,10 @@ def stream_direct_qa(
logger.debug(f"Received QA query: {question.query}")
logger.debug(f"Query filters: {question.filters}")
def stream_qa_portions() -> Generator[str, None, None]:
start_time = time.time()
@log_generator_function_time()
def stream_qa_portions(
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
) -> Generator[str, None, None]:
query = question.query
collection = question.collection
filters = question.filters
@ -142,13 +147,27 @@ def stream_direct_qa(
initial_response_dict = {
top_documents_key: [top_doc.json() for top_doc in top_docs],
unranked_top_docs_key: [doc.json() for doc in unranked_top_docs],
predicted_flow_key: predicted_flow,
# if generative AI is disabled, set flow as search so frontend
# doesn't ask the user if they want to run QA over more documents
predicted_flow_key: QueryFlow.SEARCH
if disable_generative_answer
else predicted_flow,
predicted_search_key: predicted_search,
}
logger.debug(send_packet_debug_msg.format(initial_response_dict))
yield get_json_line(initial_response_dict)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return
try:
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
except (UnknownModelError, OpenAIKeyMissing) as e:
logger.exception("Unable to get QA model")
yield get_json_line({"error": str(e)})
return
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(ranked_chunks):
raise ValueError(
@ -165,10 +184,10 @@ def stream_direct_qa(
continue
logger.debug(f"Sending packet: {response_dict}")
yield get_json_line(response_dict)
except Exception:
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
pass
logger.info(f"Total QA took {time.time() - start_time} seconds")
yield get_json_line({"error": str(e)})
return
return StreamingResponse(stream_qa_portions(), media_type="application/json")

View File

@ -1,5 +1,6 @@
import time
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import cast
from typing import TypeVar
@ -9,6 +10,7 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable)
FG = TypeVar("FG", bound=Callable[..., Generator])
def log_function_time(
@ -34,3 +36,30 @@ def log_function_time(
return cast(F, wrapped_func)
return timing_wrapper
def log_generator_function_time(
func_name: str | None = None,
) -> Callable[[FG], FG]:
"""Build a timing wrapper for a function which returns a generator.
Logs how long the function took to run.
Use like:
@log_generator_function_time()
def my_func():
...
yield X
...
"""
def timing_wrapper(func: FG) -> FG:
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
yield from func(*args, **kwargs)
logger.info(
f"{func_name or func.__name__} took {time.time() - start_time} seconds"
)
return cast(F, wrapped_func)
return timing_wrapper

View File

@ -24,6 +24,7 @@ services:
- DISABLE_AUTH=${DISABLE_AUTH:-True}
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
volumes:
- local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage

View File

@ -40,7 +40,7 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
return null;
}
const { answer, quotes, documents } = searchResponse;
const { answer, quotes, documents, error } = searchResponse;
if (isFetching && !answer) {
return (
@ -67,73 +67,78 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
});
}
const shouldDisplayQA =
searchResponse.suggestedFlowType === FlowType.QUESTION_ANSWER ||
defaultOverrides.forceDisplayQA;
console.log(shouldDisplayQA);
return (
<>
{answer &&
(searchResponse.suggestedFlowType !== FlowType.SEARCH ||
defaultOverrides.forceDisplayQA) && (
<div className="min-h-[14rem]">
<div className="p-4 border-2 rounded-md border-gray-700">
<div className="flex mb-1">
<h2 className="text font-bold my-auto">AI Answer</h2>
</div>
<p className="mb-4">{answer}</p>
{quotes !== null && (
<>
<h2 className="text-sm font-bold mb-2">Sources</h2>
{isFetching && dedupedQuotes.length === 0 ? (
<LoadingAnimation text="Finding quotes" size="text-sm" />
) : (
<div className="flex">
{dedupedQuotes.length > 0 ? (
dedupedQuotes.map((quoteInfo) => (
<a
key={quoteInfo.document_id}
className="p-2 ml-1 border border-gray-800 rounded-lg text-sm flex max-w-[280px] hover:bg-gray-800"
href={quoteInfo.link}
target="_blank"
rel="noopener noreferrer"
>
{getSourceIcon(quoteInfo.source_type, "20")}
<p className="truncate break-all ml-2">
{quoteInfo.semantic_identifier ||
quoteInfo.document_id}
</p>
</a>
))
) : (
<div className="flex">
<InfoIcon
size="20"
className="text-red-500 my-auto flex flex-shrink-0"
/>
<div className="text-red-500 text-sm my-auto ml-1">
Did not find any exact quotes to support the above
answer.
</div>
</div>
)}
</div>
)}
</>
)}
{answer && shouldDisplayQA && (
<div className="min-h-[14rem]">
<div className="p-4 border-2 rounded-md border-gray-700">
<div className="flex mb-1">
<h2 className="text font-bold my-auto">AI Answer</h2>
</div>
</div>
)}
<p className="mb-4">{answer}</p>
{(answer === null || answer === undefined) && !isFetching && (
<div className="flex">
<InfoIcon
size="20"
className="text-red-500 my-auto flex flex-shrink-0"
/>
<div className="text-red-500 text-xs my-auto ml-1">
GPT hurt itself in its confusion :(
{quotes !== null && (
<>
<h2 className="text-sm font-bold mb-2">Sources</h2>
{isFetching && dedupedQuotes.length === 0 ? (
<LoadingAnimation text="Finding quotes" size="text-sm" />
) : (
<div className="flex">
{dedupedQuotes.length > 0 ? (
dedupedQuotes.map((quoteInfo) => (
<a
key={quoteInfo.document_id}
className="p-2 ml-1 border border-gray-800 rounded-lg text-sm flex max-w-[280px] hover:bg-gray-800"
href={quoteInfo.link}
target="_blank"
rel="noopener noreferrer"
>
{getSourceIcon(quoteInfo.source_type, "20")}
<p className="truncate break-all ml-2">
{quoteInfo.semantic_identifier ||
quoteInfo.document_id}
</p>
</a>
))
) : (
<div className="flex">
<InfoIcon
size="20"
className="text-red-500 my-auto flex flex-shrink-0"
/>
<div className="text-red-500 text-sm my-auto ml-1">
Did not find any exact quotes to support the above
answer.
</div>
</div>
)}
</div>
)}
</>
)}
</div>
</div>
)}
{(answer === null || answer === undefined) &&
!isFetching &&
shouldDisplayQA && (
<div className="flex">
<InfoIcon
size="20"
className="text-red-500 my-auto flex flex-shrink-0"
/>
<div className="text-red-500 text-xs my-auto ml-1">
{error ?? "GPT hurt itself in its confusion :("}
</div>
</div>
)}
{documents && documents.length > 0 && (
<div className="mt-4">
<div className="font-bold border-b mb-4 pb-1 border-gray-800">

View File

@ -64,6 +64,7 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
documents: null,
suggestedSearchType: null,
suggestedFlowType: null,
error: null,
};
const updateCurrentAnswer = (answer: string) =>
setSearchResponse((prevState) => ({
@ -90,6 +91,11 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
...(prevState || initialSearchResponse),
suggestedFlowType,
}));
const updateError = (error: FlowType) =>
setSearchResponse((prevState) => ({
...(prevState || initialSearchResponse),
error,
}));
let lastSearchCancellationToken = useRef<CancellationToken | null>(null);
const onSearch = async ({
@ -131,6 +137,10 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
cancellationToken: lastSearchCancellationToken.current,
fn: updateSuggestedFlowType,
}),
updateError: cancellable({
cancellationToken: lastSearchCancellationToken.current,
fn: updateError,
}),
selectedSearchType: searchType ?? selectedSearchType,
offset: offset ?? defaultOverrides.offset,
});

View File

@ -34,6 +34,7 @@ export interface SearchResponse {
answer: string | null;
quotes: Record<string, Quote> | null;
documents: DanswerDocument[] | null;
error: string | null;
}
export interface Source {
@ -54,6 +55,7 @@ export interface SearchRequestArgs {
updateDocs: (documents: DanswerDocument[]) => void;
updateSuggestedSearchType: (searchType: SearchType) => void;
updateSuggestedFlowType: (flowType: FlowType) => void;
updateError: (error: string) => void;
selectedSearchType: SearchType | null;
offset: number | null;
}

View File

@ -14,6 +14,7 @@ export const searchRequest = async ({
updateDocs,
updateSuggestedSearchType,
updateSuggestedFlowType,
updateError,
selectedSearchType,
offset,
}: SearchRequestArgs) => {
@ -72,6 +73,7 @@ export const searchRequest = async ({
updateSuggestedSearchType(data.predicted_search);
updateSuggestedFlowType(data.predicted_flow);
updateError(data.error);
} catch (err) {
console.error("Fetch error:", err);
}

View File

@ -59,6 +59,7 @@ export const searchRequestStreamed = async ({
updateDocs,
updateSuggestedSearchType,
updateSuggestedFlowType,
updateError,
selectedSearchType,
offset,
}: SearchRequestArgs) => {
@ -121,7 +122,11 @@ export const searchRequestStreamed = async ({
if (answerChunk) {
answer += answerChunk;
updateCurrentAnswer(answer);
} else if (chunk.answer_finished) {
return;
}
const answerFinished = chunk.answer_finished;
if (answerFinished) {
// set quotes as non-null to signify that the answer is finished and
// we're now looking for quotes
updateQuotes({});
@ -136,27 +141,38 @@ export const searchRequestStreamed = async ({
} else {
updateCurrentAnswer("");
}
} else {
if (Object.hasOwn(chunk, "top_documents")) {
const docs = chunk.top_documents as any[] | null;
if (docs) {
relevantDocuments = docs.map(
(doc) => JSON.parse(doc) as DanswerDocument
);
updateDocs(relevantDocuments);
}
if (chunk.predicted_flow) {
updateSuggestedFlowType(chunk.predicted_flow);
}
if (chunk.predicted_search) {
updateSuggestedSearchType(chunk.predicted_search);
}
} else {
quotes = chunk as Record<string, Quote>;
updateQuotes(quotes);
}
return;
}
const errorMsg = chunk.error;
if (errorMsg) {
updateError(errorMsg);
return;
}
// These all come together
if (Object.hasOwn(chunk, "top_documents")) {
const topDocuments = chunk.top_documents as any[] | null;
if (topDocuments) {
relevantDocuments = topDocuments.map(
(doc) => JSON.parse(doc) as DanswerDocument
);
updateDocs(relevantDocuments);
}
if (chunk.predicted_flow) {
updateSuggestedFlowType(chunk.predicted_flow);
}
if (chunk.predicted_search) {
updateSuggestedSearchType(chunk.predicted_search);
}
return;
}
// if it doesn't match any of the above, assume it is a quote
quotes = chunk as Record<string, Quote>;
updateQuotes(quotes);
});
}
} catch (err) {