Fix OpenAI key validation

This commit is contained in:
Weves 2023-05-22 13:41:03 -07:00 committed by Chris Weaver
parent 8e9e284849
commit f4ef92e279
2 changed files with 14 additions and 9 deletions

View File

@ -8,12 +8,12 @@ def check_openai_api_key_is_valid(openai_api_key: str) -> bool:
if not openai_api_key:
return False
qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=2)
qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=5)
if not isinstance(qa_model, OpenAIQAModel):
raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model")
# try for up to 3 timeouts (e.g. 6 seconds in total)
for _ in range(3):
# try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(2):
try:
qa_model.answer_question("Do not respond", [])
return True

View File

@ -215,10 +215,14 @@ def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
@wraps(openai_call)
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
try:
if not kwargs.get("stream"):
return openai_call(*args, **kwargs)
# if streamed, the call returns a generator
yield from openai_call(*args, **kwargs)
if kwargs.get("stream"):
def _generator():
yield from openai_call(*args, **kwargs)
return _generator()
return openai_call(*args, **kwargs)
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
@ -372,7 +376,9 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
request_timeout=self.timeout,
),
)
model_output = response["choices"][0]["message"]["content"].strip()
model_output = cast(
str, response["choices"][0]["message"]["content"]
).strip()
assistant_msg = {"content": model_output, "role": "assistant"}
messages.extend([assistant_msg, get_chat_reflexion_msg()])
logger.info(
@ -405,12 +411,11 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
stream=True,
),
)
logger.info("Raw response: %s", response)
model_output: str = ""
found_answer_start = False
found_answer_end = False
for event in response:
event_dict = cast(str, event["choices"][0]["delta"])
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
if (
"content" not in event_dict
): # could be a role message or empty termination