From 0d4244f9903b0c869263a1dd0b97f8bb7aa9cc3d Mon Sep 17 00:00:00 2001 From: Weves Date: Fri, 1 Sep 2023 09:52:17 -0700 Subject: [PATCH] Fix API key specification bug --- backend/danswer/direct_qa/llm_utils.py | 4 ++-- backend/danswer/llm/build.py | 21 +++++++++++++++++---- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/backend/danswer/direct_qa/llm_utils.py b/backend/danswer/direct_qa/llm_utils.py index c87cd343d..e98734f32 100644 --- a/backend/danswer/direct_qa/llm_utils.py +++ b/backend/danswer/direct_qa/llm_utils.py @@ -36,12 +36,12 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool: if not model_api_key: return False - qa_model = get_default_qa_model(api_key=model_api_key, timeout=5) + llm = get_default_llm(api_key=model_api_key, timeout=5) # try for up to 2 timeouts (e.g. 10 seconds in total) for _ in range(2): try: - qa_model.answer_question("Do not respond", []) + llm.invoke("Do not respond") return True except AuthenticationError: return False diff --git a/backend/danswer/llm/build.py b/backend/danswer/llm/build.py index 31939065c..39c9cd3f7 100644 --- a/backend/danswer/llm/build.py +++ b/backend/danswer/llm/build.py @@ -3,12 +3,13 @@ from typing import Any from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.constants import DanswerGenAIModel from danswer.configs.model_configs import API_TYPE_OPENAI -from danswer.configs.model_configs import GEN_AI_API_KEY from danswer.configs.model_configs import GEN_AI_ENDPOINT from danswer.configs.model_configs import GEN_AI_HOST_TYPE from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import INTERNAL_MODEL_VERSION +from danswer.direct_qa.qa_utils import get_gen_ai_api_key +from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.llm.azure import AzureGPT from danswer.llm.llm import LLM from danswer.llm.openai import OpenAIGPT @@ -23,14 +24,26 @@ def get_llm_from_model(model: str, **kwargs: Any) -> LLM: raise ValueError(f"Unknown LLM model: {model}") -def get_default_llm(**kwargs: Any) -> LLM: +def get_default_llm( + api_key: str | None = None, timeout: int = QA_TIMEOUT, **kwargs: Any +) -> LLM: + """NOTE: api_key/timeout must be a special args since we may want to check + if an API key is valid for the default model setup OR we may want to use the + default model with a different timeout specified.""" + if api_key is None: + try: + api_key = get_gen_ai_api_key() + except ConfigNotFoundError: + # if no API key is found, assume this model doesn't need one + pass + return get_llm_from_model( model=INTERNAL_MODEL_VERSION, - api_key=GEN_AI_API_KEY, + api_key=api_key, + timeout=timeout, model_version=GEN_AI_MODEL_VERSION, endpoint=GEN_AI_ENDPOINT, model_host_type=GEN_AI_HOST_TYPE, - timeout=QA_TIMEOUT, max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS, **kwargs, )