mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-26 11:58:28 +02:00
Adjust the way LLM class is instantiated + fix issue where .env file GEN_AI_API_KEY wasn't being used (#630)
This commit is contained in:
@@ -82,10 +82,7 @@ INTERNAL_MODEL_VERSION = os.environ.get(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
||||||
GEN_AI_API_KEY = (
|
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))
|
||||||
os.environ.get("GEN_AI_API_KEY", os.environ.get("OPENAI_API_KEY"))
|
|
||||||
or "dummy_llm_key"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If using GPT4All, HuggingFace Inference API, or OpenAI - specify the model version
|
# If using GPT4All, HuggingFace Inference API, or OpenAI - specify the model version
|
||||||
GEN_AI_MODEL_VERSION = os.environ.get(
|
GEN_AI_MODEL_VERSION = os.environ.get(
|
||||||
|
@@ -48,11 +48,12 @@ if API_TYPE_OPENAI in ["azure"]: # TODO: Azure AD support ["azure_ad", "azuread
|
|||||||
|
|
||||||
|
|
||||||
def _ensure_openai_api_key(api_key: str | None) -> str:
|
def _ensure_openai_api_key(api_key: str | None) -> str:
|
||||||
try:
|
final_api_key = api_key or get_gen_ai_api_key()
|
||||||
return api_key or get_gen_ai_api_key()
|
if final_api_key is None:
|
||||||
except ConfigNotFoundError:
|
|
||||||
raise OpenAIKeyMissing()
|
raise OpenAIKeyMissing()
|
||||||
|
|
||||||
|
return final_api_key
|
||||||
|
|
||||||
|
|
||||||
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
|
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
@@ -22,6 +22,7 @@ from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
|||||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||||
from danswer.dynamic_configs import get_dynamic_config_store
|
from danswer.dynamic_configs import get_dynamic_config_store
|
||||||
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
from danswer.llm.utils import check_number_of_tokens
|
from danswer.llm.utils import check_number_of_tokens
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.text_processing import clean_model_quote
|
from danswer.utils.text_processing import clean_model_quote
|
||||||
@@ -31,11 +32,15 @@ from danswer.utils.text_processing import shared_precompare_cleanup
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
def get_gen_ai_api_key() -> str:
|
def get_gen_ai_api_key() -> str | None:
|
||||||
return (
|
# first check if the key has been provided by the UI
|
||||||
cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
|
try:
|
||||||
or GEN_AI_API_KEY
|
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
|
||||||
)
|
except ConfigNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# if not provided by the UI, fallback to the env variable
|
||||||
|
return GEN_AI_API_KEY
|
||||||
|
|
||||||
|
|
||||||
def extract_answer_quotes_freeform(
|
def extract_answer_quotes_freeform(
|
||||||
|
@@ -5,7 +5,6 @@ from langchain.chat_models.azure_openai import AzureChatOpenAI
|
|||||||
from danswer.configs.model_configs import API_BASE_OPENAI
|
from danswer.configs.model_configs import API_BASE_OPENAI
|
||||||
from danswer.configs.model_configs import API_VERSION_OPENAI
|
from danswer.configs.model_configs import API_VERSION_OPENAI
|
||||||
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
|
from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
|
||||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
|
||||||
from danswer.llm.llm import LangChainChatLLM
|
from danswer.llm.llm import LangChainChatLLM
|
||||||
from danswer.llm.utils import should_be_verbose
|
from danswer.llm.utils import should_be_verbose
|
||||||
|
|
||||||
@@ -23,11 +22,6 @@ class AzureGPT(LangChainChatLLM):
|
|||||||
*args: list[Any],
|
*args: list[Any],
|
||||||
**kwargs: dict[str, Any]
|
**kwargs: dict[str, Any]
|
||||||
):
|
):
|
||||||
# set a dummy API key if not specified so that LangChain doesn't throw an
|
|
||||||
# exception when trying to initialize the LLM which would prevent the API
|
|
||||||
# server from starting up
|
|
||||||
if not api_key:
|
|
||||||
api_key = GEN_AI_API_KEY
|
|
||||||
self._llm = AzureChatOpenAI(
|
self._llm = AzureChatOpenAI(
|
||||||
model=model_version,
|
model=model_version,
|
||||||
openai_api_type="azure",
|
openai_api_type="azure",
|
||||||
|
@@ -1,5 +1,3 @@
|
|||||||
from typing import Any
|
|
||||||
|
|
||||||
from danswer.configs.app_configs import QA_TIMEOUT
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
from danswer.configs.constants import DanswerGenAIModel
|
from danswer.configs.constants import DanswerGenAIModel
|
||||||
from danswer.configs.constants import ModelHostType
|
from danswer.configs.constants import ModelHostType
|
||||||
@@ -11,48 +9,40 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
|||||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
|
||||||
from danswer.llm.azure import AzureGPT
|
from danswer.llm.azure import AzureGPT
|
||||||
from danswer.llm.google_colab_demo import GoogleColabDemo
|
from danswer.llm.google_colab_demo import GoogleColabDemo
|
||||||
from danswer.llm.llm import LLM
|
from danswer.llm.llm import LLM
|
||||||
from danswer.llm.openai import OpenAIGPT
|
from danswer.llm.openai import OpenAIGPT
|
||||||
|
|
||||||
|
|
||||||
def get_llm_from_model(model: str, **kwargs: Any) -> LLM:
|
|
||||||
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
|
||||||
if API_TYPE_OPENAI == "azure":
|
|
||||||
return AzureGPT(**kwargs)
|
|
||||||
return OpenAIGPT(**kwargs)
|
|
||||||
if (
|
|
||||||
model == DanswerGenAIModel.REQUEST.value
|
|
||||||
and kwargs.get("model_host_type") == ModelHostType.COLAB_DEMO
|
|
||||||
):
|
|
||||||
return GoogleColabDemo(**kwargs)
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown LLM model: {model}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_llm(
|
def get_default_llm(
|
||||||
api_key: str | None = None, timeout: int = QA_TIMEOUT, **kwargs: Any
|
api_key: str | None = None,
|
||||||
|
timeout: int = QA_TIMEOUT,
|
||||||
) -> LLM:
|
) -> LLM:
|
||||||
"""NOTE: api_key/timeout must be a special args since we may want to check
|
"""NOTE: api_key/timeout must be a special args since we may want to check
|
||||||
if an API key is valid for the default model setup OR we may want to use the
|
if an API key is valid for the default model setup OR we may want to use the
|
||||||
default model with a different timeout specified."""
|
default model with a different timeout specified."""
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
try:
|
|
||||||
api_key = get_gen_ai_api_key()
|
api_key = get_gen_ai_api_key()
|
||||||
except ConfigNotFoundError:
|
|
||||||
# if no API key is found, assume this model doesn't need one
|
|
||||||
pass
|
|
||||||
|
|
||||||
return get_llm_from_model(
|
model_args = {
|
||||||
model=INTERNAL_MODEL_VERSION,
|
# provide a dummy key since LangChain will throw an exception if not
|
||||||
api_key=api_key,
|
# given, which would prevent server startup
|
||||||
timeout=timeout,
|
"api_key": api_key or "dummy_api_key",
|
||||||
model_version=GEN_AI_MODEL_VERSION,
|
"timeout": timeout,
|
||||||
endpoint=GEN_AI_ENDPOINT,
|
"model_version": GEN_AI_MODEL_VERSION,
|
||||||
model_host_type=GEN_AI_HOST_TYPE,
|
"endpoint": GEN_AI_ENDPOINT,
|
||||||
max_output_tokens=GEN_AI_MAX_OUTPUT_TOKENS,
|
"max_output_tokens": GEN_AI_MAX_OUTPUT_TOKENS,
|
||||||
temperature=GEN_AI_TEMPERATURE,
|
"temperature": GEN_AI_TEMPERATURE,
|
||||||
**kwargs,
|
}
|
||||||
)
|
if INTERNAL_MODEL_VERSION == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||||
|
if API_TYPE_OPENAI == "azure":
|
||||||
|
return AzureGPT(**model_args) # type: ignore
|
||||||
|
return OpenAIGPT(**model_args) # type: ignore
|
||||||
|
if (
|
||||||
|
INTERNAL_MODEL_VERSION == DanswerGenAIModel.REQUEST.value
|
||||||
|
and GEN_AI_HOST_TYPE == ModelHostType.COLAB_DEMO
|
||||||
|
):
|
||||||
|
return GoogleColabDemo(**model_args) # type: ignore
|
||||||
|
|
||||||
|
raise ValueError(f"Unknown LLM model: {INTERNAL_MODEL_VERSION}")
|
||||||
|
@@ -3,7 +3,6 @@ from typing import cast
|
|||||||
|
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
|
|
||||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
|
||||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||||
from danswer.llm.llm import LangChainChatLLM
|
from danswer.llm.llm import LangChainChatLLM
|
||||||
from danswer.llm.utils import should_be_verbose
|
from danswer.llm.utils import should_be_verbose
|
||||||
@@ -20,12 +19,6 @@ class OpenAIGPT(LangChainChatLLM):
|
|||||||
*args: list[Any],
|
*args: list[Any],
|
||||||
**kwargs: dict[str, Any]
|
**kwargs: dict[str, Any]
|
||||||
):
|
):
|
||||||
# set a dummy API key if not specified so that LangChain doesn't throw an
|
|
||||||
# exception when trying to initialize the LLM which would prevent the API
|
|
||||||
# server from starting up
|
|
||||||
if not api_key:
|
|
||||||
api_key = GEN_AI_API_KEY
|
|
||||||
|
|
||||||
self._llm = ChatOpenAI(
|
self._llm = ChatOpenAI(
|
||||||
model=model_version,
|
model=model_version,
|
||||||
openai_api_key=api_key,
|
openai_api_key=api_key,
|
||||||
|
@@ -502,14 +502,9 @@ def validate_existing_genai_api_key(
|
|||||||
# First time checking the key, nothing unusual
|
# First time checking the key, nothing unusual
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
|
||||||
genai_api_key = get_gen_ai_api_key()
|
genai_api_key = get_gen_ai_api_key()
|
||||||
except ConfigNotFoundError:
|
if genai_api_key is None:
|
||||||
raise HTTPException(status_code=404, detail="Key not found")
|
raise HTTPException(status_code=404, detail="Key not found")
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
|
|
||||||
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_valid = check_model_api_key_is_valid(genai_api_key)
|
is_valid = check_model_api_key_is_valid(genai_api_key)
|
||||||
@@ -520,6 +515,9 @@ def validate_existing_genai_api_key(
|
|||||||
if not is_valid:
|
if not is_valid:
|
||||||
raise HTTPException(status_code=400, detail="Invalid API key provided")
|
raise HTTPException(status_code=400, detail="Invalid API key provided")
|
||||||
|
|
||||||
|
# mark check as successful
|
||||||
|
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
|
||||||
|
|
||||||
|
|
||||||
@router.get("/admin/genai-api-key", response_model=ApiKey)
|
@router.get("/admin/genai-api-key", response_model=ApiKey)
|
||||||
def get_gen_ai_api_key_from_dynamic_config_store(
|
def get_gen_ai_api_key_from_dynamic_config_store(
|
||||||
|
Reference in New Issue
Block a user