diff --git a/backend/danswer/direct_qa/key_validation.py b/backend/danswer/direct_qa/key_validation.py index ca21ce062..5c6575817 100644 --- a/backend/danswer/direct_qa/key_validation.py +++ b/backend/danswer/direct_qa/key_validation.py @@ -8,12 +8,12 @@ def check_openai_api_key_is_valid(openai_api_key: str) -> bool: if not openai_api_key: return False - qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=2) + qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=5) if not isinstance(qa_model, OpenAIQAModel): raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model") - # try for up to 3 timeouts (e.g. 6 seconds in total) - for _ in range(3): + # try for up to 2 timeouts (e.g. 10 seconds in total) + for _ in range(2): try: qa_model.answer_question("Do not respond", []) return True diff --git a/backend/danswer/direct_qa/question_answer.py b/backend/danswer/direct_qa/question_answer.py index dc4dea397..3d51c01fb 100644 --- a/backend/danswer/direct_qa/question_answer.py +++ b/backend/danswer/direct_qa/question_answer.py @@ -215,10 +215,14 @@ def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F: @wraps(openai_call) def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any: try: - if not kwargs.get("stream"): - return openai_call(*args, **kwargs) # if streamed, the call returns a generator - yield from openai_call(*args, **kwargs) + if kwargs.get("stream"): + + def _generator(): + yield from openai_call(*args, **kwargs) + + return _generator() + return openai_call(*args, **kwargs) except AuthenticationError: logger.exception("Failed to authenticate with OpenAI API") raise @@ -372,7 +376,9 @@ class OpenAIChatCompletionQA(OpenAIQAModel): request_timeout=self.timeout, ), ) - model_output = response["choices"][0]["message"]["content"].strip() + model_output = cast( + str, response["choices"][0]["message"]["content"] + ).strip() assistant_msg = {"content": model_output, "role": "assistant"} messages.extend([assistant_msg, get_chat_reflexion_msg()]) logger.info( @@ -405,12 +411,11 @@ class OpenAIChatCompletionQA(OpenAIQAModel): stream=True, ), ) - logger.info("Raw response: %s", response) model_output: str = "" found_answer_start = False found_answer_end = False for event in response: - event_dict = cast(str, event["choices"][0]["delta"]) + event_dict = cast(dict[str, Any], event["choices"][0]["delta"]) if ( "content" not in event_dict ): # could be a role message or empty termination