Better error message on GPT failures (#187)

* Better error message on GPT-call failures

* Add support for disabling Generative AI
This commit is contained in:
Chris Weaver
2023-07-16 16:25:33 -07:00
committed by GitHub
parent 6c584f0650
commit 676538da61
14 changed files with 246 additions and 107 deletions

View File

@@ -12,6 +12,10 @@ APP_PORT = 8080
#####
BLURB_LENGTH = 200 # Characters. Blurbs will be truncated at the first punctuation after this many characters.
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
# DISABLE_GENERATIVE_AI will turn of the question answering part of Danswer. Use this
# if you want to use Danswer as a search engine only and/or you are not comfortable sending
# anything to OpenAI. TODO: update this message once we support Azure / open source generative models.
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
#####
# Web Configs

View File

@@ -1,6 +1,7 @@
from typing import Any
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.llm import OpenAIChatCompletionQA
from danswer.direct_qa.llm import OpenAICompletionQA
@@ -14,4 +15,4 @@ def get_default_backend_qa_model(
elif internal_model == "openai-chat-completion":
return OpenAIChatCompletionQA(**kwargs)
else:
raise ValueError("Unknown internal QA model set.")
raise UnknownModelError(internal_model)

View File

@@ -1,29 +1,34 @@
import time
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.search.danswer_helper import query_intent
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
@log_function_time()
def answer_question(
question: QuestionRequest, user: User | None, qa_model_timeout: int = QA_TIMEOUT
question: QuestionRequest,
user: User | None,
qa_model_timeout: int = QA_TIMEOUT,
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
) -> QAResponse:
start_time = time.time()
query = question.query
collection = question.collection
filters = question.filters
@@ -55,7 +60,32 @@ def answer_question(
predicted_search=predicted_search,
)
qa_model = get_default_backend_qa_model(timeout=qa_model_timeout)
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
# set flow as search so frontend doesn't ask the user if they want
# to run QA over more documents
predicted_flow=QueryFlow.SEARCH,
predicted_search=predicted_search,
)
try:
qa_model = get_default_backend_qa_model(timeout=qa_model_timeout)
except (UnknownModelError, OpenAIKeyMissing) as e:
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
predicted_flow=predicted_flow,
predicted_search=predicted_search,
error_msg=str(e),
)
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(ranked_chunks):
raise ValueError("Chunks offset too large, should not retry this many times")
@@ -71,8 +101,6 @@ def answer_question(
answer, quotes = None, None
error_msg = f"Error occurred in call to LLM - {e}"
logger.info(f"Total QA took {time.time() - start_time} seconds")
return QAResponse(
answer=answer,
quotes=quotes,

View File

@@ -0,0 +1,8 @@
class OpenAIKeyMissing(Exception):
def __init__(self, msg: str = "Unable to find an OpenAI Key") -> None:
super().__init__(msg)
class UnknownModelError(Exception):
def __init__(self, model_name: str) -> None:
super().__init__(f"Unknown Internal QA model name: {model_name}")

View File

@@ -27,6 +27,7 @@ from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import OPENAI_MODEL_VERSION
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import get_chat_reflexion_msg
@@ -278,7 +279,7 @@ class OpenAICompletionQA(OpenAIQAModel):
try:
self.api_key = api_key or get_openai_api_key()
except ConfigNotFoundError:
raise RuntimeError("No OpenAI Key available")
raise OpenAIKeyMissing()
@log_function_time()
def answer_question(
@@ -391,7 +392,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
try:
self.api_key = api_key or get_openai_api_key()
except ConfigNotFoundError:
raise RuntimeError("No OpenAI Key available")
raise OpenAIKeyMissing()
@log_function_time()
def answer_question(
@@ -482,6 +483,6 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
_, quotes_dict = process_answer(model_output, context_docs)
yield {} if quotes_dict is None else quotes_dict

View File

@@ -5,6 +5,7 @@ from typing import cast
from danswer.auth.schemas import UserRole
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
@@ -293,6 +294,20 @@ def connector_run_once(
def validate_existing_openai_api_key(
_: User = Depends(current_admin_user),
) -> None:
# OpenAI key is only used for generative QA, so no need to validate this
# if it's turned off
if DISABLE_GENERATIVE_AI:
return
# always check if key exists
try:
openai_api_key = get_openai_api_key()
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# don't call OpenAI every single time, only validate every so often
check_key_time = "openai_api_key_last_check_time"
kv_store = get_dynamic_config_store()
curr_time = datetime.now()
@@ -308,12 +323,10 @@ def validate_existing_openai_api_key(
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
try:
openai_api_key = get_openai_api_key()
is_valid = check_openai_api_key_is_valid(openai_api_key)
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except ValueError:
# this is the case where they aren't using an OpenAI-based model
is_valid = True
if not is_valid:
raise HTTPException(status_code=400, detail="Invalid API key provided")

View File

@@ -1,8 +1,8 @@
import time
from collections.abc import Generator
from danswer.auth.users import current_user
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.qdrant.store import QdrantIndex
@@ -10,10 +10,13 @@ from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.answer_question import answer_question
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.llm import get_json_line
from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
@@ -22,6 +25,7 @@ from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchResponse
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
from fastapi import APIRouter
from fastapi import Depends
from fastapi.responses import StreamingResponse
@@ -102,9 +106,10 @@ def stream_direct_qa(
logger.debug(f"Received QA query: {question.query}")
logger.debug(f"Query filters: {question.filters}")
def stream_qa_portions() -> Generator[str, None, None]:
start_time = time.time()
@log_generator_function_time()
def stream_qa_portions(
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
) -> Generator[str, None, None]:
query = question.query
collection = question.collection
filters = question.filters
@@ -142,13 +147,27 @@ def stream_direct_qa(
initial_response_dict = {
top_documents_key: [top_doc.json() for top_doc in top_docs],
unranked_top_docs_key: [doc.json() for doc in unranked_top_docs],
predicted_flow_key: predicted_flow,
# if generative AI is disabled, set flow as search so frontend
# doesn't ask the user if they want to run QA over more documents
predicted_flow_key: QueryFlow.SEARCH
if disable_generative_answer
else predicted_flow,
predicted_search_key: predicted_search,
}
logger.debug(send_packet_debug_msg.format(initial_response_dict))
yield get_json_line(initial_response_dict)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return
try:
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
except (UnknownModelError, OpenAIKeyMissing) as e:
logger.exception("Unable to get QA model")
yield get_json_line({"error": str(e)})
return
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(ranked_chunks):
raise ValueError(
@@ -165,10 +184,10 @@ def stream_direct_qa(
continue
logger.debug(f"Sending packet: {response_dict}")
yield get_json_line(response_dict)
except Exception:
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
pass
logger.info(f"Total QA took {time.time() - start_time} seconds")
yield get_json_line({"error": str(e)})
return
return StreamingResponse(stream_qa_portions(), media_type="application/json")

View File

@@ -1,5 +1,6 @@
import time
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import cast
from typing import TypeVar
@@ -9,6 +10,7 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable)
FG = TypeVar("FG", bound=Callable[..., Generator])
def log_function_time(
@@ -34,3 +36,30 @@ def log_function_time(
return cast(F, wrapped_func)
return timing_wrapper
def log_generator_function_time(
func_name: str | None = None,
) -> Callable[[FG], FG]:
"""Build a timing wrapper for a function which returns a generator.
Logs how long the function took to run.
Use like:
@log_generator_function_time()
def my_func():
...
yield X
...
"""
def timing_wrapper(func: FG) -> FG:
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
yield from func(*args, **kwargs)
logger.info(
f"{func_name or func.__name__} took {time.time() - start_time} seconds"
)
return cast(F, wrapped_func)
return timing_wrapper