Adjust the way LLM class is instantiated + fix issue where .env file GEN_AI_API_KEY wasn't being used (#630)

This commit is contained in:
Chris Weaver
2023-10-25 22:33:18 -07:00
committed by GitHub
parent 604e511c09
commit 76275b29d4
7 changed files with 44 additions and 66 deletions

View File

@@ -82,10 +82,7 @@ INTERNAL_MODEL_VERSION = os.environ.get(
) )
# If the Generative AI model requires an API key for access, otherwise can leave blank # If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = ( GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))
or "dummy_llm_key"
)
# If using GPT4All, HuggingFace Inference API, or OpenAI - specify the model version # If using GPT4All, HuggingFace Inference API, or OpenAI - specify the model version
GEN_AI_MODEL_VERSION = os.environ.get( GEN_AI_MODEL_VERSION = os.environ.get(

View File

@@ -48,11 +48,12 @@ if API_TYPE_OPENAI in ["azure"]: # TODO: Azure AD support ["azure_ad", "azuread
def _ensure_openai_api_key(api_key: str | None) -> str: def _ensure_openai_api_key(api_key: str | None) -> str:
try: final_api_key = api_key or get_gen_ai_api_key()
return api_key or get_gen_ai_api_key() if final_api_key is None:
except ConfigNotFoundError:
raise OpenAIKeyMissing() raise OpenAIKeyMissing()
return final_api_key
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]: def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
""" """

View File

@@ -22,6 +22,7 @@ from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import QUOTE_PAT from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import check_number_of_tokens
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote from danswer.utils.text_processing import clean_model_quote
@@ -31,11 +32,15 @@ from danswer.utils.text_processing import shared_precompare_cleanup
logger = setup_logger() logger = setup_logger()
def get_gen_ai_api_key() -> str: def get_gen_ai_api_key() -> str | None:
return ( # first check if the key has been provided by the UI
cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)) try:
or GEN_AI_API_KEY return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
) except ConfigNotFoundError:
pass
# if not provided by the UI, fallback to the env variable
return GEN_AI_API_KEY
def extract_answer_quotes_freeform( def extract_answer_quotes_freeform(

View File

@@ -5,7 +5,6 @@ from langchain.chat_models.azure_openai import AzureChatOpenAI
from danswer.configs.model_configs import API_BASE_OPENAI from danswer.configs.model_configs import API_BASE_OPENAI
from danswer.configs.model_configs import API_VERSION_OPENAI from danswer.configs.model_configs import API_VERSION_OPENAI
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.llm.llm import LangChainChatLLM from danswer.llm.llm import LangChainChatLLM
from danswer.llm.utils import should_be_verbose from danswer.llm.utils import should_be_verbose
@@ -23,11 +22,6 @@ class AzureGPT(LangChainChatLLM):
*args: list[Any], *args: list[Any],
**kwargs: dict[str, Any] **kwargs: dict[str, Any]
): ):
# set a dummy API key if not specified so that LangChain doesn't throw an
# exception when trying to initialize the LLM which would prevent the API
# server from starting up
if not api_key:
api_key = GEN_AI_API_KEY
self._llm = AzureChatOpenAI( self._llm = AzureChatOpenAI(
model=model_version, model=model_version,
openai_api_type="azure", openai_api_type="azure",

View File

@@ -1,5 +1,3 @@
from typing import Any
from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType from danswer.configs.constants import ModelHostType
@@ -11,48 +9,40 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.qa_utils import get_gen_ai_api_key from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.azure import AzureGPT from danswer.llm.azure import AzureGPT
from danswer.llm.google_colab_demo import GoogleColabDemo from danswer.llm.google_colab_demo import GoogleColabDemo
from danswer.llm.llm import LLM from danswer.llm.llm import LLM
from danswer.llm.openai import OpenAIGPT from danswer.llm.openai import OpenAIGPT
def get_llm_from_model(model: str, **kwargs: Any) -> LLM:
if model == DanswerGenAIModel.OPENAI_CHAT.value:
if API_TYPE_OPENAI == "azure":
return AzureGPT(**kwargs)
return OpenAIGPT(**kwargs)
if (
model == DanswerGenAIModel.REQUEST.value
and kwargs.get("model_host_type") == ModelHostType.COLAB_DEMO
):
return GoogleColabDemo(**kwargs)
raise ValueError(f"Unknown LLM model: {model}")
def get_default_llm( def get_default_llm(
api_key: str | None = None, timeout: int = QA_TIMEOUT, **kwargs: Any api_key: str | None = None,
timeout: int = QA_TIMEOUT,
) -> LLM: ) -> LLM:
"""NOTE: api_key/timeout must be a special args since we may want to check """NOTE: api_key/timeout must be a special args since we may want to check
if an API key is valid for the default model setup OR we may want to use the if an API key is valid for the default model setup OR we may want to use the
default model with a different timeout specified.""" default model with a different timeout specified."""
if api_key is None: if api_key is None:
try:
api_key = get_gen_ai_api_key() api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
# if no API key is found, assume this model doesn't need one
pass
return get_llm_from_model( model_args = {
model=INTERNAL_MODEL_VERSION, # provide a dummy key since LangChain will throw an exception if not
api_key=api_key, # given, which would prevent server startup
timeout=timeout, "api_key": api_key or "dummy_api_key",
model_version=GEN_AI_MODEL_VERSION, "timeout": timeout,
endpoint=GEN_AI_ENDPOINT, "model_version": GEN_AI_MODEL_VERSION,
model_host_type=GEN_AI_HOST_TYPE, "endpoint": GEN_AI_ENDPOINT,
max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS, "max_output_tokens": GEN_AI_MAX_OUTPUT_TOKENS,
temperature=GEN_AI_TEMPERATURE, "temperature": GEN_AI_TEMPERATURE,
**kwargs, }
) if INTERNAL_MODEL_VERSION == DanswerGenAIModel.OPENAI_CHAT.value:
if API_TYPE_OPENAI == "azure":
return AzureGPT(**model_args) # type: ignore
return OpenAIGPT(**model_args) # type: ignore
if (
INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value
and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO
):
return GoogleColabDemo(**model_args) # type: ignore
raise ValueError(f"Unknown LLM model: {INTERNAL_MODEL_VERSION}")

View File

@@ -3,7 +3,6 @@ from typing import cast
from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.openai import ChatOpenAI
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.llm import LangChainChatLLM from danswer.llm.llm import LangChainChatLLM
from danswer.llm.utils import should_be_verbose from danswer.llm.utils import should_be_verbose
@@ -20,12 +19,6 @@ class OpenAIGPT(LangChainChatLLM):
*args: list[Any], *args: list[Any],
**kwargs: dict[str, Any] **kwargs: dict[str, Any]
): ):
# set a dummy API key if not specified so that LangChain doesn't throw an
# exception when trying to initialize the LLM which would prevent the API
# server from starting up
if not api_key:
api_key = GEN_AI_API_KEY
self._llm = ChatOpenAI( self._llm = ChatOpenAI(
model=model_version, model=model_version,
openai_api_key=api_key, openai_api_key=api_key,

View File

@@ -502,14 +502,9 @@ def validate_existing_genai_api_key(
# First time checking the key, nothing unusual # First time checking the key, nothing unusual
pass pass
try:
genai_api_key = get_gen_ai_api_key() genai_api_key = get_gen_ai_api_key()
except ConfigNotFoundError: if genai_api_key is None:
raise HTTPException(status_code=404, detail="Key not found") raise HTTPException(status_code=404, detail="Key not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
try: try:
is_valid = check_model_api_key_is_valid(genai_api_key) is_valid = check_model_api_key_is_valid(genai_api_key)
@@ -520,6 +515,9 @@ def validate_existing_genai_api_key(
if not is_valid: if not is_valid:
raise HTTPException(status_code=400, detail="Invalid API key provided") raise HTTPException(status_code=400, detail="Invalid API key provided")
# mark check as successful
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
@router.get("/admin/genai-api-key", response_model=ApiKey) @router.get("/admin/genai-api-key", response_model=ApiKey)
def get_gen_ai_api_key_from_dynamic_config_store( def get_gen_ai_api_key_from_dynamic_config_store(