diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 3920fb765d1c..31073da3a9c5 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -21,6 +21,7 @@ from danswer.chat.models import QADocsResponse from danswer.chat.models import StreamingError from danswer.configs.chat_configs import CHUNK_SIZE from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT +from danswer.configs.constants import DISABLED_GEN_AI_MSG from danswer.configs.constants import MessageType from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_new_chat_message @@ -36,6 +37,7 @@ from danswer.db.models import SearchDoc as DbSearchDoc from danswer.db.models import User from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import InferenceChunk +from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_token_encode @@ -61,10 +63,18 @@ def generate_ai_chat_response( history: list[ChatMessage], context_docs: list[LlmDoc], doc_id_to_rank_map: dict[str, int], - llm: LLM, + llm: LLM | None, llm_tokenizer: Callable, all_doc_useful: bool, ) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]: + if llm is None: + try: + llm = get_default_llm() + except GenAIDisabledException: + # Not an error if it's a user configuration + yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG) + return + if query_message.prompt is None: raise RuntimeError("No prompt received for generating Gen AI answer.") @@ -171,7 +181,11 @@ def stream_chat_message( "Must specify a set of documents for chat or specify search options" ) - llm = get_default_llm() + try: + llm = get_default_llm() + except GenAIDisabledException: + llm = None + llm_tokenizer = get_default_llm_token_encode() document_index = get_default_document_index() diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index d4f869156713..ca39b84cf683 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -20,7 +20,6 @@ APP_API_PREFIX = os.environ.get("API_PREFIX", "") ##### BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day -# CURRENTLY DOES NOT FULLY WORK, DON'T USE THIS DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true" diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index d86f980c9d72..8e57bb6bdf08 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -47,6 +47,14 @@ SECTION_SEPARATOR = "\n\n" INDEX_SEPARATOR = "===" +# Messages +DISABLED_GEN_AI_MSG = ( + "Your System Admin has disabled the Generative AI functionalities of Danswer.\n" + "Please contact them if you wish to have this enabled.\n" + "You can still use Danswer as a search engine." +) + + class DocumentSource(str, Enum): # Special case, document passed in via Danswer APIs without specifying a source type INGESTION_API = "ingestion_api" diff --git a/backend/danswer/danswerbot/slack/blocks.py b/backend/danswer/danswerbot/slack/blocks.py index f20c894d2d3f..65717ba5613f 100644 --- a/backend/danswer/danswerbot/slack/blocks.py +++ b/backend/danswer/danswerbot/slack/blocks.py @@ -12,6 +12,7 @@ from slack_sdk.models.blocks import RadioButtonsElement from slack_sdk.models.blocks import SectionBlock from danswer.chat.models import DanswerQuote +from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.constants import DocumentSource from danswer.configs.constants import SearchFeedbackType from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY @@ -106,8 +107,11 @@ def build_documents_blocks( message_id: int | None, num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY, ) -> list[Block]: + header_text = ( + "Retrieved Documents" if DISABLE_GENERATIVE_AI else "Reference Documents" + ) seen_docs_identifiers = set() - section_blocks: list[Block] = [HeaderBlock(text="Reference Documents")] + section_blocks: list[Block] = [HeaderBlock(text=header_text)] included_docs = 0 for rank, d in enumerate(documents): if d.document_id in seen_docs_identifiers: @@ -208,6 +212,9 @@ def build_qa_response_blocks( favor_recent: bool, skip_quotes: bool = False, ) -> list[Block]: + if DISABLE_GENERATIVE_AI: + return [] + quotes_blocks: list[Block] = [] ai_answer_header = HeaderBlock(text="AI Answer") diff --git a/backend/danswer/llm/exceptions.py b/backend/danswer/llm/exceptions.py new file mode 100644 index 000000000000..0cdb893c83b8 --- /dev/null +++ b/backend/danswer/llm/exceptions.py @@ -0,0 +1,4 @@ +class GenAIDisabledException(Exception): + def __init__(self, message: str = "Generative AI has been turned off") -> None: + self.message = message + super().__init__(self.message) diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index bc4062e294ab..fca6a9c14243 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -1,9 +1,11 @@ +from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.llm.chat_llm import DefaultMultiLLM from danswer.llm.custom_llm import CustomModelServer +from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.gpt_4_all import DanswerGPT4All from danswer.llm.interfaces import LLM from danswer.llm.utils import get_gen_ai_api_key @@ -18,6 +20,9 @@ def get_default_llm( ) -> LLM: """A single place to fetch the configured LLM for Danswer Also allows overriding certain LLM defaults""" + if DISABLE_GENERATIVE_AI: + raise GenAIDisabledException() + if gen_ai_model_version_override: model_version = gen_ai_model_version_override else: diff --git a/backend/danswer/main.py b/backend/danswer/main.py index aeaa028b2249..342d589d4f6b 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -231,6 +231,9 @@ def get_application() -> FastAPI: if GEN_AI_API_ENDPOINT: logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}") + # Any additional model configs logged here + get_default_llm().log_model_configs() + if MULTILINGUAL_QUERY_EXPANSION: logger.info( f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}" @@ -258,9 +261,6 @@ def get_application() -> FastAPI: logger.info("GPU is not available") logger.info(f"Torch Threads: {torch.get_num_threads()}") - # This is for the LLM, most LLMs will not need warming up - get_default_llm().log_model_configs() - logger.info("Verifying query preprocessing (NLTK) data is downloaded") nltk.download("stopwords", quiet=True) nltk.download("wordnet", quiet=True) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 94a2921c6c23..1ebbd2980a03 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -29,6 +29,7 @@ from danswer.one_shot_answer.factory import get_question_answer_model from danswer.one_shot_answer.models import DirectQARequest from danswer.one_shot_answer.models import OneShotQAResponse from danswer.one_shot_answer.models import QueryRephrase +from danswer.one_shot_answer.qa_block import no_gen_ai_response from danswer.one_shot_answer.qa_utils import combine_message_thread from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer @@ -191,8 +192,12 @@ def stream_answer_objects( llm_version=llm_override, ) - full_prompt_str = qa_model.build_prompt( - query=query_msg.message, history_str=history_str, context_chunks=llm_chunks + full_prompt_str = ( + qa_model.build_prompt( + query=query_msg.message, history_str=history_str, context_chunks=llm_chunks + ) + if qa_model is not None + else "Gen AI Disabled" ) # Create the first User query message @@ -207,10 +212,14 @@ def stream_answer_objects( commit=True, ) - response_packets = qa_model.answer_question_stream( - prompt=full_prompt_str, - llm_context_docs=llm_chunks, - metrics_callback=llm_metrics_callback, + response_packets = ( + qa_model.answer_question_stream( + prompt=full_prompt_str, + llm_context_docs=llm_chunks, + metrics_callback=llm_metrics_callback, + ) + if qa_model is not None + else no_gen_ai_response() ) # Capture outputs and errors diff --git a/backend/danswer/one_shot_answer/factory.py b/backend/danswer/one_shot_answer/factory.py index 47be1fd25e36..122ed6ac06fb 100644 --- a/backend/danswer/one_shot_answer/factory.py +++ b/backend/danswer/one_shot_answer/factory.py @@ -1,6 +1,7 @@ from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.configs.chat_configs import QA_TIMEOUT from danswer.db.models import Prompt +from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.one_shot_answer.interfaces import QAModel from danswer.one_shot_answer.qa_block import QABlock @@ -19,18 +20,21 @@ def get_question_answer_model( chain_of_thought: bool = False, llm_version: str | None = None, qa_model_version: str | None = QA_PROMPT_OVERRIDE, -) -> QAModel: +) -> QAModel | None: if chain_of_thought: raise NotImplementedError("COT has been disabled") system_prompt = prompt.system_prompt if prompt is not None else None task_prompt = prompt.task_prompt if prompt is not None else None - llm = get_default_llm( - api_key=api_key, - timeout=timeout, - gen_ai_model_version_override=llm_version, - ) + try: + llm = get_default_llm( + api_key=api_key, + timeout=timeout, + gen_ai_model_version_override=llm_version, + ) + except GenAIDisabledException: + return None if qa_model_version == "weak": qa_handler: QAHandler = WeakLLMQAHandler( diff --git a/backend/danswer/one_shot_answer/qa_block.py b/backend/danswer/one_shot_answer/qa_block.py index 455b23cb126b..3a9bfdf0339c 100644 --- a/backend/danswer/one_shot_answer/qa_block.py +++ b/backend/danswer/one_shot_answer/qa_block.py @@ -13,6 +13,7 @@ from danswer.chat.models import LlmDoc from danswer.chat.models import LLMMetricsContainer from danswer.chat.models import StreamingError from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION +from danswer.configs.constants import DISABLED_GEN_AI_MSG from danswer.indexing.models import InferenceChunk from danswer.llm.interfaces import LLM from danswer.llm.utils import check_number_of_tokens @@ -252,6 +253,10 @@ def build_dummy_prompt( ).strip() +def no_gen_ai_response() -> Iterator[DanswerAnswerPiece]: + yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG) + + class QABlock(QAModel): def __init__(self, llm: LLM, qa_handler: QAHandler) -> None: self._llm = llm diff --git a/backend/danswer/secondary_llm_flows/answer_validation.py b/backend/danswer/secondary_llm_flows/answer_validation.py index 4bcd9e9cff40..88a153da4174 100644 --- a/backend/danswer/secondary_llm_flows/answer_validation.py +++ b/backend/danswer/secondary_llm_flows/answer_validation.py @@ -1,3 +1,4 @@ +from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.prompts.answer_validation import ANSWER_VALIDITY_PROMPT @@ -41,12 +42,17 @@ def get_answer_validity( return False return True # If something is wrong, let's not toss away the answer + try: + llm = get_default_llm() + except GenAIDisabledException: + return True + if not answer: return False messages = _get_answer_validation_messages(query, answer) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = get_default_llm().invoke(filled_llm_prompt) + model_output = llm.invoke(filled_llm_prompt) logger.debug(model_output) validity = _extract_validity(model_output) diff --git a/backend/danswer/secondary_llm_flows/chat_session_naming.py b/backend/danswer/secondary_llm_flows/chat_session_naming.py index c54a2afefceb..e65c8dd36b9b 100644 --- a/backend/danswer/secondary_llm_flows/chat_session_naming.py +++ b/backend/danswer/secondary_llm_flows/chat_session_naming.py @@ -1,5 +1,6 @@ from danswer.chat.chat_utils import combine_message_chain from danswer.db.models import ChatMessage +from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt @@ -23,7 +24,12 @@ def get_renamed_conversation_name( return messages if llm is None: - llm = get_default_llm() + try: + llm = get_default_llm() + except GenAIDisabledException: + # This may be longer than what the LLM tends to produce but is the most + # clear thing we can do + return full_history[0].message history_str = combine_message_chain(full_history) diff --git a/backend/danswer/secondary_llm_flows/choose_search.py b/backend/danswer/secondary_llm_flows/choose_search.py index d0bbfeea5f54..626b10775b18 100644 --- a/backend/danswer/secondary_llm_flows/choose_search.py +++ b/backend/danswer/secondary_llm_flows/choose_search.py @@ -5,6 +5,7 @@ from langchain.schema import SystemMessage from danswer.chat.chat_utils import combine_message_chain from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH from danswer.db.models import ChatMessage +from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt @@ -68,15 +69,20 @@ def check_if_need_search( if disable_llm_check: return True + if llm is None: + try: + llm = get_default_llm() + except GenAIDisabledException: + # If Generative AI is turned off the always run Search as Danswer is being used + # as just a search engine + return True + history_str = combine_message_chain(history) prompt_msgs = _get_search_messages( question=query_message.message, history_str=history_str ) - if llm is None: - llm = get_default_llm() - filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs) require_search_output = llm.invoke(filled_llm_prompt) diff --git a/backend/danswer/secondary_llm_flows/chunk_usefulness.py b/backend/danswer/secondary_llm_flows/chunk_usefulness.py index b977947bf44d..2db06bdbafe2 100644 --- a/backend/danswer/secondary_llm_flows/chunk_usefulness.py +++ b/backend/danswer/secondary_llm_flows/chunk_usefulness.py @@ -1,5 +1,6 @@ from collections.abc import Callable +from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT @@ -30,14 +31,20 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool: return False return True + # If Gen AI is disabled, none of the messages are more "useful" than any other + # All are marked not useful (False) so that the icon for Gen AI likes this answer + # is not shown for any result + try: + llm = get_default_llm(use_fast_llm=True, timeout=5) + except GenAIDisabledException: + return False + messages = _get_usefulness_messages() filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) # When running in a batch, it takes as long as the longest thread # And when running a large batch, one may fail and take the whole timeout # instead cap it to 5 seconds - model_output = get_default_llm(use_fast_llm=True, timeout=5).invoke( - filled_llm_prompt - ) + model_output = llm.invoke(filled_llm_prompt) logger.debug(model_output) return _extract_usefulness(model_output) diff --git a/backend/danswer/secondary_llm_flows/query_expansion.py b/backend/danswer/secondary_llm_flows/query_expansion.py index 1be27d545478..9a494607460f 100644 --- a/backend/danswer/secondary_llm_flows/query_expansion.py +++ b/backend/danswer/secondary_llm_flows/query_expansion.py @@ -3,6 +3,7 @@ from typing import cast from danswer.chat.chat_utils import combine_message_chain from danswer.db.models import ChatMessage +from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt @@ -28,9 +29,17 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str: return messages + try: + llm = get_default_llm(use_fast_llm=True, timeout=5) + except GenAIDisabledException: + logger.warning( + "Unable to perform multilingual query expansion, Gen AI disabled" + ) + return query + messages = _get_rephrase_messages() filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = get_default_llm().invoke(filled_llm_prompt) + model_output = llm.invoke(filled_llm_prompt) logger.debug(model_output) return model_output @@ -81,12 +90,25 @@ def history_based_query_rephrase( llm: LLM | None = None, size_heuristic: int = 200, punctuation_heuristic: int = 10, + skip_first_rephrase: bool = False, ) -> str: user_query = cast(str, query_message.message) if not user_query: raise ValueError("Can't rephrase/search an empty query") + if llm is None: + try: + llm = get_default_llm() + except GenAIDisabledException: + # If Generative AI is turned off, just return the original query + return user_query + + # For some use cases, the first query should be untouched. Later queries must be rephrased + # due to needing context but the first query has no context. + if skip_first_rephrase and not history: + return user_query + # If it's a very large query, assume it's a copy paste which we may want to find exactly # or at least very closely, so don't rephrase it if len(user_query) >= size_heuristic: @@ -103,9 +125,6 @@ def history_based_query_rephrase( question=user_query, history_str=history_str ) - if llm is None: - llm = get_default_llm() - filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs) rephrased_query = llm.invoke(filled_llm_prompt) @@ -130,13 +149,17 @@ def thread_based_query_rephrase( if count_punctuation(user_query) >= punctuation_heuristic: return user_query + if llm is None: + try: + llm = get_default_llm() + except GenAIDisabledException: + # If Generative AI is turned off, just return the original query + return user_query + prompt_msgs = get_contextual_rephrase_messages( question=user_query, history_str=history_str ) - if llm is None: - llm = get_default_llm() - filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs) rephrased_query = llm.invoke(filled_llm_prompt) diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py index 5f5aa1402198..22ba49e68cd4 100644 --- a/backend/danswer/secondary_llm_flows/query_validation.py +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -4,6 +4,7 @@ from collections.abc import Iterator from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import StreamingError from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY +from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.prompts.constants import ANSWERABLE_PAT @@ -48,9 +49,14 @@ def get_query_answerability( if skip_check: return "Query Answerability Evaluation feature is turned off", True + try: + llm = get_default_llm() + except GenAIDisabledException: + return "Generative AI is turned off - skipping check", True + messages = get_query_validation_messages(user_query) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = get_default_llm().invoke(filled_llm_prompt) + model_output = llm.invoke(filled_llm_prompt) reasoning = extract_answerability_reasoning(model_output) answerable = extract_answerability_bool(model_output) @@ -70,10 +76,21 @@ def stream_query_answerability( ) return + try: + llm = get_default_llm() + except GenAIDisabledException: + yield get_json_line( + QueryValidationResponse( + reasoning="Generative AI is turned off - skipping check", + answerable=True, + ).dict() + ) + return + messages = get_query_validation_messages(user_query) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) try: - tokens = get_default_llm().stream(filled_llm_prompt) + tokens = llm.stream(filled_llm_prompt) reasoning_pat_found = False model_output = "" hold_answerable = "" diff --git a/backend/danswer/secondary_llm_flows/source_filter.py b/backend/danswer/secondary_llm_flows/source_filter.py index 5abd7bb388f6..969bd92829ed 100644 --- a/backend/danswer/secondary_llm_flows/source_filter.py +++ b/backend/danswer/secondary_llm_flows/source_filter.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from danswer.configs.constants import DocumentSource from danswer.db.connector import fetch_unique_document_sources from danswer.db.engine import get_sqlalchemy_engine +from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.prompts.constants import SOURCES_KEY @@ -145,13 +146,18 @@ def extract_source_filter( logger.warning("LLM failed to provide a valid Source Filter output") return None + try: + llm = get_default_llm() + except GenAIDisabledException: + return None + valid_sources = fetch_unique_document_sources(db_session) if not valid_sources: return None messages = _get_source_filter_messages(query=query, valid_sources=valid_sources) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = get_default_llm().invoke(filled_llm_prompt) + model_output = llm.invoke(filled_llm_prompt) logger.debug(model_output) return _extract_source_filters_from_llm_out(model_output) diff --git a/backend/danswer/secondary_llm_flows/time_filter.py b/backend/danswer/secondary_llm_flows/time_filter.py index d68cb20f5b47..be2799f8f4e3 100644 --- a/backend/danswer/secondary_llm_flows/time_filter.py +++ b/backend/danswer/secondary_llm_flows/time_filter.py @@ -5,6 +5,7 @@ from datetime import timezone from dateutil.parser import parse +from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.prompts.filter_extration import TIME_FILTER_PROMPT @@ -145,9 +146,14 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]: return None, False + try: + llm = get_default_llm() + except GenAIDisabledException: + return None, False + messages = _get_time_filter_messages(query) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = get_default_llm().invoke(filled_llm_prompt) + model_output = llm.invoke(filled_llm_prompt) logger.debug(model_output) return _extract_time_filter_from_llm_out(model_output) diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index b0c09e6f6480..4fbe07256542 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -9,7 +9,6 @@ from fastapi import HTTPException from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user -from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY from danswer.db.connector_credential_pair import get_connector_credential_pair @@ -22,6 +21,7 @@ from danswer.db.models import User from danswer.document_index.factory import get_default_document_index from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.utils import get_gen_ai_api_key from danswer.llm.utils import test_llm @@ -100,9 +100,6 @@ def document_hidden_update( def validate_existing_genai_api_key( _: User = Depends(current_admin_user), ) -> None: - if DISABLE_GENERATIVE_AI: - return - # Only validate every so often check_key_time = "genai_api_key_last_check_time" kv_store = get_dynamic_config_store() @@ -120,7 +117,11 @@ def validate_existing_genai_api_key( genai_api_key = get_gen_ai_api_key() - llm = get_default_llm(api_key=genai_api_key, timeout=10) + try: + llm = get_default_llm(api_key=genai_api_key, timeout=10) + except GenAIDisabledException: + return + is_valid = test_llm(llm) if not is_valid: @@ -165,6 +166,9 @@ def store_genai_api_key( if not is_valid: raise HTTPException(400, "Invalid API key provided") + get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key) + except GenAIDisabledException: + # If Disable Generative AI is set, no need to verify, just store the key for later use get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key) except RuntimeError as e: raise HTTPException(400, str(e)) diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 0b7a730c9efb..94d59132871e 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -40,6 +40,7 @@ services: - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} - DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-} - DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-} + - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} # Query Options - DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years) - HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector) @@ -96,6 +97,7 @@ services: - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} - DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-} - DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-} + - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} # Query Options - DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years) - HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)