Adjust the way LLM class is instantiated + fix issue where .env file GEN_AI_API_KEY wasn't being used (#630)

This commit is contained in:
Chris Weaver
2023-10-25 22:33:18 -07:00
committed by GitHub
parent 604e511c09
commit 76275b29d4
7 changed files with 44 additions and 66 deletions

View File

@@ -82,10 +82,7 @@ INTERNAL_MODEL_VERSION = os.environ.get(
)
# If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = (
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))
or "dummy_llm_key"
)
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))
# If using GPT4All, HuggingFace Inference API, or OpenAI - specify the model version
GEN_AI_MODEL_VERSION = os.environ.get(

View File

@@ -48,11 +48,12 @@ if API_TYPE_OPENAI in ["azure"]: # TODO: Azure AD support ["azure_ad", "azuread
def _ensure_openai_api_key(api_key: str | None) -> str:
try:
return api_key or get_gen_ai_api_key()
except ConfigNotFoundError:
final_api_key = api_key or get_gen_ai_api_key()
if final_api_key is None:
raise OpenAIKeyMissing()
return final_api_key
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
"""

View File

@@ -22,6 +22,7 @@ from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.utils import check_number_of_tokens
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
@@ -31,11 +32,15 @@ from danswer.utils.text_processing import shared_precompare_cleanup
logger = setup_logger()
def get_gen_ai_api_key() -> str:
return (
cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
or GEN_AI_API_KEY
)
def get_gen_ai_api_key() -> str | None:
# first check if the key has been provided by the UI
try:
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
except ConfigNotFoundError:
pass
# if not provided by the UI, fallback to the env variable
return GEN_AI_API_KEY
def extract_answer_quotes_freeform(

View File

@@ -5,7 +5,6 @@ from langchain.chat_models.azure_openai import AzureChatOpenAI
from danswer.configs.model_configs import API_BASE_OPENAI
from danswer.configs.model_configs import API_VERSION_OPENAI
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.llm.llm import LangChainChatLLM
from danswer.llm.utils import should_be_verbose
@@ -23,11 +22,6 @@ class AzureGPT(LangChainChatLLM):
*args: list[Any],
**kwargs: dict[str, Any]
):
# set a dummy API key if not specified so that LangChain doesn't throw an
# exception when trying to initialize the LLM which would prevent the API
# server from starting up
if not api_key:
api_key = GEN_AI_API_KEY
self._llm = AzureChatOpenAI(
model=model_version,
openai_api_type="azure",

View File

@@ -1,5 +1,3 @@
from typing import Any
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
@@ -11,48 +9,40 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.azure import AzureGPT
from danswer.llm.google_colab_demo import GoogleColabDemo
from danswer.llm.llm import LLM
from danswer.llm.openai import OpenAIGPT
def get_llm_from_model(model: str, **kwargs: Any) -> LLM:
if model == DanswerGenAIModel.OPENAI_CHAT.value:
if API_TYPE_OPENAI == "azure":
return AzureGPT(**kwargs)
return OpenAIGPT(**kwargs)
if (
model == DanswerGenAIModel.REQUEST.value
and kwargs.get("model_host_type") == ModelHostType.COLAB_DEMO
):
return GoogleColabDemo(**kwargs)
raise ValueError(f"Unknown LLM model: {model}")
def get_default_llm(
api_key: str | None = None, timeout: int = QA_TIMEOUT, **kwargs: Any
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
) -> LLM:
"""NOTE: api_key/timeout must be a special args since we may want to check
if an API key is valid for the default model setup OR we may want to use the
default model with a different timeout specified."""
if api_key is None:
try:
api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
# if no API key is found, assume this model doesn't need one
pass
api_key = get_gen_ai_api_key()
return get_llm_from_model(
model=INTERNAL_MODEL_VERSION,
api_key=api_key,
timeout=timeout,
model_version=GEN_AI_MODEL_VERSION,
endpoint=GEN_AI_ENDPOINT,
model_host_type=GEN_AI_HOST_TYPE,
max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS,
temperature=GEN_AI_TEMPERATURE,
**kwargs,
)
model_args = {
# provide a dummy key since LangChain will throw an exception if not
# given, which would prevent server startup
"api_key": api_key or "dummy_api_key",
"timeout": timeout,
"model_version": GEN_AI_MODEL_VERSION,
"endpoint": GEN_AI_ENDPOINT,
"max_output_tokens": GEN_AI_MAX_OUTPUT_TOKENS,
"temperature": GEN_AI_TEMPERATURE,
}
if INTERNAL_MODEL_VERSION == DanswerGenAIModel.OPENAI_CHAT.value:
if API_TYPE_OPENAI == "azure":
return AzureGPT(**model_args) # type: ignore
return OpenAIGPT(**model_args) # type: ignore
if (
INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value
and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO
):
return GoogleColabDemo(**model_args) # type: ignore
raise ValueError(f"Unknown LLM model: {INTERNAL_MODEL_VERSION}")

View File

@@ -3,7 +3,6 @@ from typing import cast
from langchain.chat_models.openai import ChatOpenAI
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.llm import LangChainChatLLM
from danswer.llm.utils import should_be_verbose
@@ -20,12 +19,6 @@ class OpenAIGPT(LangChainChatLLM):
*args: list[Any],
**kwargs: dict[str, Any]
):
# set a dummy API key if not specified so that LangChain doesn't throw an
# exception when trying to initialize the LLM which would prevent the API
# server from starting up
if not api_key:
api_key = GEN_AI_API_KEY
self._llm = ChatOpenAI(
model=model_version,
openai_api_key=api_key,

View File

@@ -502,14 +502,9 @@ def validate_existing_genai_api_key(
# First time checking the key, nothing unusual
pass
try:
genai_api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
genai_api_key = get_gen_ai_api_key()
if genai_api_key is None:
raise HTTPException(status_code=404, detail="Key not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
try:
is_valid = check_model_api_key_is_valid(genai_api_key)
@@ -520,6 +515,9 @@ def validate_existing_genai_api_key(
if not is_valid:
raise HTTPException(status_code=400, detail="Invalid API key provided")
# mark check as successful
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
@router.get("/admin/genai-api-key", response_model=ApiKey)
def get_gen_ai_api_key_from_dynamic_config_store(