Reenable option to run Danswer without Gen AI (#906)

This commit is contained in:
Yuhong Sun
2024-01-03 18:31:16 -08:00
committed by GitHub
parent 20441df4a4
commit 6b6b3daab7
20 changed files with 181 additions and 43 deletions

View File

@@ -21,6 +21,7 @@ from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import CHUNK_SIZE
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.configs.constants import MessageType
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -36,6 +37,7 @@ from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_token_encode
@@ -61,10 +63,18 @@ def generate_ai_chat_response(
history: list[ChatMessage],
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
llm: LLM,
llm: LLM | None,
llm_tokenizer: Callable,
all_doc_useful: bool,
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
if llm is None:
try:
llm = get_default_llm()
except GenAIDisabledException:
# Not an error if it's a user configuration
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
return
if query_message.prompt is None:
raise RuntimeError("No prompt received for generating Gen AI answer.")
@@ -171,7 +181,11 @@ def stream_chat_message(
"Must specify a set of documents for chat or specify search options"
)
llm = get_default_llm()
try:
llm = get_default_llm()
except GenAIDisabledException:
llm = None
llm_tokenizer = get_default_llm_token_encode()
document_index = get_default_document_index()

View File

@@ -20,7 +20,6 @@ APP_API_PREFIX = os.environ.get("API_PREFIX", "")
#####
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
# CURRENTLY DOES NOT FULLY WORK, DON'T USE THIS
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"

View File

@@ -47,6 +47,14 @@ SECTION_SEPARATOR = "\n\n"
INDEX_SEPARATOR = "==="
# Messages
DISABLED_GEN_AI_MSG = (
"Your System Admin has disabled the Generative AI functionalities of Danswer.\n"
"Please contact them if you wish to have this enabled.\n"
"You can still use Danswer as a search engine."
)
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
INGESTION_API = "ingestion_api"

View File

@@ -12,6 +12,7 @@ from slack_sdk.models.blocks import RadioButtonsElement
from slack_sdk.models.blocks import SectionBlock
from danswer.chat.models import DanswerQuote
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
@@ -106,8 +107,11 @@ def build_documents_blocks(
message_id: int | None,
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
) -> list[Block]:
header_text = (
"Retrieved Documents" if DISABLE_GENERATIVE_AI else "Reference Documents"
)
seen_docs_identifiers = set()
section_blocks: list[Block] = [HeaderBlock(text="Reference Documents")]
section_blocks: list[Block] = [HeaderBlock(text=header_text)]
included_docs = 0
for rank, d in enumerate(documents):
if d.document_id in seen_docs_identifiers:
@@ -208,6 +212,9 @@ def build_qa_response_blocks(
favor_recent: bool,
skip_quotes: bool = False,
) -> list[Block]:
if DISABLE_GENERATIVE_AI:
return []
quotes_blocks: list[Block] = []
ai_answer_header = HeaderBlock(text="AI Answer")

View File

@@ -0,0 +1,4 @@
class GenAIDisabledException(Exception):
def __init__(self, message: str = "Generative AI has been turned off") -> None:
self.message = message
super().__init__(self.message)

View File

@@ -1,9 +1,11 @@
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.custom_llm import CustomModelServer
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.gpt_4_all import DanswerGPT4All
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_gen_ai_api_key
@@ -18,6 +20,9 @@ def get_default_llm(
) -> LLM:
"""A single place to fetch the configured LLM for Danswer
Also allows overriding certain LLM defaults"""
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
if gen_ai_model_version_override:
model_version = gen_ai_model_version_override
else:

View File

@@ -231,6 +231,9 @@ def get_application() -> FastAPI:
if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
# Any additional model configs logged here
get_default_llm().log_model_configs()
if MULTILINGUAL_QUERY_EXPANSION:
logger.info(
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
@@ -258,9 +261,6 @@ def get_application() -> FastAPI:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")
# This is for the LLM, most LLMs will not need warming up
get_default_llm().log_model_configs()
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)

View File

@@ -29,6 +29,7 @@ from danswer.one_shot_answer.factory import get_question_answer_model
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.one_shot_answer.models import QueryRephrase
from danswer.one_shot_answer.qa_block import no_gen_ai_response
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
@@ -191,8 +192,12 @@ def stream_answer_objects(
llm_version=llm_override,
)
full_prompt_str = qa_model.build_prompt(
query=query_msg.message, history_str=history_str, context_chunks=llm_chunks
full_prompt_str = (
qa_model.build_prompt(
query=query_msg.message, history_str=history_str, context_chunks=llm_chunks
)
if qa_model is not None
else "Gen AI Disabled"
)
# Create the first User query message
@@ -207,10 +212,14 @@ def stream_answer_objects(
commit=True,
)
response_packets = qa_model.answer_question_stream(
prompt=full_prompt_str,
llm_context_docs=llm_chunks,
metrics_callback=llm_metrics_callback,
response_packets = (
qa_model.answer_question_stream(
prompt=full_prompt_str,
llm_context_docs=llm_chunks,
metrics_callback=llm_metrics_callback,
)
if qa_model is not None
else no_gen_ai_response()
)
# Capture outputs and errors

View File

@@ -1,6 +1,7 @@
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.db.models import Prompt
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.one_shot_answer.interfaces import QAModel
from danswer.one_shot_answer.qa_block import QABlock
@@ -19,18 +20,21 @@ def get_question_answer_model(
chain_of_thought: bool = False,
llm_version: str | None = None,
qa_model_version: str | None = QA_PROMPT_OVERRIDE,
) -> QAModel:
) -> QAModel | None:
if chain_of_thought:
raise NotImplementedError("COT has been disabled")
system_prompt = prompt.system_prompt if prompt is not None else None
task_prompt = prompt.task_prompt if prompt is not None else None
llm = get_default_llm(
api_key=api_key,
timeout=timeout,
gen_ai_model_version_override=llm_version,
)
try:
llm = get_default_llm(
api_key=api_key,
timeout=timeout,
gen_ai_model_version_override=llm_version,
)
except GenAIDisabledException:
return None
if qa_model_version == "weak":
qa_handler: QAHandler = WeakLLMQAHandler(

View File

@@ -13,6 +13,7 @@ from danswer.chat.models import LlmDoc
from danswer.chat.models import LLMMetricsContainer
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.indexing.models import InferenceChunk
from danswer.llm.interfaces import LLM
from danswer.llm.utils import check_number_of_tokens
@@ -252,6 +253,10 @@ def build_dummy_prompt(
).strip()
def no_gen_ai_response() -> Iterator[DanswerAnswerPiece]:
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
class QABlock(QAModel):
def __init__(self, llm: LLM, qa_handler: QAHandler) -> None:
self._llm = llm

View File

@@ -1,3 +1,4 @@
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.answer_validation import ANSWER_VALIDITY_PROMPT
@@ -41,12 +42,17 @@ def get_answer_validity(
return False
return True # If something is wrong, let's not toss away the answer
try:
llm = get_default_llm()
except GenAIDisabledException:
return True
if not answer:
return False
messages = _get_answer_validation_messages(query, answer)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)
model_output = llm.invoke(filled_llm_prompt)
logger.debug(model_output)
validity = _extract_validity(model_output)

View File

@@ -1,5 +1,6 @@
from danswer.chat.chat_utils import combine_message_chain
from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
@@ -23,7 +24,12 @@ def get_renamed_conversation_name(
return messages
if llm is None:
llm = get_default_llm()
try:
llm = get_default_llm()
except GenAIDisabledException:
# This may be longer than what the LLM tends to produce but is the most
# clear thing we can do
return full_history[0].message
history_str = combine_message_chain(full_history)

View File

@@ -5,6 +5,7 @@ from langchain.schema import SystemMessage
from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
@@ -68,15 +69,20 @@ def check_if_need_search(
if disable_llm_check:
return True
if llm is None:
try:
llm = get_default_llm()
except GenAIDisabledException:
# If Generative AI is turned off the always run Search as Danswer is being used
# as just a search engine
return True
history_str = combine_message_chain(history)
prompt_msgs = _get_search_messages(
question=query_message.message, history_str=history_str
)
if llm is None:
llm = get_default_llm()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
require_search_output = llm.invoke(filled_llm_prompt)

View File

@@ -1,5 +1,6 @@
from collections.abc import Callable
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT
@@ -30,14 +31,20 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool:
return False
return True
# If Gen AI is disabled, none of the messages are more "useful" than any other
# All are marked not useful (False) so that the icon for Gen AI likes this answer
# is not shown for any result
try:
llm = get_default_llm(use_fast_llm=True, timeout=5)
except GenAIDisabledException:
return False
messages = _get_usefulness_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
# When running in a batch, it takes as long as the longest thread
# And when running a large batch, one may fail and take the whole timeout
# instead cap it to 5 seconds
model_output = get_default_llm(use_fast_llm=True, timeout=5).invoke(
filled_llm_prompt
)
model_output = llm.invoke(filled_llm_prompt)
logger.debug(model_output)
return _extract_usefulness(model_output)

View File

@@ -3,6 +3,7 @@ from typing import cast
from danswer.chat.chat_utils import combine_message_chain
from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
@@ -28,9 +29,17 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
return messages
try:
llm = get_default_llm(use_fast_llm=True, timeout=5)
except GenAIDisabledException:
logger.warning(
"Unable to perform multilingual query expansion, Gen AI disabled"
)
return query
messages = _get_rephrase_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)
model_output = llm.invoke(filled_llm_prompt)
logger.debug(model_output)
return model_output
@@ -81,12 +90,25 @@ def history_based_query_rephrase(
llm: LLM | None = None,
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
skip_first_rephrase: bool = False,
) -> str:
user_query = cast(str, query_message.message)
if not user_query:
raise ValueError("Can't rephrase/search an empty query")
if llm is None:
try:
llm = get_default_llm()
except GenAIDisabledException:
# If Generative AI is turned off, just return the original query
return user_query
# For some use cases, the first query should be untouched. Later queries must be rephrased
# due to needing context but the first query has no context.
if skip_first_rephrase and not history:
return user_query
# If it's a very large query, assume it's a copy paste which we may want to find exactly
# or at least very closely, so don't rephrase it
if len(user_query) >= size_heuristic:
@@ -103,9 +125,6 @@ def history_based_query_rephrase(
question=user_query, history_str=history_str
)
if llm is None:
llm = get_default_llm()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
rephrased_query = llm.invoke(filled_llm_prompt)
@@ -130,13 +149,17 @@ def thread_based_query_rephrase(
if count_punctuation(user_query) >= punctuation_heuristic:
return user_query
if llm is None:
try:
llm = get_default_llm()
except GenAIDisabledException:
# If Generative AI is turned off, just return the original query
return user_query
prompt_msgs = get_contextual_rephrase_messages(
question=user_query, history_str=history_str
)
if llm is None:
llm = get_default_llm()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
rephrased_query = llm.invoke(filled_llm_prompt)

View File

@@ -4,6 +4,7 @@ from collections.abc import Iterator
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.constants import ANSWERABLE_PAT
@@ -48,9 +49,14 @@ def get_query_answerability(
if skip_check:
return "Query Answerability Evaluation feature is turned off", True
try:
llm = get_default_llm()
except GenAIDisabledException:
return "Generative AI is turned off - skipping check", True
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)
model_output = llm.invoke(filled_llm_prompt)
reasoning = extract_answerability_reasoning(model_output)
answerable = extract_answerability_bool(model_output)
@@ -70,10 +76,21 @@ def stream_query_answerability(
)
return
try:
llm = get_default_llm()
except GenAIDisabledException:
yield get_json_line(
QueryValidationResponse(
reasoning="Generative AI is turned off - skipping check",
answerable=True,
).dict()
)
return
messages = get_query_validation_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
try:
tokens = get_default_llm().stream(filled_llm_prompt)
tokens = llm.stream(filled_llm_prompt)
reasoning_pat_found = False
model_output = ""
hold_answerable = ""

View File

@@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.connector import fetch_unique_document_sources
from danswer.db.engine import get_sqlalchemy_engine
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.constants import SOURCES_KEY
@@ -145,13 +146,18 @@ def extract_source_filter(
logger.warning("LLM failed to provide a valid Source Filter output")
return None
try:
llm = get_default_llm()
except GenAIDisabledException:
return None
valid_sources = fetch_unique_document_sources(db_session)
if not valid_sources:
return None
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)
model_output = llm.invoke(filled_llm_prompt)
logger.debug(model_output)
return _extract_source_filters_from_llm_out(model_output)

View File

@@ -5,6 +5,7 @@ from datetime import timezone
from dateutil.parser import parse
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.filter_extration import TIME_FILTER_PROMPT
@@ -145,9 +146,14 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
return None, False
try:
llm = get_default_llm()
except GenAIDisabledException:
return None, False
messages = _get_time_filter_messages(query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)
model_output = llm.invoke(filled_llm_prompt)
logger.debug(model_output)
return _extract_time_filter_from_llm_out(model_output)

View File

@@ -9,7 +9,6 @@ from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.db.connector_credential_pair import get_connector_credential_pair
@@ -22,6 +21,7 @@ from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import get_gen_ai_api_key
from danswer.llm.utils import test_llm
@@ -100,9 +100,6 @@ def document_hidden_update(
def validate_existing_genai_api_key(
_: User = Depends(current_admin_user),
) -> None:
if DISABLE_GENERATIVE_AI:
return
# Only validate every so often
check_key_time = "genai_api_key_last_check_time"
kv_store = get_dynamic_config_store()
@@ -120,7 +117,11 @@ def validate_existing_genai_api_key(
genai_api_key = get_gen_ai_api_key()
llm = get_default_llm(api_key=genai_api_key, timeout=10)
try:
llm = get_default_llm(api_key=genai_api_key, timeout=10)
except GenAIDisabledException:
return
is_valid = test_llm(llm)
if not is_valid:
@@ -165,6 +166,9 @@ def store_genai_api_key(
if not is_valid:
raise HTTPException(400, "Invalid API key provided")
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
except GenAIDisabledException:
# If Disable Generative AI is set, no need to verify, just store the key for later use
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
except RuntimeError as e:
raise HTTPException(400, str(e))

View File

@@ -40,6 +40,7 @@ services:
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
# Query Options
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
@@ -96,6 +97,7 @@ services:
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
# Query Options
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)