From 806653dcb0ecddcc927b09b73518ebb0e356f2e0 Mon Sep 17 00:00:00 2001 From: Weves Date: Sat, 20 May 2023 18:18:54 -0700 Subject: [PATCH] Add timeout option to OpenAI models --- backend/danswer/configs/app_configs.py | 1 + backend/danswer/direct_qa/__init__.py | 5 +- backend/danswer/direct_qa/key_validation.py | 22 +- backend/danswer/direct_qa/question_answer.py | 242 ++++++++++-------- backend/danswer/server/search_backend.py | 34 ++- .../search/SearchResultsDisplay.tsx | 5 +- 6 files changed, 175 insertions(+), 134 deletions(-) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 6a098a8bd0be..8f887c37574c 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -98,6 +98,7 @@ KEYWORD_MAX_HITS = 5 QUOTE_ALLOWED_ERROR_PERCENT = ( 0.05 # 1 edit per 2 characters, currently unused due to fuzzy match being too slow ) +QA_TIMEOUT = 10 # 10 seconds ##### diff --git a/backend/danswer/direct_qa/__init__.py b/backend/danswer/direct_qa/__init__.py index 413ad823e10b..b40721dcf97b 100644 --- a/backend/danswer/direct_qa/__init__.py +++ b/backend/danswer/direct_qa/__init__.py @@ -1,12 +1,9 @@ from typing import Any -from danswer.configs.app_configs import OPENAI_API_KEY -from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.question_answer import OpenAIChatCompletionQA from danswer.direct_qa.question_answer import OpenAICompletionQA -from danswer.dynamic_configs import get_dynamic_config_store def get_default_backend_qa_model( @@ -17,4 +14,4 @@ def get_default_backend_qa_model( elif internal_model == "openai-chat-completion": return OpenAIChatCompletionQA(**kwargs) else: - raise ValueError("Wrong internal QA model set.") + raise ValueError("Unknown internal QA model set.") diff --git a/backend/danswer/direct_qa/key_validation.py b/backend/danswer/direct_qa/key_validation.py index 512924a7eb4b..ca21ce062cc3 100644 --- a/backend/danswer/direct_qa/key_validation.py +++ b/backend/danswer/direct_qa/key_validation.py @@ -1,21 +1,25 @@ -from danswer.configs.app_configs import OPENAI_API_KEY -from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa.question_answer import OpenAIQAModel -from danswer.dynamic_configs import get_dynamic_config_store from openai.error import AuthenticationError +from openai.error import Timeout def check_openai_api_key_is_valid(openai_api_key: str) -> bool: if not openai_api_key: return False - qa_model = get_default_backend_qa_model(api_key=openai_api_key) + qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=2) if not isinstance(qa_model, OpenAIQAModel): raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model") - try: - qa_model.answer_question("Do not respond", []) - return True - except AuthenticationError: - return False + # try for up to 3 timeouts (e.g. 6 seconds in total) + for _ in range(3): + try: + qa_model.answer_question("Do not respond", []) + return True + except AuthenticationError: + return False + except Timeout: + pass + + return False diff --git a/backend/danswer/direct_qa/question_answer.py b/backend/danswer/direct_qa/question_answer.py index cfd8457e3a6a..dc4dea39713e 100644 --- a/backend/danswer/direct_qa/question_answer.py +++ b/backend/danswer/direct_qa/question_answer.py @@ -7,8 +7,10 @@ from functools import wraps from typing import Any from typing import cast from typing import Dict +from typing import Literal from typing import Optional from typing import Tuple +from typing import TypeVar from typing import Union import openai @@ -37,6 +39,7 @@ from danswer.utils.text_processing import clean_model_quote from danswer.utils.text_processing import shared_precompare_cleanup from danswer.utils.timing import log_function_time from openai.error import AuthenticationError +from openai.error import Timeout logger = setup_logger() @@ -187,6 +190,48 @@ def stream_answer_end(answer_so_far: str, next_token: str) -> bool: return False +F = TypeVar("F", bound=Callable) + +Answer = str +Quotes = dict[str, dict[str, str | int | None]] +ModelType = Literal["ChatCompletion", "Completion"] +PromptProcessor = Callable[[str, list[str]], str] + + +def _build_openai_settings(**kwargs: dict[str, Any]) -> dict[str, Any]: + """ + Utility to add in some common default values so they don't have to be set every time. + """ + return { + "temperature": 0, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + **kwargs, + } + + +def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F: + @wraps(openai_call) + def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any: + try: + if not kwargs.get("stream"): + return openai_call(*args, **kwargs) + # if streamed, the call returns a generator + yield from openai_call(*args, **kwargs) + except AuthenticationError: + logger.exception("Failed to authenticate with OpenAI API") + raise + except Timeout: + logger.exception("OpenAI API timed out for query: %s", query) + raise + except Exception as e: + logger.exception("Unexpected error with OpenAI API for query: %s", query) + raise + + return cast(F, wrapped_call) + + # used to check if the QAModel is an OpenAI model class OpenAIQAModel(QAModel): pass @@ -199,11 +244,13 @@ class OpenAICompletionQA(OpenAIQAModel): model_version: str = OPENAI_MODEL_VERSION, max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS, api_key: str | None = None, + timeout: int | None = None, ) -> None: self.prompt_processor = prompt_processor self.model_version = model_version self.max_output_tokens = max_output_tokens self.api_key = api_key or get_openai_api_key() + self.timeout = timeout @log_function_time() def answer_question( @@ -213,28 +260,21 @@ class OpenAICompletionQA(OpenAIQAModel): filled_prompt = self.prompt_processor(query, top_contents) logger.debug(filled_prompt) - try: - response = openai.Completion.create( + openai_call = _handle_openai_exceptions_wrapper( + openai_call=openai.Completion.create, + query=query, + ) + response = openai_call( + **_build_openai_settings( api_key=self.api_key, prompt=filled_prompt, - temperature=0, - top_p=1, - frequency_penalty=0, - presence_penalty=0, model=self.model_version, max_tokens=self.max_output_tokens, - ) - model_output = response["choices"][0]["text"].strip() - logger.info( - "OpenAI Token Usage: " + str(response["usage"]).replace("\n", "") - ) - except AuthenticationError: - logger.exception("Failed to authenticate with OpenAI API") - raise - except Exception as e: - logger.exception(e) - model_output = "Model Failure" - + request_timeout=self.timeout, + ), + ) + model_output = cast(str, response["choices"][0]["text"]).strip() + logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")) logger.debug(model_output) answer, quotes_dict = process_answer(model_output, context_docs) @@ -247,46 +287,41 @@ class OpenAICompletionQA(OpenAIQAModel): filled_prompt = self.prompt_processor(query, top_contents) logger.debug(filled_prompt) - try: - response = openai.Completion.create( + openai_call = _handle_openai_exceptions_wrapper( + openai_call=openai.Completion.create, + query=query, + ) + response = openai_call( + **_build_openai_settings( api_key=self.api_key, prompt=filled_prompt, - temperature=0, - top_p=1, - frequency_penalty=0, - presence_penalty=0, model=self.model_version, max_tokens=self.max_output_tokens, + request_timeout=self.timeout, stream=True, - ) + ), + ) + model_output: str = "" + found_answer_start = False + found_answer_end = False + # iterate through the stream of events + for event in response: + event_text = cast(str, event["choices"][0]["text"]) + model_previous = model_output + model_output += event_text - model_output: str = "" - found_answer_start = False - found_answer_end = False - # iterate through the stream of events - for event in response: - event_text = cast(str, event["choices"][0]["text"]) - model_previous = model_output - model_output += event_text + if not found_answer_start and '{"answer":"' in model_output.replace( + " ", "" + ).replace("\n", ""): + found_answer_start = True + continue - if not found_answer_start and '{"answer":"' in model_output.replace( - " ", "" - ).replace("\n", ""): - found_answer_start = True + if found_answer_start and not found_answer_end: + if stream_answer_end(model_previous, event_text): + found_answer_end = True + yield {"answer_finished": True} continue - - if found_answer_start and not found_answer_end: - if stream_answer_end(model_previous, event_text): - found_answer_end = True - yield {"answer_finished": True} - continue - yield {"answer_data": event_text} - except AuthenticationError: - logger.exception("Failed to authenticate with OpenAI API") - raise - except Exception as e: - logger.exception(e) - model_output = "Model Failure" + yield {"answer_data": event_text} logger.debug(model_output) @@ -304,6 +339,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel): ] = json_chat_processor, model_version: str = OPENAI_MODEL_VERSION, max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS, + timeout: int | None = None, reflexion_try_count: int = 0, api_key: str | None = None, ) -> None: @@ -312,6 +348,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel): self.max_output_tokens = max_output_tokens self.reflexion_try_count = reflexion_try_count self.api_key = api_key or get_openai_api_key() + self.timeout = timeout @log_function_time() def answer_question( @@ -322,30 +359,25 @@ class OpenAIChatCompletionQA(OpenAIQAModel): logger.debug(messages) model_output = "" for _ in range(self.reflexion_try_count + 1): - try: - response = openai.ChatCompletion.create( + openai_call = _handle_openai_exceptions_wrapper( + openai_call=openai.ChatCompletion.create, + query=query, + ) + response = openai_call( + **_build_openai_settings( api_key=self.api_key, messages=messages, - temperature=0, - top_p=1, - frequency_penalty=0, - presence_penalty=0, model=self.model_version, max_tokens=self.max_output_tokens, - ) - model_output = response["choices"][0]["message"]["content"].strip() - assistant_msg = {"content": model_output, "role": "assistant"} - messages.extend([assistant_msg, get_chat_reflexion_msg()]) - logger.info( - "OpenAI Token Usage: " + str(response["usage"]).replace("\n", "") - ) - except AuthenticationError: - logger.exception("Failed to authenticate with OpenAI API") - raise - except Exception as e: - logger.exception(e) - logger.warning(f"Model failure for query: {query}") - return None, None + request_timeout=self.timeout, + ), + ) + model_output = response["choices"][0]["message"]["content"].strip() + assistant_msg = {"content": model_output, "role": "assistant"} + messages.extend([assistant_msg, get_chat_reflexion_msg()]) + logger.info( + "OpenAI Token Usage: " + str(response["usage"]).replace("\n", "") + ) logger.debug(model_output) @@ -359,50 +391,46 @@ class OpenAIChatCompletionQA(OpenAIQAModel): messages = self.prompt_processor(query, top_contents) logger.debug(messages) - try: - response = openai.ChatCompletion.create( + openai_call = _handle_openai_exceptions_wrapper( + openai_call=openai.ChatCompletion.create, + query=query, + ) + response = openai_call( + **_build_openai_settings( api_key=self.api_key, messages=messages, - temperature=0, - top_p=1, - frequency_penalty=0, - presence_penalty=0, model=self.model_version, max_tokens=self.max_output_tokens, + request_timeout=self.timeout, stream=True, - ) + ), + ) + logger.info("Raw response: %s", response) + model_output: str = "" + found_answer_start = False + found_answer_end = False + for event in response: + event_dict = cast(str, event["choices"][0]["delta"]) + if ( + "content" not in event_dict + ): # could be a role message or empty termination + continue + event_text = event_dict["content"] + model_previous = model_output + model_output += event_text - model_output: str = "" - found_answer_start = False - found_answer_end = False - for event in response: - event_dict = event["choices"][0]["delta"] - if ( - "content" not in event_dict - ): # could be a role message or empty termination + if not found_answer_start and '{"answer":"' in model_output.replace( + " ", "" + ).replace("\n", ""): + found_answer_start = True + continue + + if found_answer_start and not found_answer_end: + if stream_answer_end(model_previous, event_text): + found_answer_end = True + yield {"answer_finished": True} continue - event_text = event_dict["content"] - model_previous = model_output - model_output += event_text - - if not found_answer_start and '{"answer":"' in model_output.replace( - " ", "" - ).replace("\n", ""): - found_answer_start = True - continue - - if found_answer_start and not found_answer_end: - if stream_answer_end(model_previous, event_text): - found_answer_end = True - yield {"answer_finished": True} - continue - yield {"answer_data": event_text} - except AuthenticationError: - logger.exception("Failed to authenticate with OpenAI API") - raise - except Exception as e: - logger.exception(e) - model_output = "Model Failure" + yield {"answer_data": event_text} logger.debug(model_output) diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index da89a8d3c613..5c808917984c 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -6,6 +6,7 @@ from danswer.auth.users import current_active_user from danswer.auth.users import current_admin_user from danswer.configs.app_configs import KEYWORD_MAX_HITS from danswer.configs.app_configs import NUM_RERANKED_RESULTS +from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.constants import CONTENT from danswer.configs.constants import SOURCE_LINKS from danswer.datastores import create_datastore @@ -85,10 +86,14 @@ def direct_qa( for chunk in ranked_chunks ] - qa_model = get_default_backend_qa_model() - answer, quotes = qa_model.answer_question( - query, ranked_chunks[:NUM_RERANKED_RESULTS] - ) + qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT) + try: + answer, quotes = qa_model.answer_question( + query, ranked_chunks[:NUM_RERANKED_RESULTS] + ) + except Exception: + # exception is logged in the answer_question method, no need to re-log + answer, quotes = None, None logger.info(f"Total QA took {time.time() - start_time} seconds") @@ -126,14 +131,19 @@ def stream_direct_qa( top_docs_dict = {top_documents_key: [top_doc.json() for top_doc in top_docs]} yield get_json_line(top_docs_dict) - qa_model = get_default_backend_qa_model() - for response_dict in qa_model.answer_question_stream( - query, ranked_chunks[:NUM_RERANKED_RESULTS] - ): - if response_dict is None: - continue - logger.debug(response_dict) - yield get_json_line(response_dict) + qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT) + try: + for response_dict in qa_model.answer_question_stream( + query, ranked_chunks[:NUM_RERANKED_RESULTS] + ): + if response_dict is None: + continue + logger.debug(response_dict) + yield get_json_line(response_dict) + except Exception: + # exception is logged in the answer_question method, no need to re-log + pass + return return StreamingResponse(stream_qa_portions(), media_type="application/json") diff --git a/web/src/components/search/SearchResultsDisplay.tsx b/web/src/components/search/SearchResultsDisplay.tsx index 3da368d071f0..201389b8f555 100644 --- a/web/src/components/search/SearchResultsDisplay.tsx +++ b/web/src/components/search/SearchResultsDisplay.tsx @@ -32,7 +32,9 @@ export const SearchResultsDisplay: React.FC = ({ return null; } - if (isFetching) { + const { answer, quotes, documents } = searchResponse; + + if (isFetching && !answer) { return (
@@ -42,7 +44,6 @@ export const SearchResultsDisplay: React.FC = ({ ); } - const { answer, quotes, documents } = searchResponse; if (answer === null && documents === null && quotes === null) { return (