Fix OpenAI key validation

This commit is contained in:
Weves
2023-05-22 13:41:03 -07:00
committed by Chris Weaver
parent 8e9e284849
commit f4ef92e279
2 changed files with 14 additions and 9 deletions

View File

@@ -8,12 +8,12 @@ def check_openai_api_key_is_valid(openai_api_key: str) -> bool:
if not openai_api_key: if not openai_api_key:
return False return False
qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=2) qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=5)
if not isinstance(qa_model, OpenAIQAModel): if not isinstance(qa_model, OpenAIQAModel):
raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model") raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model")
# try for up to 3 timeouts (e.g. 6 seconds in total) # try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(3): for _ in range(2):
try: try:
qa_model.answer_question("Do not respond", []) qa_model.answer_question("Do not respond", [])
return True return True

View File

@@ -215,10 +215,14 @@ def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
@wraps(openai_call) @wraps(openai_call)
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any: def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
try: try:
if not kwargs.get("stream"):
return openai_call(*args, **kwargs)
# if streamed, the call returns a generator # if streamed, the call returns a generator
yield from openai_call(*args, **kwargs) if kwargs.get("stream"):
def _generator():
yield from openai_call(*args, **kwargs)
return _generator()
return openai_call(*args, **kwargs)
except AuthenticationError: except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API") logger.exception("Failed to authenticate with OpenAI API")
raise raise
@@ -372,7 +376,9 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
request_timeout=self.timeout, request_timeout=self.timeout,
), ),
) )
model_output = response["choices"][0]["message"]["content"].strip() model_output = cast(
str, response["choices"][0]["message"]["content"]
).strip()
assistant_msg = {"content": model_output, "role": "assistant"} assistant_msg = {"content": model_output, "role": "assistant"}
messages.extend([assistant_msg, get_chat_reflexion_msg()]) messages.extend([assistant_msg, get_chat_reflexion_msg()])
logger.info( logger.info(
@@ -405,12 +411,11 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
stream=True, stream=True,
), ),
) )
logger.info("Raw response: %s", response)
model_output: str = "" model_output: str = ""
found_answer_start = False found_answer_start = False
found_answer_end = False found_answer_end = False
for event in response: for event in response:
event_dict = cast(str, event["choices"][0]["delta"]) event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
if ( if (
"content" not in event_dict "content" not in event_dict
): # could be a role message or empty termination ): # could be a role message or empty termination