Fix API key specification bug

This commit is contained in:
Weves 2023-09-01 09:52:17 -07:00 committed by Chris Weaver
parent bddf03cd54
commit 0d4244f990
2 changed files with 19 additions and 6 deletions

View File

@ -36,12 +36,12 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
if not model_api_key:
return False
qa_model = get_default_qa_model(api_key=model_api_key, timeout=5)
llm = get_default_llm(api_key=model_api_key, timeout=5)
# try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(2):
try:
qa_model.answer_question("Do not respond", [])
llm.invoke("Do not respond")
return True
except AuthenticationError:
return False

View File

@ -3,12 +3,13 @@ from typing import Any
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.model_configs import API_TYPE_OPENAI
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.azure import AzureGPT
from danswer.llm.llm import LLM
from danswer.llm.openai import OpenAIGPT
@ -23,14 +24,26 @@ def get_llm_from_model(model: str, **kwargs: Any) -> LLM:
raise ValueError(f"Unknown LLM model: {model}")
def get_default_llm(**kwargs: Any) -> LLM:
def get_default_llm(
api_key: str | None = None, timeout: int = QA_TIMEOUT, **kwargs: Any
) -> LLM:
"""NOTE: api_key/timeout must be a special args since we may want to check
if an API key is valid for the default model setup OR we may want to use the
default model with a different timeout specified."""
if api_key is None:
try:
api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
# if no API key is found, assume this model doesn't need one
pass
return get_llm_from_model(
model=INTERNAL_MODEL_VERSION,
api_key=GEN_AI_API_KEY,
api_key=api_key,
timeout=timeout,
model_version=GEN_AI_MODEL_VERSION,
endpoint=GEN_AI_ENDPOINT,
model_host_type=GEN_AI_HOST_TYPE,
timeout=QA_TIMEOUT,
max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS,
**kwargs,
)