mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 04:37:09 +02:00
Reenable option to run Danswer without Gen AI (#906)
This commit is contained in:
@@ -21,6 +21,7 @@ from danswer.chat.models import QADocsResponse
|
|||||||
from danswer.chat.models import StreamingError
|
from danswer.chat.models import StreamingError
|
||||||
from danswer.configs.chat_configs import CHUNK_SIZE
|
from danswer.configs.chat_configs import CHUNK_SIZE
|
||||||
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
|
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
|
||||||
|
from danswer.configs.constants import DISABLED_GEN_AI_MSG
|
||||||
from danswer.configs.constants import MessageType
|
from danswer.configs.constants import MessageType
|
||||||
from danswer.db.chat import create_db_search_doc
|
from danswer.db.chat import create_db_search_doc
|
||||||
from danswer.db.chat import create_new_chat_message
|
from danswer.db.chat import create_new_chat_message
|
||||||
@@ -36,6 +37,7 @@ from danswer.db.models import SearchDoc as DbSearchDoc
|
|||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.utils import get_default_llm_token_encode
|
from danswer.llm.utils import get_default_llm_token_encode
|
||||||
@@ -61,10 +63,18 @@ def generate_ai_chat_response(
|
|||||||
history: list[ChatMessage],
|
history: list[ChatMessage],
|
||||||
context_docs: list[LlmDoc],
|
context_docs: list[LlmDoc],
|
||||||
doc_id_to_rank_map: dict[str, int],
|
doc_id_to_rank_map: dict[str, int],
|
||||||
llm: LLM,
|
llm: LLM | None,
|
||||||
llm_tokenizer: Callable,
|
llm_tokenizer: Callable,
|
||||||
all_doc_useful: bool,
|
all_doc_useful: bool,
|
||||||
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
|
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
|
||||||
|
if llm is None:
|
||||||
|
try:
|
||||||
|
llm = get_default_llm()
|
||||||
|
except GenAIDisabledException:
|
||||||
|
# Not an error if it's a user configuration
|
||||||
|
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
|
||||||
|
return
|
||||||
|
|
||||||
if query_message.prompt is None:
|
if query_message.prompt is None:
|
||||||
raise RuntimeError("No prompt received for generating Gen AI answer.")
|
raise RuntimeError("No prompt received for generating Gen AI answer.")
|
||||||
|
|
||||||
@@ -171,7 +181,11 @@ def stream_chat_message(
|
|||||||
"Must specify a set of documents for chat or specify search options"
|
"Must specify a set of documents for chat or specify search options"
|
||||||
)
|
)
|
||||||
|
|
||||||
llm = get_default_llm()
|
try:
|
||||||
|
llm = get_default_llm()
|
||||||
|
except GenAIDisabledException:
|
||||||
|
llm = None
|
||||||
|
|
||||||
llm_tokenizer = get_default_llm_token_encode()
|
llm_tokenizer = get_default_llm_token_encode()
|
||||||
document_index = get_default_document_index()
|
document_index = get_default_document_index()
|
||||||
|
|
||||||
|
@@ -20,7 +20,6 @@ APP_API_PREFIX = os.environ.get("API_PREFIX", "")
|
|||||||
#####
|
#####
|
||||||
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
|
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
|
||||||
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
|
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
|
||||||
# CURRENTLY DOES NOT FULLY WORK, DON'T USE THIS
|
|
||||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
|
@@ -47,6 +47,14 @@ SECTION_SEPARATOR = "\n\n"
|
|||||||
INDEX_SEPARATOR = "==="
|
INDEX_SEPARATOR = "==="
|
||||||
|
|
||||||
|
|
||||||
|
# Messages
|
||||||
|
DISABLED_GEN_AI_MSG = (
|
||||||
|
"Your System Admin has disabled the Generative AI functionalities of Danswer.\n"
|
||||||
|
"Please contact them if you wish to have this enabled.\n"
|
||||||
|
"You can still use Danswer as a search engine."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DocumentSource(str, Enum):
|
class DocumentSource(str, Enum):
|
||||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||||
INGESTION_API = "ingestion_api"
|
INGESTION_API = "ingestion_api"
|
||||||
|
@@ -12,6 +12,7 @@ from slack_sdk.models.blocks import RadioButtonsElement
|
|||||||
from slack_sdk.models.blocks import SectionBlock
|
from slack_sdk.models.blocks import SectionBlock
|
||||||
|
|
||||||
from danswer.chat.models import DanswerQuote
|
from danswer.chat.models import DanswerQuote
|
||||||
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.configs.constants import SearchFeedbackType
|
from danswer.configs.constants import SearchFeedbackType
|
||||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
|
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
|
||||||
@@ -106,8 +107,11 @@ def build_documents_blocks(
|
|||||||
message_id: int | None,
|
message_id: int | None,
|
||||||
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
|
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
|
||||||
) -> list[Block]:
|
) -> list[Block]:
|
||||||
|
header_text = (
|
||||||
|
"Retrieved Documents" if DISABLE_GENERATIVE_AI else "Reference Documents"
|
||||||
|
)
|
||||||
seen_docs_identifiers = set()
|
seen_docs_identifiers = set()
|
||||||
section_blocks: list[Block] = [HeaderBlock(text="Reference Documents")]
|
section_blocks: list[Block] = [HeaderBlock(text=header_text)]
|
||||||
included_docs = 0
|
included_docs = 0
|
||||||
for rank, d in enumerate(documents):
|
for rank, d in enumerate(documents):
|
||||||
if d.document_id in seen_docs_identifiers:
|
if d.document_id in seen_docs_identifiers:
|
||||||
@@ -208,6 +212,9 @@ def build_qa_response_blocks(
|
|||||||
favor_recent: bool,
|
favor_recent: bool,
|
||||||
skip_quotes: bool = False,
|
skip_quotes: bool = False,
|
||||||
) -> list[Block]:
|
) -> list[Block]:
|
||||||
|
if DISABLE_GENERATIVE_AI:
|
||||||
|
return []
|
||||||
|
|
||||||
quotes_blocks: list[Block] = []
|
quotes_blocks: list[Block] = []
|
||||||
|
|
||||||
ai_answer_header = HeaderBlock(text="AI Answer")
|
ai_answer_header = HeaderBlock(text="AI Answer")
|
||||||
|
4
backend/danswer/llm/exceptions.py
Normal file
4
backend/danswer/llm/exceptions.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
class GenAIDisabledException(Exception):
|
||||||
|
def __init__(self, message: str = "Generative AI has been turned off") -> None:
|
||||||
|
self.message = message
|
||||||
|
super().__init__(self.message)
|
@@ -1,9 +1,11 @@
|
|||||||
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||||
from danswer.llm.chat_llm import DefaultMultiLLM
|
from danswer.llm.chat_llm import DefaultMultiLLM
|
||||||
from danswer.llm.custom_llm import CustomModelServer
|
from danswer.llm.custom_llm import CustomModelServer
|
||||||
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.gpt_4_all import DanswerGPT4All
|
from danswer.llm.gpt_4_all import DanswerGPT4All
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.utils import get_gen_ai_api_key
|
from danswer.llm.utils import get_gen_ai_api_key
|
||||||
@@ -18,6 +20,9 @@ def get_default_llm(
|
|||||||
) -> LLM:
|
) -> LLM:
|
||||||
"""A single place to fetch the configured LLM for Danswer
|
"""A single place to fetch the configured LLM for Danswer
|
||||||
Also allows overriding certain LLM defaults"""
|
Also allows overriding certain LLM defaults"""
|
||||||
|
if DISABLE_GENERATIVE_AI:
|
||||||
|
raise GenAIDisabledException()
|
||||||
|
|
||||||
if gen_ai_model_version_override:
|
if gen_ai_model_version_override:
|
||||||
model_version = gen_ai_model_version_override
|
model_version = gen_ai_model_version_override
|
||||||
else:
|
else:
|
||||||
|
@@ -231,6 +231,9 @@ def get_application() -> FastAPI:
|
|||||||
if GEN_AI_API_ENDPOINT:
|
if GEN_AI_API_ENDPOINT:
|
||||||
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
|
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
|
||||||
|
|
||||||
|
# Any additional model configs logged here
|
||||||
|
get_default_llm().log_model_configs()
|
||||||
|
|
||||||
if MULTILINGUAL_QUERY_EXPANSION:
|
if MULTILINGUAL_QUERY_EXPANSION:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
|
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
|
||||||
@@ -258,9 +261,6 @@ def get_application() -> FastAPI:
|
|||||||
logger.info("GPU is not available")
|
logger.info("GPU is not available")
|
||||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||||
|
|
||||||
# This is for the LLM, most LLMs will not need warming up
|
|
||||||
get_default_llm().log_model_configs()
|
|
||||||
|
|
||||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||||
nltk.download("stopwords", quiet=True)
|
nltk.download("stopwords", quiet=True)
|
||||||
nltk.download("wordnet", quiet=True)
|
nltk.download("wordnet", quiet=True)
|
||||||
|
@@ -29,6 +29,7 @@ from danswer.one_shot_answer.factory import get_question_answer_model
|
|||||||
from danswer.one_shot_answer.models import DirectQARequest
|
from danswer.one_shot_answer.models import DirectQARequest
|
||||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||||
from danswer.one_shot_answer.models import QueryRephrase
|
from danswer.one_shot_answer.models import QueryRephrase
|
||||||
|
from danswer.one_shot_answer.qa_block import no_gen_ai_response
|
||||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||||
from danswer.search.models import RerankMetricsContainer
|
from danswer.search.models import RerankMetricsContainer
|
||||||
from danswer.search.models import RetrievalMetricsContainer
|
from danswer.search.models import RetrievalMetricsContainer
|
||||||
@@ -191,8 +192,12 @@ def stream_answer_objects(
|
|||||||
llm_version=llm_override,
|
llm_version=llm_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
full_prompt_str = qa_model.build_prompt(
|
full_prompt_str = (
|
||||||
query=query_msg.message, history_str=history_str, context_chunks=llm_chunks
|
qa_model.build_prompt(
|
||||||
|
query=query_msg.message, history_str=history_str, context_chunks=llm_chunks
|
||||||
|
)
|
||||||
|
if qa_model is not None
|
||||||
|
else "Gen AI Disabled"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create the first User query message
|
# Create the first User query message
|
||||||
@@ -207,10 +212,14 @@ def stream_answer_objects(
|
|||||||
commit=True,
|
commit=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
response_packets = qa_model.answer_question_stream(
|
response_packets = (
|
||||||
prompt=full_prompt_str,
|
qa_model.answer_question_stream(
|
||||||
llm_context_docs=llm_chunks,
|
prompt=full_prompt_str,
|
||||||
metrics_callback=llm_metrics_callback,
|
llm_context_docs=llm_chunks,
|
||||||
|
metrics_callback=llm_metrics_callback,
|
||||||
|
)
|
||||||
|
if qa_model is not None
|
||||||
|
else no_gen_ai_response()
|
||||||
)
|
)
|
||||||
|
|
||||||
# Capture outputs and errors
|
# Capture outputs and errors
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||||
from danswer.db.models import Prompt
|
from danswer.db.models import Prompt
|
||||||
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.one_shot_answer.interfaces import QAModel
|
from danswer.one_shot_answer.interfaces import QAModel
|
||||||
from danswer.one_shot_answer.qa_block import QABlock
|
from danswer.one_shot_answer.qa_block import QABlock
|
||||||
@@ -19,18 +20,21 @@ def get_question_answer_model(
|
|||||||
chain_of_thought: bool = False,
|
chain_of_thought: bool = False,
|
||||||
llm_version: str | None = None,
|
llm_version: str | None = None,
|
||||||
qa_model_version: str | None = QA_PROMPT_OVERRIDE,
|
qa_model_version: str | None = QA_PROMPT_OVERRIDE,
|
||||||
) -> QAModel:
|
) -> QAModel | None:
|
||||||
if chain_of_thought:
|
if chain_of_thought:
|
||||||
raise NotImplementedError("COT has been disabled")
|
raise NotImplementedError("COT has been disabled")
|
||||||
|
|
||||||
system_prompt = prompt.system_prompt if prompt is not None else None
|
system_prompt = prompt.system_prompt if prompt is not None else None
|
||||||
task_prompt = prompt.task_prompt if prompt is not None else None
|
task_prompt = prompt.task_prompt if prompt is not None else None
|
||||||
|
|
||||||
llm = get_default_llm(
|
try:
|
||||||
api_key=api_key,
|
llm = get_default_llm(
|
||||||
timeout=timeout,
|
api_key=api_key,
|
||||||
gen_ai_model_version_override=llm_version,
|
timeout=timeout,
|
||||||
)
|
gen_ai_model_version_override=llm_version,
|
||||||
|
)
|
||||||
|
except GenAIDisabledException:
|
||||||
|
return None
|
||||||
|
|
||||||
if qa_model_version == "weak":
|
if qa_model_version == "weak":
|
||||||
qa_handler: QAHandler = WeakLLMQAHandler(
|
qa_handler: QAHandler = WeakLLMQAHandler(
|
||||||
|
@@ -13,6 +13,7 @@ from danswer.chat.models import LlmDoc
|
|||||||
from danswer.chat.models import LLMMetricsContainer
|
from danswer.chat.models import LLMMetricsContainer
|
||||||
from danswer.chat.models import StreamingError
|
from danswer.chat.models import StreamingError
|
||||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||||
|
from danswer.configs.constants import DISABLED_GEN_AI_MSG
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.utils import check_number_of_tokens
|
from danswer.llm.utils import check_number_of_tokens
|
||||||
@@ -252,6 +253,10 @@ def build_dummy_prompt(
|
|||||||
).strip()
|
).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def no_gen_ai_response() -> Iterator[DanswerAnswerPiece]:
|
||||||
|
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
|
||||||
|
|
||||||
|
|
||||||
class QABlock(QAModel):
|
class QABlock(QAModel):
|
||||||
def __init__(self, llm: LLM, qa_handler: QAHandler) -> None:
|
def __init__(self, llm: LLM, qa_handler: QAHandler) -> None:
|
||||||
self._llm = llm
|
self._llm = llm
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||||
from danswer.prompts.answer_validation import ANSWER_VALIDITY_PROMPT
|
from danswer.prompts.answer_validation import ANSWER_VALIDITY_PROMPT
|
||||||
@@ -41,12 +42,17 @@ def get_answer_validity(
|
|||||||
return False
|
return False
|
||||||
return True # If something is wrong, let's not toss away the answer
|
return True # If something is wrong, let's not toss away the answer
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = get_default_llm()
|
||||||
|
except GenAIDisabledException:
|
||||||
|
return True
|
||||||
|
|
||||||
if not answer:
|
if not answer:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
messages = _get_answer_validation_messages(query, answer)
|
messages = _get_answer_validation_messages(query, answer)
|
||||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
model_output = llm.invoke(filled_llm_prompt)
|
||||||
logger.debug(model_output)
|
logger.debug(model_output)
|
||||||
|
|
||||||
validity = _extract_validity(model_output)
|
validity = _extract_validity(model_output)
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
from danswer.chat.chat_utils import combine_message_chain
|
from danswer.chat.chat_utils import combine_message_chain
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||||
@@ -23,7 +24,12 @@ def get_renamed_conversation_name(
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
if llm is None:
|
if llm is None:
|
||||||
llm = get_default_llm()
|
try:
|
||||||
|
llm = get_default_llm()
|
||||||
|
except GenAIDisabledException:
|
||||||
|
# This may be longer than what the LLM tends to produce but is the most
|
||||||
|
# clear thing we can do
|
||||||
|
return full_history[0].message
|
||||||
|
|
||||||
history_str = combine_message_chain(full_history)
|
history_str = combine_message_chain(full_history)
|
||||||
|
|
||||||
|
@@ -5,6 +5,7 @@ from langchain.schema import SystemMessage
|
|||||||
from danswer.chat.chat_utils import combine_message_chain
|
from danswer.chat.chat_utils import combine_message_chain
|
||||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||||
@@ -68,15 +69,20 @@ def check_if_need_search(
|
|||||||
if disable_llm_check:
|
if disable_llm_check:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if llm is None:
|
||||||
|
try:
|
||||||
|
llm = get_default_llm()
|
||||||
|
except GenAIDisabledException:
|
||||||
|
# If Generative AI is turned off the always run Search as Danswer is being used
|
||||||
|
# as just a search engine
|
||||||
|
return True
|
||||||
|
|
||||||
history_str = combine_message_chain(history)
|
history_str = combine_message_chain(history)
|
||||||
|
|
||||||
prompt_msgs = _get_search_messages(
|
prompt_msgs = _get_search_messages(
|
||||||
question=query_message.message, history_str=history_str
|
question=query_message.message, history_str=history_str
|
||||||
)
|
)
|
||||||
|
|
||||||
if llm is None:
|
|
||||||
llm = get_default_llm()
|
|
||||||
|
|
||||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
|
||||||
require_search_output = llm.invoke(filled_llm_prompt)
|
require_search_output = llm.invoke(filled_llm_prompt)
|
||||||
|
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||||
from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT
|
from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT
|
||||||
@@ -30,14 +31,20 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# If Gen AI is disabled, none of the messages are more "useful" than any other
|
||||||
|
# All are marked not useful (False) so that the icon for Gen AI likes this answer
|
||||||
|
# is not shown for any result
|
||||||
|
try:
|
||||||
|
llm = get_default_llm(use_fast_llm=True, timeout=5)
|
||||||
|
except GenAIDisabledException:
|
||||||
|
return False
|
||||||
|
|
||||||
messages = _get_usefulness_messages()
|
messages = _get_usefulness_messages()
|
||||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||||
# When running in a batch, it takes as long as the longest thread
|
# When running in a batch, it takes as long as the longest thread
|
||||||
# And when running a large batch, one may fail and take the whole timeout
|
# And when running a large batch, one may fail and take the whole timeout
|
||||||
# instead cap it to 5 seconds
|
# instead cap it to 5 seconds
|
||||||
model_output = get_default_llm(use_fast_llm=True, timeout=5).invoke(
|
model_output = llm.invoke(filled_llm_prompt)
|
||||||
filled_llm_prompt
|
|
||||||
)
|
|
||||||
logger.debug(model_output)
|
logger.debug(model_output)
|
||||||
|
|
||||||
return _extract_usefulness(model_output)
|
return _extract_usefulness(model_output)
|
||||||
|
@@ -3,6 +3,7 @@ from typing import cast
|
|||||||
|
|
||||||
from danswer.chat.chat_utils import combine_message_chain
|
from danswer.chat.chat_utils import combine_message_chain
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||||
@@ -28,9 +29,17 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
|
|||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = get_default_llm(use_fast_llm=True, timeout=5)
|
||||||
|
except GenAIDisabledException:
|
||||||
|
logger.warning(
|
||||||
|
"Unable to perform multilingual query expansion, Gen AI disabled"
|
||||||
|
)
|
||||||
|
return query
|
||||||
|
|
||||||
messages = _get_rephrase_messages()
|
messages = _get_rephrase_messages()
|
||||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
model_output = llm.invoke(filled_llm_prompt)
|
||||||
logger.debug(model_output)
|
logger.debug(model_output)
|
||||||
|
|
||||||
return model_output
|
return model_output
|
||||||
@@ -81,12 +90,25 @@ def history_based_query_rephrase(
|
|||||||
llm: LLM | None = None,
|
llm: LLM | None = None,
|
||||||
size_heuristic: int = 200,
|
size_heuristic: int = 200,
|
||||||
punctuation_heuristic: int = 10,
|
punctuation_heuristic: int = 10,
|
||||||
|
skip_first_rephrase: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
user_query = cast(str, query_message.message)
|
user_query = cast(str, query_message.message)
|
||||||
|
|
||||||
if not user_query:
|
if not user_query:
|
||||||
raise ValueError("Can't rephrase/search an empty query")
|
raise ValueError("Can't rephrase/search an empty query")
|
||||||
|
|
||||||
|
if llm is None:
|
||||||
|
try:
|
||||||
|
llm = get_default_llm()
|
||||||
|
except GenAIDisabledException:
|
||||||
|
# If Generative AI is turned off, just return the original query
|
||||||
|
return user_query
|
||||||
|
|
||||||
|
# For some use cases, the first query should be untouched. Later queries must be rephrased
|
||||||
|
# due to needing context but the first query has no context.
|
||||||
|
if skip_first_rephrase and not history:
|
||||||
|
return user_query
|
||||||
|
|
||||||
# If it's a very large query, assume it's a copy paste which we may want to find exactly
|
# If it's a very large query, assume it's a copy paste which we may want to find exactly
|
||||||
# or at least very closely, so don't rephrase it
|
# or at least very closely, so don't rephrase it
|
||||||
if len(user_query) >= size_heuristic:
|
if len(user_query) >= size_heuristic:
|
||||||
@@ -103,9 +125,6 @@ def history_based_query_rephrase(
|
|||||||
question=user_query, history_str=history_str
|
question=user_query, history_str=history_str
|
||||||
)
|
)
|
||||||
|
|
||||||
if llm is None:
|
|
||||||
llm = get_default_llm()
|
|
||||||
|
|
||||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
|
||||||
rephrased_query = llm.invoke(filled_llm_prompt)
|
rephrased_query = llm.invoke(filled_llm_prompt)
|
||||||
|
|
||||||
@@ -130,13 +149,17 @@ def thread_based_query_rephrase(
|
|||||||
if count_punctuation(user_query) >= punctuation_heuristic:
|
if count_punctuation(user_query) >= punctuation_heuristic:
|
||||||
return user_query
|
return user_query
|
||||||
|
|
||||||
|
if llm is None:
|
||||||
|
try:
|
||||||
|
llm = get_default_llm()
|
||||||
|
except GenAIDisabledException:
|
||||||
|
# If Generative AI is turned off, just return the original query
|
||||||
|
return user_query
|
||||||
|
|
||||||
prompt_msgs = get_contextual_rephrase_messages(
|
prompt_msgs = get_contextual_rephrase_messages(
|
||||||
question=user_query, history_str=history_str
|
question=user_query, history_str=history_str
|
||||||
)
|
)
|
||||||
|
|
||||||
if llm is None:
|
|
||||||
llm = get_default_llm()
|
|
||||||
|
|
||||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
|
||||||
rephrased_query = llm.invoke(filled_llm_prompt)
|
rephrased_query = llm.invoke(filled_llm_prompt)
|
||||||
|
|
||||||
|
@@ -4,6 +4,7 @@ from collections.abc import Iterator
|
|||||||
from danswer.chat.models import DanswerAnswerPiece
|
from danswer.chat.models import DanswerAnswerPiece
|
||||||
from danswer.chat.models import StreamingError
|
from danswer.chat.models import StreamingError
|
||||||
from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY
|
from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY
|
||||||
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||||
from danswer.prompts.constants import ANSWERABLE_PAT
|
from danswer.prompts.constants import ANSWERABLE_PAT
|
||||||
@@ -48,9 +49,14 @@ def get_query_answerability(
|
|||||||
if skip_check:
|
if skip_check:
|
||||||
return "Query Answerability Evaluation feature is turned off", True
|
return "Query Answerability Evaluation feature is turned off", True
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = get_default_llm()
|
||||||
|
except GenAIDisabledException:
|
||||||
|
return "Generative AI is turned off - skipping check", True
|
||||||
|
|
||||||
messages = get_query_validation_messages(user_query)
|
messages = get_query_validation_messages(user_query)
|
||||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
model_output = llm.invoke(filled_llm_prompt)
|
||||||
|
|
||||||
reasoning = extract_answerability_reasoning(model_output)
|
reasoning = extract_answerability_reasoning(model_output)
|
||||||
answerable = extract_answerability_bool(model_output)
|
answerable = extract_answerability_bool(model_output)
|
||||||
@@ -70,10 +76,21 @@ def stream_query_answerability(
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = get_default_llm()
|
||||||
|
except GenAIDisabledException:
|
||||||
|
yield get_json_line(
|
||||||
|
QueryValidationResponse(
|
||||||
|
reasoning="Generative AI is turned off - skipping check",
|
||||||
|
answerable=True,
|
||||||
|
).dict()
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
messages = get_query_validation_messages(user_query)
|
messages = get_query_validation_messages(user_query)
|
||||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||||
try:
|
try:
|
||||||
tokens = get_default_llm().stream(filled_llm_prompt)
|
tokens = llm.stream(filled_llm_prompt)
|
||||||
reasoning_pat_found = False
|
reasoning_pat_found = False
|
||||||
model_output = ""
|
model_output = ""
|
||||||
hold_answerable = ""
|
hold_answerable = ""
|
||||||
|
@@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
|
|||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.db.connector import fetch_unique_document_sources
|
from danswer.db.connector import fetch_unique_document_sources
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||||
from danswer.prompts.constants import SOURCES_KEY
|
from danswer.prompts.constants import SOURCES_KEY
|
||||||
@@ -145,13 +146,18 @@ def extract_source_filter(
|
|||||||
logger.warning("LLM failed to provide a valid Source Filter output")
|
logger.warning("LLM failed to provide a valid Source Filter output")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = get_default_llm()
|
||||||
|
except GenAIDisabledException:
|
||||||
|
return None
|
||||||
|
|
||||||
valid_sources = fetch_unique_document_sources(db_session)
|
valid_sources = fetch_unique_document_sources(db_session)
|
||||||
if not valid_sources:
|
if not valid_sources:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
|
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
|
||||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
model_output = llm.invoke(filled_llm_prompt)
|
||||||
logger.debug(model_output)
|
logger.debug(model_output)
|
||||||
|
|
||||||
return _extract_source_filters_from_llm_out(model_output)
|
return _extract_source_filters_from_llm_out(model_output)
|
||||||
|
@@ -5,6 +5,7 @@ from datetime import timezone
|
|||||||
|
|
||||||
from dateutil.parser import parse
|
from dateutil.parser import parse
|
||||||
|
|
||||||
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||||
from danswer.prompts.filter_extration import TIME_FILTER_PROMPT
|
from danswer.prompts.filter_extration import TIME_FILTER_PROMPT
|
||||||
@@ -145,9 +146,14 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
|
|||||||
|
|
||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = get_default_llm()
|
||||||
|
except GenAIDisabledException:
|
||||||
|
return None, False
|
||||||
|
|
||||||
messages = _get_time_filter_messages(query)
|
messages = _get_time_filter_messages(query)
|
||||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
model_output = llm.invoke(filled_llm_prompt)
|
||||||
logger.debug(model_output)
|
logger.debug(model_output)
|
||||||
|
|
||||||
return _extract_time_filter_from_llm_out(model_output)
|
return _extract_time_filter_from_llm_out(model_output)
|
||||||
|
@@ -9,7 +9,6 @@ from fastapi import HTTPException
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.auth.users import current_admin_user
|
from danswer.auth.users import current_admin_user
|
||||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
|
||||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||||
@@ -22,6 +21,7 @@ from danswer.db.models import User
|
|||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
from danswer.dynamic_configs import get_dynamic_config_store
|
from danswer.dynamic_configs import get_dynamic_config_store
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.utils import get_gen_ai_api_key
|
from danswer.llm.utils import get_gen_ai_api_key
|
||||||
from danswer.llm.utils import test_llm
|
from danswer.llm.utils import test_llm
|
||||||
@@ -100,9 +100,6 @@ def document_hidden_update(
|
|||||||
def validate_existing_genai_api_key(
|
def validate_existing_genai_api_key(
|
||||||
_: User = Depends(current_admin_user),
|
_: User = Depends(current_admin_user),
|
||||||
) -> None:
|
) -> None:
|
||||||
if DISABLE_GENERATIVE_AI:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Only validate every so often
|
# Only validate every so often
|
||||||
check_key_time = "genai_api_key_last_check_time"
|
check_key_time = "genai_api_key_last_check_time"
|
||||||
kv_store = get_dynamic_config_store()
|
kv_store = get_dynamic_config_store()
|
||||||
@@ -120,7 +117,11 @@ def validate_existing_genai_api_key(
|
|||||||
|
|
||||||
genai_api_key = get_gen_ai_api_key()
|
genai_api_key = get_gen_ai_api_key()
|
||||||
|
|
||||||
llm = get_default_llm(api_key=genai_api_key, timeout=10)
|
try:
|
||||||
|
llm = get_default_llm(api_key=genai_api_key, timeout=10)
|
||||||
|
except GenAIDisabledException:
|
||||||
|
return
|
||||||
|
|
||||||
is_valid = test_llm(llm)
|
is_valid = test_llm(llm)
|
||||||
|
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
@@ -165,6 +166,9 @@ def store_genai_api_key(
|
|||||||
if not is_valid:
|
if not is_valid:
|
||||||
raise HTTPException(400, "Invalid API key provided")
|
raise HTTPException(400, "Invalid API key provided")
|
||||||
|
|
||||||
|
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
|
||||||
|
except GenAIDisabledException:
|
||||||
|
# If Disable Generative AI is set, no need to verify, just store the key for later use
|
||||||
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
|
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise HTTPException(400, str(e))
|
raise HTTPException(400, str(e))
|
||||||
|
@@ -40,6 +40,7 @@ services:
|
|||||||
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
|
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
|
||||||
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
|
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
|
||||||
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
|
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
|
||||||
|
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||||
# Query Options
|
# Query Options
|
||||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
||||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
||||||
@@ -96,6 +97,7 @@ services:
|
|||||||
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
|
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
|
||||||
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
|
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
|
||||||
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
|
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
|
||||||
|
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||||
# Query Options
|
# Query Options
|
||||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
||||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
||||||
|
Reference in New Issue
Block a user