mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-07 19:38:19 +02:00
Fix OpenAI key validation
This commit is contained in:
parent
8e9e284849
commit
f4ef92e279
@ -8,12 +8,12 @@ def check_openai_api_key_is_valid(openai_api_key: str) -> bool:
|
||||
if not openai_api_key:
|
||||
return False
|
||||
|
||||
qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=2)
|
||||
qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=5)
|
||||
if not isinstance(qa_model, OpenAIQAModel):
|
||||
raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model")
|
||||
|
||||
# try for up to 3 timeouts (e.g. 6 seconds in total)
|
||||
for _ in range(3):
|
||||
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||
for _ in range(2):
|
||||
try:
|
||||
qa_model.answer_question("Do not respond", [])
|
||||
return True
|
||||
|
@ -215,10 +215,14 @@ def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
|
||||
@wraps(openai_call)
|
||||
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
||||
try:
|
||||
if not kwargs.get("stream"):
|
||||
return openai_call(*args, **kwargs)
|
||||
# if streamed, the call returns a generator
|
||||
yield from openai_call(*args, **kwargs)
|
||||
if kwargs.get("stream"):
|
||||
|
||||
def _generator():
|
||||
yield from openai_call(*args, **kwargs)
|
||||
|
||||
return _generator()
|
||||
return openai_call(*args, **kwargs)
|
||||
except AuthenticationError:
|
||||
logger.exception("Failed to authenticate with OpenAI API")
|
||||
raise
|
||||
@ -372,7 +376,9 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
request_timeout=self.timeout,
|
||||
),
|
||||
)
|
||||
model_output = response["choices"][0]["message"]["content"].strip()
|
||||
model_output = cast(
|
||||
str, response["choices"][0]["message"]["content"]
|
||||
).strip()
|
||||
assistant_msg = {"content": model_output, "role": "assistant"}
|
||||
messages.extend([assistant_msg, get_chat_reflexion_msg()])
|
||||
logger.info(
|
||||
@ -405,12 +411,11 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
stream=True,
|
||||
),
|
||||
)
|
||||
logger.info("Raw response: %s", response)
|
||||
model_output: str = ""
|
||||
found_answer_start = False
|
||||
found_answer_end = False
|
||||
for event in response:
|
||||
event_dict = cast(str, event["choices"][0]["delta"])
|
||||
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
|
||||
if (
|
||||
"content" not in event_dict
|
||||
): # could be a role message or empty termination
|
||||
|
Loading…
x
Reference in New Issue
Block a user