mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Adjust the way LLM class is instantiated + fix issue where .env file GEN_AI_API_KEY wasn't being used (#630)
This commit is contained in:
@@ -82,10 +82,7 @@ INTERNAL_MODEL_VERSION = os.environ.get(
|
||||
)
|
||||
|
||||
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
||||
GEN_AI_API_KEY = (
|
||||
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))
|
||||
or "dummy_llm_key"
|
||||
)
|
||||
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))
|
||||
|
||||
# If using GPT4All, HuggingFace Inference API, or OpenAI - specify the model version
|
||||
GEN_AI_MODEL_VERSION = os.environ.get(
|
||||
|
@@ -48,11 +48,12 @@ if API_TYPE_OPENAI in ["azure"]: # TODO: Azure AD support ["azure_ad", "azuread
|
||||
|
||||
|
||||
def _ensure_openai_api_key(api_key: str | None) -> str:
|
||||
try:
|
||||
return api_key or get_gen_ai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
final_api_key = api_key or get_gen_ai_api_key()
|
||||
if final_api_key is None:
|
||||
raise OpenAIKeyMissing()
|
||||
|
||||
return final_api_key
|
||||
|
||||
|
||||
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
|
@@ -22,6 +22,7 @@ from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_model_quote
|
||||
@@ -31,11 +32,15 @@ from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_gen_ai_api_key() -> str:
|
||||
return (
|
||||
cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
|
||||
or GEN_AI_API_KEY
|
||||
)
|
||||
def get_gen_ai_api_key() -> str | None:
|
||||
# first check if the key has been provided by the UI
|
||||
try:
|
||||
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
||||
# if not provided by the UI, fallback to the env variable
|
||||
return GEN_AI_API_KEY
|
||||
|
||||
|
||||
def extract_answer_quotes_freeform(
|
||||
|
@@ -5,7 +5,6 @@ from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
from danswer.configs.model_configs import API_BASE_OPENAI
|
||||
from danswer.configs.model_configs import API_VERSION_OPENAI
|
||||
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.llm.llm import LangChainChatLLM
|
||||
from danswer.llm.utils import should_be_verbose
|
||||
|
||||
@@ -23,11 +22,6 @@ class AzureGPT(LangChainChatLLM):
|
||||
*args: list[Any],
|
||||
**kwargs: dict[str, Any]
|
||||
):
|
||||
# set a dummy API key if not specified so that LangChain doesn't throw an
|
||||
# exception when trying to initialize the LLM which would prevent the API
|
||||
# server from starting up
|
||||
if not api_key:
|
||||
api_key = GEN_AI_API_KEY
|
||||
self._llm = AzureChatOpenAI(
|
||||
model=model_version,
|
||||
openai_api_type="azure",
|
||||
|
@@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import DanswerGenAIModel
|
||||
from danswer.configs.constants import ModelHostType
|
||||
@@ -11,48 +9,40 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.llm.azure import AzureGPT
|
||||
from danswer.llm.google_colab_demo import GoogleColabDemo
|
||||
from danswer.llm.llm import LLM
|
||||
from danswer.llm.openai import OpenAIGPT
|
||||
|
||||
|
||||
def get_llm_from_model(model: str, **kwargs: Any) -> LLM:
|
||||
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||
if API_TYPE_OPENAI == "azure":
|
||||
return AzureGPT(**kwargs)
|
||||
return OpenAIGPT(**kwargs)
|
||||
if (
|
||||
model == DanswerGenAIModel.REQUEST.value
|
||||
and kwargs.get("model_host_type") == ModelHostType.COLAB_DEMO
|
||||
):
|
||||
return GoogleColabDemo(**kwargs)
|
||||
|
||||
raise ValueError(f"Unknown LLM model: {model}")
|
||||
|
||||
|
||||
def get_default_llm(
|
||||
api_key: str | None = None, timeout: int = QA_TIMEOUT, **kwargs: Any
|
||||
api_key: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
) -> LLM:
|
||||
"""NOTE: api_key/timeout must be a special args since we may want to check
|
||||
if an API key is valid for the default model setup OR we may want to use the
|
||||
default model with a different timeout specified."""
|
||||
if api_key is None:
|
||||
try:
|
||||
api_key = get_gen_ai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
# if no API key is found, assume this model doesn't need one
|
||||
pass
|
||||
api_key = get_gen_ai_api_key()
|
||||
|
||||
return get_llm_from_model(
|
||||
model=INTERNAL_MODEL_VERSION,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
model_version=GEN_AI_MODEL_VERSION,
|
||||
endpoint=GEN_AI_ENDPOINT,
|
||||
model_host_type=GEN_AI_HOST_TYPE,
|
||||
max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
**kwargs,
|
||||
)
|
||||
model_args = {
|
||||
# provide a dummy key since LangChain will throw an exception if not
|
||||
# given, which would prevent server startup
|
||||
"api_key": api_key or "dummy_api_key",
|
||||
"timeout": timeout,
|
||||
"model_version": GEN_AI_MODEL_VERSION,
|
||||
"endpoint": GEN_AI_ENDPOINT,
|
||||
"max_output_tokens": GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
"temperature": GEN_AI_TEMPERATURE,
|
||||
}
|
||||
if INTERNAL_MODEL_VERSION == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||
if API_TYPE_OPENAI == "azure":
|
||||
return AzureGPT(**model_args) # type: ignore
|
||||
return OpenAIGPT(**model_args) # type: ignore
|
||||
if (
|
||||
INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value
|
||||
and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO
|
||||
):
|
||||
return GoogleColabDemo(**model_args) # type: ignore
|
||||
|
||||
raise ValueError(f"Unknown LLM model: {INTERNAL_MODEL_VERSION}")
|
||||
|
@@ -3,7 +3,6 @@ from typing import cast
|
||||
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.llm.llm import LangChainChatLLM
|
||||
from danswer.llm.utils import should_be_verbose
|
||||
@@ -20,12 +19,6 @@ class OpenAIGPT(LangChainChatLLM):
|
||||
*args: list[Any],
|
||||
**kwargs: dict[str, Any]
|
||||
):
|
||||
# set a dummy API key if not specified so that LangChain doesn't throw an
|
||||
# exception when trying to initialize the LLM which would prevent the API
|
||||
# server from starting up
|
||||
if not api_key:
|
||||
api_key = GEN_AI_API_KEY
|
||||
|
||||
self._llm = ChatOpenAI(
|
||||
model=model_version,
|
||||
openai_api_key=api_key,
|
||||
|
@@ -502,14 +502,9 @@ def validate_existing_genai_api_key(
|
||||
# First time checking the key, nothing unusual
|
||||
pass
|
||||
|
||||
try:
|
||||
genai_api_key = get_gen_ai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
genai_api_key = get_gen_ai_api_key()
|
||||
if genai_api_key is None:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
|
||||
|
||||
try:
|
||||
is_valid = check_model_api_key_is_valid(genai_api_key)
|
||||
@@ -520,6 +515,9 @@ def validate_existing_genai_api_key(
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail="Invalid API key provided")
|
||||
|
||||
# mark check as successful
|
||||
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
|
||||
|
||||
|
||||
@router.get("/admin/genai-api-key", response_model=ApiKey)
|
||||
def get_gen_ai_api_key_from_dynamic_config_store(
|
||||
|
Reference in New Issue
Block a user