Add timeout option to OpenAI models

This commit is contained in:
Weves
2023-05-20 18:18:54 -07:00
committed by Chris Weaver
parent 62e86efec3
commit 806653dcb0
6 changed files with 175 additions and 134 deletions

View File

@@ -98,6 +98,7 @@ KEYWORD_MAX_HITS = 5
QUOTE_ALLOWED_ERROR_PERCENT = (
0.05 # 1 edit per 2 characters, currently unused due to fuzzy match being too slow
)
QA_TIMEOUT = 10 # 10 seconds
#####

View File

@@ -1,12 +1,9 @@
from typing import Any
from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.question_answer import OpenAIChatCompletionQA
from danswer.direct_qa.question_answer import OpenAICompletionQA
from danswer.dynamic_configs import get_dynamic_config_store
def get_default_backend_qa_model(
@@ -17,4 +14,4 @@ def get_default_backend_qa_model(
elif internal_model == "openai-chat-completion":
return OpenAIChatCompletionQA(**kwargs)
else:
raise ValueError("Wrong internal QA model set.")
raise ValueError("Unknown internal QA model set.")

View File

@@ -1,21 +1,25 @@
from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.question_answer import OpenAIQAModel
from danswer.dynamic_configs import get_dynamic_config_store
from openai.error import AuthenticationError
from openai.error import Timeout
def check_openai_api_key_is_valid(openai_api_key: str) -> bool:
if not openai_api_key:
return False
qa_model = get_default_backend_qa_model(api_key=openai_api_key)
qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=2)
if not isinstance(qa_model, OpenAIQAModel):
raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model")
try:
qa_model.answer_question("Do not respond", [])
return True
except AuthenticationError:
return False
# try for up to 3 timeouts (e.g. 6 seconds in total)
for _ in range(3):
try:
qa_model.answer_question("Do not respond", [])
return True
except AuthenticationError:
return False
except Timeout:
pass
return False

View File

@@ -7,8 +7,10 @@ from functools import wraps
from typing import Any
from typing import cast
from typing import Dict
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
import openai
@@ -37,6 +39,7 @@ from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import shared_precompare_cleanup
from danswer.utils.timing import log_function_time
from openai.error import AuthenticationError
from openai.error import Timeout
logger = setup_logger()
@@ -187,6 +190,48 @@ def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
return False
F = TypeVar("F", bound=Callable)
Answer = str
Quotes = dict[str, dict[str, str | int | None]]
ModelType = Literal["ChatCompletion", "Completion"]
PromptProcessor = Callable[[str, list[str]], str]
def _build_openai_settings(**kwargs: dict[str, Any]) -> dict[str, Any]:
"""
Utility to add in some common default values so they don't have to be set every time.
"""
return {
"temperature": 0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
**kwargs,
}
def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
@wraps(openai_call)
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
try:
if not kwargs.get("stream"):
return openai_call(*args, **kwargs)
# if streamed, the call returns a generator
yield from openai_call(*args, **kwargs)
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
except Timeout:
logger.exception("OpenAI API timed out for query: %s", query)
raise
except Exception as e:
logger.exception("Unexpected error with OpenAI API for query: %s", query)
raise
return cast(F, wrapped_call)
# used to check if the QAModel is an OpenAI model
class OpenAIQAModel(QAModel):
pass
@@ -199,11 +244,13 @@ class OpenAICompletionQA(OpenAIQAModel):
model_version: str = OPENAI_MODEL_VERSION,
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
api_key: str | None = None,
timeout: int | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.api_key = api_key or get_openai_api_key()
self.timeout = timeout
@log_function_time()
def answer_question(
@@ -213,28 +260,21 @@ class OpenAICompletionQA(OpenAIQAModel):
filled_prompt = self.prompt_processor(query, top_contents)
logger.debug(filled_prompt)
try:
response = openai.Completion.create(
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.Completion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=self.api_key,
prompt=filled_prompt,
temperature=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=self.model_version,
max_tokens=self.max_output_tokens,
)
model_output = response["choices"][0]["text"].strip()
logger.info(
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
)
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
except Exception as e:
logger.exception(e)
model_output = "Model Failure"
request_timeout=self.timeout,
),
)
model_output = cast(str, response["choices"][0]["text"]).strip()
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
@@ -247,46 +287,41 @@ class OpenAICompletionQA(OpenAIQAModel):
filled_prompt = self.prompt_processor(query, top_contents)
logger.debug(filled_prompt)
try:
response = openai.Completion.create(
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.Completion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=self.api_key,
prompt=filled_prompt,
temperature=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
stream=True,
)
),
)
model_output: str = ""
found_answer_start = False
found_answer_end = False
# iterate through the stream of events
for event in response:
event_text = cast(str, event["choices"][0]["text"])
model_previous = model_output
model_output += event_text
model_output: str = ""
found_answer_start = False
found_answer_end = False
# iterate through the stream of events
for event in response:
event_text = cast(str, event["choices"][0]["text"])
model_previous = model_output
model_output += event_text
if not found_answer_start and '{"answer":"' in model_output.replace(
" ", ""
).replace("\n", ""):
found_answer_start = True
continue
if not found_answer_start and '{"answer":"' in model_output.replace(
" ", ""
).replace("\n", ""):
found_answer_start = True
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
yield {"answer_finished": True}
continue
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
yield {"answer_finished": True}
continue
yield {"answer_data": event_text}
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
except Exception as e:
logger.exception(e)
model_output = "Model Failure"
yield {"answer_data": event_text}
logger.debug(model_output)
@@ -304,6 +339,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
] = json_chat_processor,
model_version: str = OPENAI_MODEL_VERSION,
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
timeout: int | None = None,
reflexion_try_count: int = 0,
api_key: str | None = None,
) -> None:
@@ -312,6 +348,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
self.max_output_tokens = max_output_tokens
self.reflexion_try_count = reflexion_try_count
self.api_key = api_key or get_openai_api_key()
self.timeout = timeout
@log_function_time()
def answer_question(
@@ -322,30 +359,25 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
logger.debug(messages)
model_output = ""
for _ in range(self.reflexion_try_count + 1):
try:
response = openai.ChatCompletion.create(
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.ChatCompletion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=self.api_key,
messages=messages,
temperature=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=self.model_version,
max_tokens=self.max_output_tokens,
)
model_output = response["choices"][0]["message"]["content"].strip()
assistant_msg = {"content": model_output, "role": "assistant"}
messages.extend([assistant_msg, get_chat_reflexion_msg()])
logger.info(
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
)
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
except Exception as e:
logger.exception(e)
logger.warning(f"Model failure for query: {query}")
return None, None
request_timeout=self.timeout,
),
)
model_output = response["choices"][0]["message"]["content"].strip()
assistant_msg = {"content": model_output, "role": "assistant"}
messages.extend([assistant_msg, get_chat_reflexion_msg()])
logger.info(
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
)
logger.debug(model_output)
@@ -359,50 +391,46 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
messages = self.prompt_processor(query, top_contents)
logger.debug(messages)
try:
response = openai.ChatCompletion.create(
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.ChatCompletion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=self.api_key,
messages=messages,
temperature=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
stream=True,
)
),
)
logger.info("Raw response: %s", response)
model_output: str = ""
found_answer_start = False
found_answer_end = False
for event in response:
event_dict = cast(str, event["choices"][0]["delta"])
if (
"content" not in event_dict
): # could be a role message or empty termination
continue
event_text = event_dict["content"]
model_previous = model_output
model_output += event_text
model_output: str = ""
found_answer_start = False
found_answer_end = False
for event in response:
event_dict = event["choices"][0]["delta"]
if (
"content" not in event_dict
): # could be a role message or empty termination
if not found_answer_start and '{"answer":"' in model_output.replace(
" ", ""
).replace("\n", ""):
found_answer_start = True
continue
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
yield {"answer_finished": True}
continue
event_text = event_dict["content"]
model_previous = model_output
model_output += event_text
if not found_answer_start and '{"answer":"' in model_output.replace(
" ", ""
).replace("\n", ""):
found_answer_start = True
continue
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
yield {"answer_finished": True}
continue
yield {"answer_data": event_text}
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
except Exception as e:
logger.exception(e)
model_output = "Model Failure"
yield {"answer_data": event_text}
logger.debug(model_output)

View File

@@ -6,6 +6,7 @@ from danswer.auth.users import current_active_user
from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import KEYWORD_MAX_HITS
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import CONTENT
from danswer.configs.constants import SOURCE_LINKS
from danswer.datastores import create_datastore
@@ -85,10 +86,14 @@ def direct_qa(
for chunk in ranked_chunks
]
qa_model = get_default_backend_qa_model()
answer, quotes = qa_model.answer_question(
query, ranked_chunks[:NUM_RERANKED_RESULTS]
)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
try:
answer, quotes = qa_model.answer_question(
query, ranked_chunks[:NUM_RERANKED_RESULTS]
)
except Exception:
# exception is logged in the answer_question method, no need to re-log
answer, quotes = None, None
logger.info(f"Total QA took {time.time() - start_time} seconds")
@@ -126,14 +131,19 @@ def stream_direct_qa(
top_docs_dict = {top_documents_key: [top_doc.json() for top_doc in top_docs]}
yield get_json_line(top_docs_dict)
qa_model = get_default_backend_qa_model()
for response_dict in qa_model.answer_question_stream(
query, ranked_chunks[:NUM_RERANKED_RESULTS]
):
if response_dict is None:
continue
logger.debug(response_dict)
yield get_json_line(response_dict)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
try:
for response_dict in qa_model.answer_question_stream(
query, ranked_chunks[:NUM_RERANKED_RESULTS]
):
if response_dict is None:
continue
logger.debug(response_dict)
yield get_json_line(response_dict)
except Exception:
# exception is logged in the answer_question method, no need to re-log
pass
return
return StreamingResponse(stream_qa_portions(), media_type="application/json")