Reenable option to run Danswer without Gen AI (#906)

This commit is contained in:
Yuhong Sun
2024-01-03 18:31:16 -08:00
committed by GitHub
parent 20441df4a4
commit 6b6b3daab7
20 changed files with 181 additions and 43 deletions

View File

@@ -21,6 +21,7 @@ from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import CHUNK_SIZE from danswer.configs.chat_configs import CHUNK_SIZE
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.configs.constants import MessageType from danswer.configs.constants import MessageType
from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message from danswer.db.chat import create_new_chat_message
@@ -36,6 +37,7 @@ from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import User from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk from danswer.indexing.models import InferenceChunk
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_token_encode from danswer.llm.utils import get_default_llm_token_encode
@@ -61,10 +63,18 @@ def generate_ai_chat_response(
history: list[ChatMessage], history: list[ChatMessage],
context_docs: list[LlmDoc], context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int], doc_id_to_rank_map: dict[str, int],
llm: LLM, llm: LLM | None,
llm_tokenizer: Callable, llm_tokenizer: Callable,
all_doc_useful: bool, all_doc_useful: bool,
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]: ) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
if llm is None:
try:
llm = get_default_llm()
except GenAIDisabledException:
# Not an error if it's a user configuration
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
return
if query_message.prompt is None: if query_message.prompt is None:
raise RuntimeError("No prompt received for generating Gen AI answer.") raise RuntimeError("No prompt received for generating Gen AI answer.")
@@ -171,7 +181,11 @@ def stream_chat_message(
"Must specify a set of documents for chat or specify search options" "Must specify a set of documents for chat or specify search options"
) )
llm = get_default_llm() try:
llm = get_default_llm()
except GenAIDisabledException:
llm = None
llm_tokenizer = get_default_llm_token_encode() llm_tokenizer = get_default_llm_token_encode()
document_index = get_default_document_index() document_index = get_default_document_index()

View File

@@ -20,7 +20,6 @@ APP_API_PREFIX = os.environ.get("API_PREFIX", "")
##### #####
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
# CURRENTLY DOES NOT FULLY WORK, DON'T USE THIS
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true" DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"

View File

@@ -47,6 +47,14 @@ SECTION_SEPARATOR = "\n\n"
INDEX_SEPARATOR = "===" INDEX_SEPARATOR = "==="
# Messages
DISABLED_GEN_AI_MSG = (
"Your System Admin has disabled the Generative AI functionalities of Danswer.\n"
"Please contact them if you wish to have this enabled.\n"
"You can still use Danswer as a search engine."
)
class DocumentSource(str, Enum): class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type # Special case, document passed in via Danswer APIs without specifying a source type
INGESTION_API = "ingestion_api" INGESTION_API = "ingestion_api"

View File

@@ -12,6 +12,7 @@ from slack_sdk.models.blocks import RadioButtonsElement
from slack_sdk.models.blocks import SectionBlock from slack_sdk.models.blocks import SectionBlock
from danswer.chat.models import DanswerQuote from danswer.chat.models import DanswerQuote
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SearchFeedbackType from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
@@ -106,8 +107,11 @@ def build_documents_blocks(
message_id: int | None, message_id: int | None,
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY, num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
) -> list[Block]: ) -> list[Block]:
header_text = (
"Retrieved Documents" if DISABLE_GENERATIVE_AI else "Reference Documents"
)
seen_docs_identifiers = set() seen_docs_identifiers = set()
section_blocks: list[Block] = [HeaderBlock(text="Reference Documents")] section_blocks: list[Block] = [HeaderBlock(text=header_text)]
included_docs = 0 included_docs = 0
for rank, d in enumerate(documents): for rank, d in enumerate(documents):
if d.document_id in seen_docs_identifiers: if d.document_id in seen_docs_identifiers:
@@ -208,6 +212,9 @@ def build_qa_response_blocks(
favor_recent: bool, favor_recent: bool,
skip_quotes: bool = False, skip_quotes: bool = False,
) -> list[Block]: ) -> list[Block]:
if DISABLE_GENERATIVE_AI:
return []
quotes_blocks: list[Block] = [] quotes_blocks: list[Block] = []
ai_answer_header = HeaderBlock(text="AI Answer") ai_answer_header = HeaderBlock(text="AI Answer")

View File

@@ -0,0 +1,4 @@
class GenAIDisabledException(Exception):
def __init__(self, message: str = "Generative AI has been turned off") -> None:
self.message = message
super().__init__(self.message)

View File

@@ -1,9 +1,11 @@
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.llm.chat_llm import DefaultMultiLLM from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.custom_llm import CustomModelServer from danswer.llm.custom_llm import CustomModelServer
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.gpt_4_all import DanswerGPT4All from danswer.llm.gpt_4_all import DanswerGPT4All
from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_gen_ai_api_key from danswer.llm.utils import get_gen_ai_api_key
@@ -18,6 +20,9 @@ def get_default_llm(
) -> LLM: ) -> LLM:
"""A single place to fetch the configured LLM for Danswer """A single place to fetch the configured LLM for Danswer
Also allows overriding certain LLM defaults""" Also allows overriding certain LLM defaults"""
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
if gen_ai_model_version_override: if gen_ai_model_version_override:
model_version = gen_ai_model_version_override model_version = gen_ai_model_version_override
else: else:

View File

@@ -231,6 +231,9 @@ def get_application() -> FastAPI:
if GEN_AI_API_ENDPOINT: if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}") logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
# Any additional model configs logged here
get_default_llm().log_model_configs()
if MULTILINGUAL_QUERY_EXPANSION: if MULTILINGUAL_QUERY_EXPANSION:
logger.info( logger.info(
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}" f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
@@ -258,9 +261,6 @@ def get_application() -> FastAPI:
logger.info("GPU is not available") logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}") logger.info(f"Torch Threads: {torch.get_num_threads()}")
# This is for the LLM, most LLMs will not need warming up
get_default_llm().log_model_configs()
logger.info("Verifying query preprocessing (NLTK) data is downloaded") logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True) nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True) nltk.download("wordnet", quiet=True)

View File

@@ -29,6 +29,7 @@ from danswer.one_shot_answer.factory import get_question_answer_model
from danswer.one_shot_answer.models import DirectQARequest from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase from danswer.one_shot_answer.models import QueryRephrase
from danswer.one_shot_answer.qa_block import no_gen_ai_response
from danswer.one_shot_answer.qa_utils import combine_message_thread from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.models import RerankMetricsContainer from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import RetrievalMetricsContainer
@@ -191,8 +192,12 @@ def stream_answer_objects(
llm_version=llm_override, llm_version=llm_override,
) )
full_prompt_str = qa_model.build_prompt( full_prompt_str = (
query=query_msg.message, history_str=history_str, context_chunks=llm_chunks qa_model.build_prompt(
query=query_msg.message, history_str=history_str, context_chunks=llm_chunks
)
if qa_model is not None
else "Gen AI Disabled"
) )
# Create the first User query message # Create the first User query message
@@ -207,10 +212,14 @@ def stream_answer_objects(
commit=True, commit=True,
) )
response_packets = qa_model.answer_question_stream( response_packets = (
prompt=full_prompt_str, qa_model.answer_question_stream(
llm_context_docs=llm_chunks, prompt=full_prompt_str,
metrics_callback=llm_metrics_callback, llm_context_docs=llm_chunks,
metrics_callback=llm_metrics_callback,
)
if qa_model is not None
else no_gen_ai_response()
) )
# Capture outputs and errors # Capture outputs and errors

View File

@@ -1,6 +1,7 @@
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.db.models import Prompt from danswer.db.models import Prompt
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.one_shot_answer.interfaces import QAModel from danswer.one_shot_answer.interfaces import QAModel
from danswer.one_shot_answer.qa_block import QABlock from danswer.one_shot_answer.qa_block import QABlock
@@ -19,18 +20,21 @@ def get_question_answer_model(
chain_of_thought: bool = False, chain_of_thought: bool = False,
llm_version: str | None = None, llm_version: str | None = None,
qa_model_version: str | None = QA_PROMPT_OVERRIDE, qa_model_version: str | None = QA_PROMPT_OVERRIDE,
) -> QAModel: ) -> QAModel | None:
if chain_of_thought: if chain_of_thought:
raise NotImplementedError("COT has been disabled") raise NotImplementedError("COT has been disabled")
system_prompt = prompt.system_prompt if prompt is not None else None system_prompt = prompt.system_prompt if prompt is not None else None
task_prompt = prompt.task_prompt if prompt is not None else None task_prompt = prompt.task_prompt if prompt is not None else None
llm = get_default_llm( try:
api_key=api_key, llm = get_default_llm(
timeout=timeout, api_key=api_key,
gen_ai_model_version_override=llm_version, timeout=timeout,
) gen_ai_model_version_override=llm_version,
)
except GenAIDisabledException:
return None
if qa_model_version == "weak": if qa_model_version == "weak":
qa_handler: QAHandler = WeakLLMQAHandler( qa_handler: QAHandler = WeakLLMQAHandler(

View File

@@ -13,6 +13,7 @@ from danswer.chat.models import LlmDoc
from danswer.chat.models import LLMMetricsContainer from danswer.chat.models import LLMMetricsContainer
from danswer.chat.models import StreamingError from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.indexing.models import InferenceChunk from danswer.indexing.models import InferenceChunk
from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLM
from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import check_number_of_tokens
@@ -252,6 +253,10 @@ def build_dummy_prompt(
).strip() ).strip()
def no_gen_ai_response() -> Iterator[DanswerAnswerPiece]:
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
class QABlock(QAModel): class QABlock(QAModel):
def __init__(self, llm: LLM, qa_handler: QAHandler) -> None: def __init__(self, llm: LLM, qa_handler: QAHandler) -> None:
self._llm = llm self._llm = llm

View File

@@ -1,3 +1,4 @@
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.answer_validation import ANSWER_VALIDITY_PROMPT from danswer.prompts.answer_validation import ANSWER_VALIDITY_PROMPT
@@ -41,12 +42,17 @@ def get_answer_validity(
return False return False
return True # If something is wrong, let's not toss away the answer return True # If something is wrong, let's not toss away the answer
try:
llm = get_default_llm()
except GenAIDisabledException:
return True
if not answer: if not answer:
return False return False
messages = _get_answer_validation_messages(query, answer) messages = _get_answer_validation_messages(query, answer)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt) model_output = llm.invoke(filled_llm_prompt)
logger.debug(model_output) logger.debug(model_output)
validity = _extract_validity(model_output) validity = _extract_validity(model_output)

View File

@@ -1,5 +1,6 @@
from danswer.chat.chat_utils import combine_message_chain from danswer.chat.chat_utils import combine_message_chain
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
@@ -23,7 +24,12 @@ def get_renamed_conversation_name(
return messages return messages
if llm is None: if llm is None:
llm = get_default_llm() try:
llm = get_default_llm()
except GenAIDisabledException:
# This may be longer than what the LLM tends to produce but is the most
# clear thing we can do
return full_history[0].message
history_str = combine_message_chain(full_history) history_str = combine_message_chain(full_history)

View File

@@ -5,6 +5,7 @@ from langchain.schema import SystemMessage
from danswer.chat.chat_utils import combine_message_chain from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
@@ -68,15 +69,20 @@ def check_if_need_search(
if disable_llm_check: if disable_llm_check:
return True return True
if llm is None:
try:
llm = get_default_llm()
except GenAIDisabledException:
# If Generative AI is turned off the always run Search as Danswer is being used
# as just a search engine
return True
history_str = combine_message_chain(history) history_str = combine_message_chain(history)
prompt_msgs = _get_search_messages( prompt_msgs = _get_search_messages(
question=query_message.message, history_str=history_str question=query_message.message, history_str=history_str
) )
if llm is None:
llm = get_default_llm()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
require_search_output = llm.invoke(filled_llm_prompt) require_search_output = llm.invoke(filled_llm_prompt)

View File

@@ -1,5 +1,6 @@
from collections.abc import Callable from collections.abc import Callable
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT
@@ -30,14 +31,20 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool:
return False return False
return True return True
# If Gen AI is disabled, none of the messages are more "useful" than any other
# All are marked not useful (False) so that the icon for Gen AI likes this answer
# is not shown for any result
try:
llm = get_default_llm(use_fast_llm=True, timeout=5)
except GenAIDisabledException:
return False
messages = _get_usefulness_messages() messages = _get_usefulness_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
# When running in a batch, it takes as long as the longest thread # When running in a batch, it takes as long as the longest thread
# And when running a large batch, one may fail and take the whole timeout # And when running a large batch, one may fail and take the whole timeout
# instead cap it to 5 seconds # instead cap it to 5 seconds
model_output = get_default_llm(use_fast_llm=True, timeout=5).invoke( model_output = llm.invoke(filled_llm_prompt)
filled_llm_prompt
)
logger.debug(model_output) logger.debug(model_output)
return _extract_usefulness(model_output) return _extract_usefulness(model_output)

View File

@@ -3,6 +3,7 @@ from typing import cast
from danswer.chat.chat_utils import combine_message_chain from danswer.chat.chat_utils import combine_message_chain
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
@@ -28,9 +29,17 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
return messages return messages
try:
llm = get_default_llm(use_fast_llm=True, timeout=5)
except GenAIDisabledException:
logger.warning(
"Unable to perform multilingual query expansion, Gen AI disabled"
)
return query
messages = _get_rephrase_messages() messages = _get_rephrase_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt) model_output = llm.invoke(filled_llm_prompt)
logger.debug(model_output) logger.debug(model_output)
return model_output return model_output
@@ -81,12 +90,25 @@ def history_based_query_rephrase(
llm: LLM | None = None, llm: LLM | None = None,
size_heuristic: int = 200, size_heuristic: int = 200,
punctuation_heuristic: int = 10, punctuation_heuristic: int = 10,
skip_first_rephrase: bool = False,
) -> str: ) -> str:
user_query = cast(str, query_message.message) user_query = cast(str, query_message.message)
if not user_query: if not user_query:
raise ValueError("Can't rephrase/search an empty query") raise ValueError("Can't rephrase/search an empty query")
if llm is None:
try:
llm = get_default_llm()
except GenAIDisabledException:
# If Generative AI is turned off, just return the original query
return user_query
# For some use cases, the first query should be untouched. Later queries must be rephrased
# due to needing context but the first query has no context.
if skip_first_rephrase and not history:
return user_query
# If it's a very large query, assume it's a copy paste which we may want to find exactly # If it's a very large query, assume it's a copy paste which we may want to find exactly
# or at least very closely, so don't rephrase it # or at least very closely, so don't rephrase it
if len(user_query) >= size_heuristic: if len(user_query) >= size_heuristic:
@@ -103,9 +125,6 @@ def history_based_query_rephrase(
question=user_query, history_str=history_str question=user_query, history_str=history_str
) )
if llm is None:
llm = get_default_llm()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
rephrased_query = llm.invoke(filled_llm_prompt) rephrased_query = llm.invoke(filled_llm_prompt)
@@ -130,13 +149,17 @@ def thread_based_query_rephrase(
if count_punctuation(user_query) >= punctuation_heuristic: if count_punctuation(user_query) >= punctuation_heuristic:
return user_query return user_query
if llm is None:
try:
llm = get_default_llm()
except GenAIDisabledException:
# If Generative AI is turned off, just return the original query
return user_query
prompt_msgs = get_contextual_rephrase_messages( prompt_msgs = get_contextual_rephrase_messages(
question=user_query, history_str=history_str question=user_query, history_str=history_str
) )
if llm is None:
llm = get_default_llm()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
rephrased_query = llm.invoke(filled_llm_prompt) rephrased_query = llm.invoke(filled_llm_prompt)

View File

@@ -4,6 +4,7 @@ from collections.abc import Iterator
from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import StreamingError from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.constants import ANSWERABLE_PAT from danswer.prompts.constants import ANSWERABLE_PAT
@@ -48,9 +49,14 @@ def get_query_answerability(
if skip_check: if skip_check:
return "Query Answerability Evaluation feature is turned off", True return "Query Answerability Evaluation feature is turned off", True
try:
llm = get_default_llm()
except GenAIDisabledException:
return "Generative AI is turned off - skipping check", True
messages = get_query_validation_messages(user_query) messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt) model_output = llm.invoke(filled_llm_prompt)
reasoning = extract_answerability_reasoning(model_output) reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output) answerable = extract_answerability_bool(model_output)
@@ -70,10 +76,21 @@ def stream_query_answerability(
) )
return return
try:
llm = get_default_llm()
except GenAIDisabledException:
yield get_json_line(
QueryValidationResponse(
reasoning="Generative AI is turned off - skipping check",
answerable=True,
).dict()
)
return
messages = get_query_validation_messages(user_query) messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
try: try:
tokens = get_default_llm().stream(filled_llm_prompt) tokens = llm.stream(filled_llm_prompt)
reasoning_pat_found = False reasoning_pat_found = False
model_output = "" model_output = ""
hold_answerable = "" hold_answerable = ""

View File

@@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.db.connector import fetch_unique_document_sources from danswer.db.connector import fetch_unique_document_sources
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.constants import SOURCES_KEY from danswer.prompts.constants import SOURCES_KEY
@@ -145,13 +146,18 @@ def extract_source_filter(
logger.warning("LLM failed to provide a valid Source Filter output") logger.warning("LLM failed to provide a valid Source Filter output")
return None return None
try:
llm = get_default_llm()
except GenAIDisabledException:
return None
valid_sources = fetch_unique_document_sources(db_session) valid_sources = fetch_unique_document_sources(db_session)
if not valid_sources: if not valid_sources:
return None return None
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources) messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt) model_output = llm.invoke(filled_llm_prompt)
logger.debug(model_output) logger.debug(model_output)
return _extract_source_filters_from_llm_out(model_output) return _extract_source_filters_from_llm_out(model_output)

View File

@@ -5,6 +5,7 @@ from datetime import timezone
from dateutil.parser import parse from dateutil.parser import parse
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.filter_extration import TIME_FILTER_PROMPT from danswer.prompts.filter_extration import TIME_FILTER_PROMPT
@@ -145,9 +146,14 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
return None, False return None, False
try:
llm = get_default_llm()
except GenAIDisabledException:
return None, False
messages = _get_time_filter_messages(query) messages = _get_time_filter_messages(query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt) model_output = llm.invoke(filled_llm_prompt)
logger.debug(model_output) logger.debug(model_output)
return _extract_time_filter_from_llm_out(model_output) return _extract_time_filter_from_llm_out(model_output)

View File

@@ -9,7 +9,6 @@ from fastapi import HTTPException
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.db.connector_credential_pair import get_connector_credential_pair from danswer.db.connector_credential_pair import get_connector_credential_pair
@@ -22,6 +21,7 @@ from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.llm.utils import get_gen_ai_api_key from danswer.llm.utils import get_gen_ai_api_key
from danswer.llm.utils import test_llm from danswer.llm.utils import test_llm
@@ -100,9 +100,6 @@ def document_hidden_update(
def validate_existing_genai_api_key( def validate_existing_genai_api_key(
_: User = Depends(current_admin_user), _: User = Depends(current_admin_user),
) -> None: ) -> None:
if DISABLE_GENERATIVE_AI:
return
# Only validate every so often # Only validate every so often
check_key_time = "genai_api_key_last_check_time" check_key_time = "genai_api_key_last_check_time"
kv_store = get_dynamic_config_store() kv_store = get_dynamic_config_store()
@@ -120,7 +117,11 @@ def validate_existing_genai_api_key(
genai_api_key = get_gen_ai_api_key() genai_api_key = get_gen_ai_api_key()
llm = get_default_llm(api_key=genai_api_key, timeout=10) try:
llm = get_default_llm(api_key=genai_api_key, timeout=10)
except GenAIDisabledException:
return
is_valid = test_llm(llm) is_valid = test_llm(llm)
if not is_valid: if not is_valid:
@@ -165,6 +166,9 @@ def store_genai_api_key(
if not is_valid: if not is_valid:
raise HTTPException(400, "Invalid API key provided") raise HTTPException(400, "Invalid API key provided")
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
except GenAIDisabledException:
# If Disable Generative AI is set, no need to verify, just store the key for later use
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key) get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
except RuntimeError as e: except RuntimeError as e:
raise HTTPException(400, str(e)) raise HTTPException(400, str(e))

View File

@@ -40,6 +40,7 @@ services:
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-} - DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-} - DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
# Query Options # Query Options
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years) - DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector) - HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
@@ -96,6 +97,7 @@ services:
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-} - DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-} - DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
# Query Options # Query Options
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years) - DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector) - HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)