mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Add timeout option to OpenAI models
This commit is contained in:
@@ -98,6 +98,7 @@ KEYWORD_MAX_HITS = 5
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = (
|
||||
0.05 # 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
||||
)
|
||||
QA_TIMEOUT = 10 # 10 seconds
|
||||
|
||||
|
||||
#####
|
||||
|
@@ -1,12 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.app_configs import OPENAI_API_KEY
|
||||
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.question_answer import OpenAIChatCompletionQA
|
||||
from danswer.direct_qa.question_answer import OpenAICompletionQA
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
|
||||
|
||||
def get_default_backend_qa_model(
|
||||
@@ -17,4 +14,4 @@ def get_default_backend_qa_model(
|
||||
elif internal_model == "openai-chat-completion":
|
||||
return OpenAIChatCompletionQA(**kwargs)
|
||||
else:
|
||||
raise ValueError("Wrong internal QA model set.")
|
||||
raise ValueError("Unknown internal QA model set.")
|
||||
|
@@ -1,21 +1,25 @@
|
||||
from danswer.configs.app_configs import OPENAI_API_KEY
|
||||
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.question_answer import OpenAIQAModel
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from openai.error import AuthenticationError
|
||||
from openai.error import Timeout
|
||||
|
||||
|
||||
def check_openai_api_key_is_valid(openai_api_key: str) -> bool:
|
||||
if not openai_api_key:
|
||||
return False
|
||||
|
||||
qa_model = get_default_backend_qa_model(api_key=openai_api_key)
|
||||
qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=2)
|
||||
if not isinstance(qa_model, OpenAIQAModel):
|
||||
raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model")
|
||||
|
||||
try:
|
||||
qa_model.answer_question("Do not respond", [])
|
||||
return True
|
||||
except AuthenticationError:
|
||||
return False
|
||||
# try for up to 3 timeouts (e.g. 6 seconds in total)
|
||||
for _ in range(3):
|
||||
try:
|
||||
qa_model.answer_question("Do not respond", [])
|
||||
return True
|
||||
except AuthenticationError:
|
||||
return False
|
||||
except Timeout:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
@@ -7,8 +7,10 @@ from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
import openai
|
||||
@@ -37,6 +39,7 @@ from danswer.utils.text_processing import clean_model_quote
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
from danswer.utils.timing import log_function_time
|
||||
from openai.error import AuthenticationError
|
||||
from openai.error import Timeout
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -187,6 +190,48 @@ def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
||||
Answer = str
|
||||
Quotes = dict[str, dict[str, str | int | None]]
|
||||
ModelType = Literal["ChatCompletion", "Completion"]
|
||||
PromptProcessor = Callable[[str, list[str]], str]
|
||||
|
||||
|
||||
def _build_openai_settings(**kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Utility to add in some common default values so they don't have to be set every time.
|
||||
"""
|
||||
return {
|
||||
"temperature": 0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
|
||||
@wraps(openai_call)
|
||||
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
||||
try:
|
||||
if not kwargs.get("stream"):
|
||||
return openai_call(*args, **kwargs)
|
||||
# if streamed, the call returns a generator
|
||||
yield from openai_call(*args, **kwargs)
|
||||
except AuthenticationError:
|
||||
logger.exception("Failed to authenticate with OpenAI API")
|
||||
raise
|
||||
except Timeout:
|
||||
logger.exception("OpenAI API timed out for query: %s", query)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error with OpenAI API for query: %s", query)
|
||||
raise
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
# used to check if the QAModel is an OpenAI model
|
||||
class OpenAIQAModel(QAModel):
|
||||
pass
|
||||
@@ -199,11 +244,13 @@ class OpenAICompletionQA(OpenAIQAModel):
|
||||
model_version: str = OPENAI_MODEL_VERSION,
|
||||
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
||||
api_key: str | None = None,
|
||||
timeout: int | None = None,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.api_key = api_key or get_openai_api_key()
|
||||
self.timeout = timeout
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
@@ -213,28 +260,21 @@ class OpenAICompletionQA(OpenAIQAModel):
|
||||
filled_prompt = self.prompt_processor(query, top_contents)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.Completion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=self.api_key,
|
||||
prompt=filled_prompt,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
)
|
||||
model_output = response["choices"][0]["text"].strip()
|
||||
logger.info(
|
||||
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
||||
)
|
||||
except AuthenticationError:
|
||||
logger.exception("Failed to authenticate with OpenAI API")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
model_output = "Model Failure"
|
||||
|
||||
request_timeout=self.timeout,
|
||||
),
|
||||
)
|
||||
model_output = cast(str, response["choices"][0]["text"]).strip()
|
||||
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
@@ -247,46 +287,41 @@ class OpenAICompletionQA(OpenAIQAModel):
|
||||
filled_prompt = self.prompt_processor(query, top_contents)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.Completion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=self.api_key,
|
||||
prompt=filled_prompt,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
stream=True,
|
||||
)
|
||||
),
|
||||
)
|
||||
model_output: str = ""
|
||||
found_answer_start = False
|
||||
found_answer_end = False
|
||||
# iterate through the stream of events
|
||||
for event in response:
|
||||
event_text = cast(str, event["choices"][0]["text"])
|
||||
model_previous = model_output
|
||||
model_output += event_text
|
||||
|
||||
model_output: str = ""
|
||||
found_answer_start = False
|
||||
found_answer_end = False
|
||||
# iterate through the stream of events
|
||||
for event in response:
|
||||
event_text = cast(str, event["choices"][0]["text"])
|
||||
model_previous = model_output
|
||||
model_output += event_text
|
||||
if not found_answer_start and '{"answer":"' in model_output.replace(
|
||||
" ", ""
|
||||
).replace("\n", ""):
|
||||
found_answer_start = True
|
||||
continue
|
||||
|
||||
if not found_answer_start and '{"answer":"' in model_output.replace(
|
||||
" ", ""
|
||||
).replace("\n", ""):
|
||||
found_answer_start = True
|
||||
if found_answer_start and not found_answer_end:
|
||||
if stream_answer_end(model_previous, event_text):
|
||||
found_answer_end = True
|
||||
yield {"answer_finished": True}
|
||||
continue
|
||||
|
||||
if found_answer_start and not found_answer_end:
|
||||
if stream_answer_end(model_previous, event_text):
|
||||
found_answer_end = True
|
||||
yield {"answer_finished": True}
|
||||
continue
|
||||
yield {"answer_data": event_text}
|
||||
except AuthenticationError:
|
||||
logger.exception("Failed to authenticate with OpenAI API")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
model_output = "Model Failure"
|
||||
yield {"answer_data": event_text}
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
@@ -304,6 +339,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
] = json_chat_processor,
|
||||
model_version: str = OPENAI_MODEL_VERSION,
|
||||
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
||||
timeout: int | None = None,
|
||||
reflexion_try_count: int = 0,
|
||||
api_key: str | None = None,
|
||||
) -> None:
|
||||
@@ -312,6 +348,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.reflexion_try_count = reflexion_try_count
|
||||
self.api_key = api_key or get_openai_api_key()
|
||||
self.timeout = timeout
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
@@ -322,30 +359,25 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
logger.debug(messages)
|
||||
model_output = ""
|
||||
for _ in range(self.reflexion_try_count + 1):
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.ChatCompletion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=self.api_key,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
)
|
||||
model_output = response["choices"][0]["message"]["content"].strip()
|
||||
assistant_msg = {"content": model_output, "role": "assistant"}
|
||||
messages.extend([assistant_msg, get_chat_reflexion_msg()])
|
||||
logger.info(
|
||||
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
||||
)
|
||||
except AuthenticationError:
|
||||
logger.exception("Failed to authenticate with OpenAI API")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
logger.warning(f"Model failure for query: {query}")
|
||||
return None, None
|
||||
request_timeout=self.timeout,
|
||||
),
|
||||
)
|
||||
model_output = response["choices"][0]["message"]["content"].strip()
|
||||
assistant_msg = {"content": model_output, "role": "assistant"}
|
||||
messages.extend([assistant_msg, get_chat_reflexion_msg()])
|
||||
logger.info(
|
||||
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
||||
)
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
@@ -359,50 +391,46 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
messages = self.prompt_processor(query, top_contents)
|
||||
logger.debug(messages)
|
||||
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.ChatCompletion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=self.api_key,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
stream=True,
|
||||
)
|
||||
),
|
||||
)
|
||||
logger.info("Raw response: %s", response)
|
||||
model_output: str = ""
|
||||
found_answer_start = False
|
||||
found_answer_end = False
|
||||
for event in response:
|
||||
event_dict = cast(str, event["choices"][0]["delta"])
|
||||
if (
|
||||
"content" not in event_dict
|
||||
): # could be a role message or empty termination
|
||||
continue
|
||||
event_text = event_dict["content"]
|
||||
model_previous = model_output
|
||||
model_output += event_text
|
||||
|
||||
model_output: str = ""
|
||||
found_answer_start = False
|
||||
found_answer_end = False
|
||||
for event in response:
|
||||
event_dict = event["choices"][0]["delta"]
|
||||
if (
|
||||
"content" not in event_dict
|
||||
): # could be a role message or empty termination
|
||||
if not found_answer_start and '{"answer":"' in model_output.replace(
|
||||
" ", ""
|
||||
).replace("\n", ""):
|
||||
found_answer_start = True
|
||||
continue
|
||||
|
||||
if found_answer_start and not found_answer_end:
|
||||
if stream_answer_end(model_previous, event_text):
|
||||
found_answer_end = True
|
||||
yield {"answer_finished": True}
|
||||
continue
|
||||
event_text = event_dict["content"]
|
||||
model_previous = model_output
|
||||
model_output += event_text
|
||||
|
||||
if not found_answer_start and '{"answer":"' in model_output.replace(
|
||||
" ", ""
|
||||
).replace("\n", ""):
|
||||
found_answer_start = True
|
||||
continue
|
||||
|
||||
if found_answer_start and not found_answer_end:
|
||||
if stream_answer_end(model_previous, event_text):
|
||||
found_answer_end = True
|
||||
yield {"answer_finished": True}
|
||||
continue
|
||||
yield {"answer_data": event_text}
|
||||
except AuthenticationError:
|
||||
logger.exception("Failed to authenticate with OpenAI API")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
model_output = "Model Failure"
|
||||
yield {"answer_data": event_text}
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
|
@@ -6,6 +6,7 @@ from danswer.auth.users import current_active_user
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.configs.app_configs import KEYWORD_MAX_HITS
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import CONTENT
|
||||
from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.datastores import create_datastore
|
||||
@@ -85,10 +86,14 @@ def direct_qa(
|
||||
for chunk in ranked_chunks
|
||||
]
|
||||
|
||||
qa_model = get_default_backend_qa_model()
|
||||
answer, quotes = qa_model.answer_question(
|
||||
query, ranked_chunks[:NUM_RERANKED_RESULTS]
|
||||
)
|
||||
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
||||
try:
|
||||
answer, quotes = qa_model.answer_question(
|
||||
query, ranked_chunks[:NUM_RERANKED_RESULTS]
|
||||
)
|
||||
except Exception:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
answer, quotes = None, None
|
||||
|
||||
logger.info(f"Total QA took {time.time() - start_time} seconds")
|
||||
|
||||
@@ -126,14 +131,19 @@ def stream_direct_qa(
|
||||
top_docs_dict = {top_documents_key: [top_doc.json() for top_doc in top_docs]}
|
||||
yield get_json_line(top_docs_dict)
|
||||
|
||||
qa_model = get_default_backend_qa_model()
|
||||
for response_dict in qa_model.answer_question_stream(
|
||||
query, ranked_chunks[:NUM_RERANKED_RESULTS]
|
||||
):
|
||||
if response_dict is None:
|
||||
continue
|
||||
logger.debug(response_dict)
|
||||
yield get_json_line(response_dict)
|
||||
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
||||
try:
|
||||
for response_dict in qa_model.answer_question_stream(
|
||||
query, ranked_chunks[:NUM_RERANKED_RESULTS]
|
||||
):
|
||||
if response_dict is None:
|
||||
continue
|
||||
logger.debug(response_dict)
|
||||
yield get_json_line(response_dict)
|
||||
except Exception:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
pass
|
||||
|
||||
return
|
||||
|
||||
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
||||
|
Reference in New Issue
Block a user