mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-26 20:08:38 +02:00
Fix OpenAI key validation
This commit is contained in:
@@ -8,12 +8,12 @@ def check_openai_api_key_is_valid(openai_api_key: str) -> bool:
|
|||||||
if not openai_api_key:
|
if not openai_api_key:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=2)
|
qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=5)
|
||||||
if not isinstance(qa_model, OpenAIQAModel):
|
if not isinstance(qa_model, OpenAIQAModel):
|
||||||
raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model")
|
raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model")
|
||||||
|
|
||||||
# try for up to 3 timeouts (e.g. 6 seconds in total)
|
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||||
for _ in range(3):
|
for _ in range(2):
|
||||||
try:
|
try:
|
||||||
qa_model.answer_question("Do not respond", [])
|
qa_model.answer_question("Do not respond", [])
|
||||||
return True
|
return True
|
||||||
|
@@ -215,10 +215,14 @@ def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
|
|||||||
@wraps(openai_call)
|
@wraps(openai_call)
|
||||||
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
||||||
try:
|
try:
|
||||||
if not kwargs.get("stream"):
|
|
||||||
return openai_call(*args, **kwargs)
|
|
||||||
# if streamed, the call returns a generator
|
# if streamed, the call returns a generator
|
||||||
yield from openai_call(*args, **kwargs)
|
if kwargs.get("stream"):
|
||||||
|
|
||||||
|
def _generator():
|
||||||
|
yield from openai_call(*args, **kwargs)
|
||||||
|
|
||||||
|
return _generator()
|
||||||
|
return openai_call(*args, **kwargs)
|
||||||
except AuthenticationError:
|
except AuthenticationError:
|
||||||
logger.exception("Failed to authenticate with OpenAI API")
|
logger.exception("Failed to authenticate with OpenAI API")
|
||||||
raise
|
raise
|
||||||
@@ -372,7 +376,9 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
|||||||
request_timeout=self.timeout,
|
request_timeout=self.timeout,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
model_output = response["choices"][0]["message"]["content"].strip()
|
model_output = cast(
|
||||||
|
str, response["choices"][0]["message"]["content"]
|
||||||
|
).strip()
|
||||||
assistant_msg = {"content": model_output, "role": "assistant"}
|
assistant_msg = {"content": model_output, "role": "assistant"}
|
||||||
messages.extend([assistant_msg, get_chat_reflexion_msg()])
|
messages.extend([assistant_msg, get_chat_reflexion_msg()])
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -405,12 +411,11 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
|||||||
stream=True,
|
stream=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
logger.info("Raw response: %s", response)
|
|
||||||
model_output: str = ""
|
model_output: str = ""
|
||||||
found_answer_start = False
|
found_answer_start = False
|
||||||
found_answer_end = False
|
found_answer_end = False
|
||||||
for event in response:
|
for event in response:
|
||||||
event_dict = cast(str, event["choices"][0]["delta"])
|
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
|
||||||
if (
|
if (
|
||||||
"content" not in event_dict
|
"content" not in event_dict
|
||||||
): # could be a role message or empty termination
|
): # could be a role message or empty termination
|
||||||
|
Reference in New Issue
Block a user