diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index d296a6b90f68..89b500da9854 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -82,10 +82,7 @@ INTERNAL_MODEL_VERSION = os.environ.get( ) # If the Generative AI model requires an API key for access, otherwise can leave blank -GEN_AI_API_KEY = ( - os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) - or "dummy_llm_key" -) +GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY")) # If using GPT4All, HuggingFace Inference API, or OpenAI - specify the model version GEN_AI_MODEL_VERSION = os.environ.get( diff --git a/backend/danswer/direct_qa/open_ai.py b/backend/danswer/direct_qa/open_ai.py index 3f1e0f97e80f..2c2dc1526ec4 100644 --- a/backend/danswer/direct_qa/open_ai.py +++ b/backend/danswer/direct_qa/open_ai.py @@ -48,11 +48,12 @@ if API_TYPE_OPENAI in ["azure"]: # TODO: Azure AD support ["azure_ad", "azuread def _ensure_openai_api_key(api_key: str | None) -> str: - try: - return api_key or get_gen_ai_api_key() - except ConfigNotFoundError: + final_api_key = api_key or get_gen_ai_api_key() + if final_api_key is None: raise OpenAIKeyMissing() + return final_api_key + def _build_openai_settings(**kwargs: Any) -> dict[str, Any]: """ diff --git a/backend/danswer/direct_qa/qa_utils.py b/backend/danswer/direct_qa/qa_utils.py index 838e3012c5c0..4fd5f58419ef 100644 --- a/backend/danswer/direct_qa/qa_utils.py +++ b/backend/danswer/direct_qa/qa_utils.py @@ -22,6 +22,7 @@ from danswer.direct_qa.qa_prompts import ANSWER_PAT from danswer.direct_qa.qa_prompts import QUOTE_PAT from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.llm.utils import check_number_of_tokens from danswer.utils.logger import setup_logger from danswer.utils.text_processing import clean_model_quote @@ -31,11 +32,15 @@ from danswer.utils.text_processing import shared_precompare_cleanup logger = setup_logger() -def get_gen_ai_api_key() -> str: - return ( - cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)) - or GEN_AI_API_KEY - ) +def get_gen_ai_api_key() -> str | None: + # first check if the key has been provided by the UI + try: + return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)) + except ConfigNotFoundError: + pass + + # if not provided by the UI, fallback to the env variable + return GEN_AI_API_KEY def extract_answer_quotes_freeform( diff --git a/backend/danswer/llm/azure.py b/backend/danswer/llm/azure.py index 60ee3b4095a8..e2ec1fa35c10 100644 --- a/backend/danswer/llm/azure.py +++ b/backend/danswer/llm/azure.py @@ -5,7 +5,6 @@ from langchain.chat_models.azure_openai import AzureChatOpenAI from danswer.configs.model_configs import API_BASE_OPENAI from danswer.configs.model_configs import API_VERSION_OPENAI from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID -from danswer.configs.model_configs import GEN_AI_API_KEY from danswer.llm.llm import LangChainChatLLM from danswer.llm.utils import should_be_verbose @@ -23,11 +22,6 @@ class AzureGPT(LangChainChatLLM): *args: list[Any], **kwargs: dict[str, Any] ): - # set a dummy API key if not specified so that LangChain doesn't throw an - # exception when trying to initialize the LLM which would prevent the API - # server from starting up - if not api_key: - api_key = GEN_AI_API_KEY self._llm = AzureChatOpenAI( model=model_version, openai_api_type="azure", diff --git a/backend/danswer/llm/build.py b/backend/danswer/llm/build.py index 9e3007001632..5ec9fae9e980 100644 --- a/backend/danswer/llm/build.py +++ b/backend/danswer/llm/build.py @@ -1,5 +1,3 @@ -from typing import Any - from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.constants import DanswerGenAIModel from danswer.configs.constants import ModelHostType @@ -11,48 +9,40 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.direct_qa.qa_utils import get_gen_ai_api_key -from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.llm.azure import AzureGPT from danswer.llm.google_colab_demo import GoogleColabDemo from danswer.llm.llm import LLM from danswer.llm.openai import OpenAIGPT -def get_llm_from_model(model: str, **kwargs: Any) -> LLM: - if model == DanswerGenAIModel.OPENAI_CHAT.value: - if API_TYPE_OPENAI == "azure": - return AzureGPT(**kwargs) - return OpenAIGPT(**kwargs) - if ( - model == DanswerGenAIModel.REQUEST.value - and kwargs.get("model_host_type") == ModelHostType.COLAB_DEMO - ): - return GoogleColabDemo(**kwargs) - - raise ValueError(f"Unknown LLM model: {model}") - - def get_default_llm( - api_key: str | None = None, timeout: int = QA_TIMEOUT, **kwargs: Any + api_key: str | None = None, + timeout: int = QA_TIMEOUT, ) -> LLM: """NOTE: api_key/timeout must be a special args since we may want to check if an API key is valid for the default model setup OR we may want to use the default model with a different timeout specified.""" if api_key is None: - try: - api_key = get_gen_ai_api_key() - except ConfigNotFoundError: - # if no API key is found, assume this model doesn't need one - pass + api_key = get_gen_ai_api_key() - return get_llm_from_model( - model=INTERNAL_MODEL_VERSION, - api_key=api_key, - timeout=timeout, - model_version=GEN_AI_MODEL_VERSION, - endpoint=GEN_AI_ENDPOINT, - model_host_type=GEN_AI_HOST_TYPE, - max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS, - temperature=GEN_AI_TEMPERATURE, - **kwargs, - ) + model_args = { + # provide a dummy key since LangChain will throw an exception if not + # given, which would prevent server startup + "api_key": api_key or "dummy_api_key", + "timeout": timeout, + "model_version": GEN_AI_MODEL_VERSION, + "endpoint": GEN_AI_ENDPOINT, + "max_output_tokens": GEN_AI_MAX_OUTPUT_TOKENS, + "temperature": GEN_AI_TEMPERATURE, + } + if INTERNAL_MODEL_VERSION == DanswerGenAIModel.OPENAI_CHAT.value: + if API_TYPE_OPENAI == "azure": + return AzureGPT(**model_args) # type: ignore + return OpenAIGPT(**model_args) # type: ignore + if ( + INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value + and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO + ): + return GoogleColabDemo(**model_args) # type: ignore + + raise ValueError(f"Unknown LLM model: {INTERNAL_MODEL_VERSION}") diff --git a/backend/danswer/llm/openai.py b/backend/danswer/llm/openai.py index 638aa134c810..90fce9035229 100644 --- a/backend/danswer/llm/openai.py +++ b/backend/danswer/llm/openai.py @@ -3,7 +3,6 @@ from typing import cast from langchain.chat_models.openai import ChatOpenAI -from danswer.configs.model_configs import GEN_AI_API_KEY from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.llm.llm import LangChainChatLLM from danswer.llm.utils import should_be_verbose @@ -20,12 +19,6 @@ class OpenAIGPT(LangChainChatLLM): *args: list[Any], **kwargs: dict[str, Any] ): - # set a dummy API key if not specified so that LangChain doesn't throw an - # exception when trying to initialize the LLM which would prevent the API - # server from starting up - if not api_key: - api_key = GEN_AI_API_KEY - self._llm = ChatOpenAI( model=model_version, openai_api_key=api_key, diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index d7954e639d8e..fda0ceafd786 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -502,14 +502,9 @@ def validate_existing_genai_api_key( # First time checking the key, nothing unusual pass - try: - genai_api_key = get_gen_ai_api_key() - except ConfigNotFoundError: + genai_api_key = get_gen_ai_api_key() + if genai_api_key is None: raise HTTPException(status_code=404, detail="Key not found") - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - - get_dynamic_config_store().store(check_key_time, curr_time.timestamp()) try: is_valid = check_model_api_key_is_valid(genai_api_key) @@ -520,6 +515,9 @@ def validate_existing_genai_api_key( if not is_valid: raise HTTPException(status_code=400, detail="Invalid API key provided") + # mark check as successful + get_dynamic_config_store().store(check_key_time, curr_time.timestamp()) + @router.get("/admin/genai-api-key", response_model=ApiKey) def get_gen_ai_api_key_from_dynamic_config_store(