mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-18 19:43:26 +02:00
Reenable option to run Danswer without Gen AI (#906)
This commit is contained in:
@@ -21,6 +21,7 @@ from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import CHUNK_SIZE
|
||||
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import DISABLED_GEN_AI_MSG
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
@@ -36,6 +37,7 @@ from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
@@ -61,10 +63,18 @@ def generate_ai_chat_response(
|
||||
history: list[ChatMessage],
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: dict[str, int],
|
||||
llm: LLM,
|
||||
llm: LLM | None,
|
||||
llm_tokenizer: Callable,
|
||||
all_doc_useful: bool,
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
|
||||
if llm is None:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
# Not an error if it's a user configuration
|
||||
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
|
||||
return
|
||||
|
||||
if query_message.prompt is None:
|
||||
raise RuntimeError("No prompt received for generating Gen AI answer.")
|
||||
|
||||
@@ -171,7 +181,11 @@ def stream_chat_message(
|
||||
"Must specify a set of documents for chat or specify search options"
|
||||
)
|
||||
|
||||
llm = get_default_llm()
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
llm = None
|
||||
|
||||
llm_tokenizer = get_default_llm_token_encode()
|
||||
document_index = get_default_document_index()
|
||||
|
||||
|
@@ -20,7 +20,6 @@ APP_API_PREFIX = os.environ.get("API_PREFIX", "")
|
||||
#####
|
||||
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
|
||||
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
|
||||
# CURRENTLY DOES NOT FULLY WORK, DON'T USE THIS
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
|
||||
|
@@ -47,6 +47,14 @@ SECTION_SEPARATOR = "\n\n"
|
||||
INDEX_SEPARATOR = "==="
|
||||
|
||||
|
||||
# Messages
|
||||
DISABLED_GEN_AI_MSG = (
|
||||
"Your System Admin has disabled the Generative AI functionalities of Danswer.\n"
|
||||
"Please contact them if you wish to have this enabled.\n"
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||
INGESTION_API = "ingestion_api"
|
||||
|
@@ -12,6 +12,7 @@ from slack_sdk.models.blocks import RadioButtonsElement
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
|
||||
@@ -106,8 +107,11 @@ def build_documents_blocks(
|
||||
message_id: int | None,
|
||||
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
|
||||
) -> list[Block]:
|
||||
header_text = (
|
||||
"Retrieved Documents" if DISABLE_GENERATIVE_AI else "Reference Documents"
|
||||
)
|
||||
seen_docs_identifiers = set()
|
||||
section_blocks: list[Block] = [HeaderBlock(text="Reference Documents")]
|
||||
section_blocks: list[Block] = [HeaderBlock(text=header_text)]
|
||||
included_docs = 0
|
||||
for rank, d in enumerate(documents):
|
||||
if d.document_id in seen_docs_identifiers:
|
||||
@@ -208,6 +212,9 @@ def build_qa_response_blocks(
|
||||
favor_recent: bool,
|
||||
skip_quotes: bool = False,
|
||||
) -> list[Block]:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return []
|
||||
|
||||
quotes_blocks: list[Block] = []
|
||||
|
||||
ai_answer_header = HeaderBlock(text="AI Answer")
|
||||
|
4
backend/danswer/llm/exceptions.py
Normal file
4
backend/danswer/llm/exceptions.py
Normal file
@@ -0,0 +1,4 @@
|
||||
class GenAIDisabledException(Exception):
|
||||
def __init__(self, message: str = "Generative AI has been turned off") -> None:
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
@@ -1,9 +1,11 @@
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.llm.chat_llm import DefaultMultiLLM
|
||||
from danswer.llm.custom_llm import CustomModelServer
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.gpt_4_all import DanswerGPT4All
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_gen_ai_api_key
|
||||
@@ -18,6 +20,9 @@ def get_default_llm(
|
||||
) -> LLM:
|
||||
"""A single place to fetch the configured LLM for Danswer
|
||||
Also allows overriding certain LLM defaults"""
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
if gen_ai_model_version_override:
|
||||
model_version = gen_ai_model_version_override
|
||||
else:
|
||||
|
@@ -231,6 +231,9 @@ def get_application() -> FastAPI:
|
||||
if GEN_AI_API_ENDPOINT:
|
||||
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
|
||||
|
||||
# Any additional model configs logged here
|
||||
get_default_llm().log_model_configs()
|
||||
|
||||
if MULTILINGUAL_QUERY_EXPANSION:
|
||||
logger.info(
|
||||
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
|
||||
@@ -258,9 +261,6 @@ def get_application() -> FastAPI:
|
||||
logger.info("GPU is not available")
|
||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
# This is for the LLM, most LLMs will not need warming up
|
||||
get_default_llm().log_model_configs()
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
nltk.download("stopwords", quiet=True)
|
||||
nltk.download("wordnet", quiet=True)
|
||||
|
@@ -29,6 +29,7 @@ from danswer.one_shot_answer.factory import get_question_answer_model
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.one_shot_answer.models import QueryRephrase
|
||||
from danswer.one_shot_answer.qa_block import no_gen_ai_response
|
||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
@@ -191,8 +192,12 @@ def stream_answer_objects(
|
||||
llm_version=llm_override,
|
||||
)
|
||||
|
||||
full_prompt_str = qa_model.build_prompt(
|
||||
query=query_msg.message, history_str=history_str, context_chunks=llm_chunks
|
||||
full_prompt_str = (
|
||||
qa_model.build_prompt(
|
||||
query=query_msg.message, history_str=history_str, context_chunks=llm_chunks
|
||||
)
|
||||
if qa_model is not None
|
||||
else "Gen AI Disabled"
|
||||
)
|
||||
|
||||
# Create the first User query message
|
||||
@@ -207,10 +212,14 @@ def stream_answer_objects(
|
||||
commit=True,
|
||||
)
|
||||
|
||||
response_packets = qa_model.answer_question_stream(
|
||||
prompt=full_prompt_str,
|
||||
llm_context_docs=llm_chunks,
|
||||
metrics_callback=llm_metrics_callback,
|
||||
response_packets = (
|
||||
qa_model.answer_question_stream(
|
||||
prompt=full_prompt_str,
|
||||
llm_context_docs=llm_chunks,
|
||||
metrics_callback=llm_metrics_callback,
|
||||
)
|
||||
if qa_model is not None
|
||||
else no_gen_ai_response()
|
||||
)
|
||||
|
||||
# Capture outputs and errors
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.one_shot_answer.interfaces import QAModel
|
||||
from danswer.one_shot_answer.qa_block import QABlock
|
||||
@@ -19,18 +20,21 @@ def get_question_answer_model(
|
||||
chain_of_thought: bool = False,
|
||||
llm_version: str | None = None,
|
||||
qa_model_version: str | None = QA_PROMPT_OVERRIDE,
|
||||
) -> QAModel:
|
||||
) -> QAModel | None:
|
||||
if chain_of_thought:
|
||||
raise NotImplementedError("COT has been disabled")
|
||||
|
||||
system_prompt = prompt.system_prompt if prompt is not None else None
|
||||
task_prompt = prompt.task_prompt if prompt is not None else None
|
||||
|
||||
llm = get_default_llm(
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
gen_ai_model_version_override=llm_version,
|
||||
)
|
||||
try:
|
||||
llm = get_default_llm(
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
gen_ai_model_version_override=llm_version,
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
return None
|
||||
|
||||
if qa_model_version == "weak":
|
||||
qa_handler: QAHandler = WeakLLMQAHandler(
|
||||
|
@@ -13,6 +13,7 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import LLMMetricsContainer
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.constants import DISABLED_GEN_AI_MSG
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
@@ -252,6 +253,10 @@ def build_dummy_prompt(
|
||||
).strip()
|
||||
|
||||
|
||||
def no_gen_ai_response() -> Iterator[DanswerAnswerPiece]:
|
||||
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
|
||||
|
||||
|
||||
class QABlock(QAModel):
|
||||
def __init__(self, llm: LLM, qa_handler: QAHandler) -> None:
|
||||
self._llm = llm
|
||||
|
@@ -1,3 +1,4 @@
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.prompts.answer_validation import ANSWER_VALIDITY_PROMPT
|
||||
@@ -41,12 +42,17 @@ def get_answer_validity(
|
||||
return False
|
||||
return True # If something is wrong, let's not toss away the answer
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
return True
|
||||
|
||||
if not answer:
|
||||
return False
|
||||
|
||||
messages = _get_answer_validation_messages(query, answer)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
||||
model_output = llm.invoke(filled_llm_prompt)
|
||||
logger.debug(model_output)
|
||||
|
||||
validity = _extract_validity(model_output)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
@@ -23,7 +24,12 @@ def get_renamed_conversation_name(
|
||||
return messages
|
||||
|
||||
if llm is None:
|
||||
llm = get_default_llm()
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
# This may be longer than what the LLM tends to produce but is the most
|
||||
# clear thing we can do
|
||||
return full_history[0].message
|
||||
|
||||
history_str = combine_message_chain(full_history)
|
||||
|
||||
|
@@ -5,6 +5,7 @@ from langchain.schema import SystemMessage
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
@@ -68,15 +69,20 @@ def check_if_need_search(
|
||||
if disable_llm_check:
|
||||
return True
|
||||
|
||||
if llm is None:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
# If Generative AI is turned off the always run Search as Danswer is being used
|
||||
# as just a search engine
|
||||
return True
|
||||
|
||||
history_str = combine_message_chain(history)
|
||||
|
||||
prompt_msgs = _get_search_messages(
|
||||
question=query_message.message, history_str=history_str
|
||||
)
|
||||
|
||||
if llm is None:
|
||||
llm = get_default_llm()
|
||||
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
|
||||
require_search_output = llm.invoke(filled_llm_prompt)
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT
|
||||
@@ -30,14 +31,20 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
# If Gen AI is disabled, none of the messages are more "useful" than any other
|
||||
# All are marked not useful (False) so that the icon for Gen AI likes this answer
|
||||
# is not shown for any result
|
||||
try:
|
||||
llm = get_default_llm(use_fast_llm=True, timeout=5)
|
||||
except GenAIDisabledException:
|
||||
return False
|
||||
|
||||
messages = _get_usefulness_messages()
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
# When running in a batch, it takes as long as the longest thread
|
||||
# And when running a large batch, one may fail and take the whole timeout
|
||||
# instead cap it to 5 seconds
|
||||
model_output = get_default_llm(use_fast_llm=True, timeout=5).invoke(
|
||||
filled_llm_prompt
|
||||
)
|
||||
model_output = llm.invoke(filled_llm_prompt)
|
||||
logger.debug(model_output)
|
||||
|
||||
return _extract_usefulness(model_output)
|
||||
|
@@ -3,6 +3,7 @@ from typing import cast
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
@@ -28,9 +29,17 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
|
||||
|
||||
return messages
|
||||
|
||||
try:
|
||||
llm = get_default_llm(use_fast_llm=True, timeout=5)
|
||||
except GenAIDisabledException:
|
||||
logger.warning(
|
||||
"Unable to perform multilingual query expansion, Gen AI disabled"
|
||||
)
|
||||
return query
|
||||
|
||||
messages = _get_rephrase_messages()
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
||||
model_output = llm.invoke(filled_llm_prompt)
|
||||
logger.debug(model_output)
|
||||
|
||||
return model_output
|
||||
@@ -81,12 +90,25 @@ def history_based_query_rephrase(
|
||||
llm: LLM | None = None,
|
||||
size_heuristic: int = 200,
|
||||
punctuation_heuristic: int = 10,
|
||||
skip_first_rephrase: bool = False,
|
||||
) -> str:
|
||||
user_query = cast(str, query_message.message)
|
||||
|
||||
if not user_query:
|
||||
raise ValueError("Can't rephrase/search an empty query")
|
||||
|
||||
if llm is None:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
# If Generative AI is turned off, just return the original query
|
||||
return user_query
|
||||
|
||||
# For some use cases, the first query should be untouched. Later queries must be rephrased
|
||||
# due to needing context but the first query has no context.
|
||||
if skip_first_rephrase and not history:
|
||||
return user_query
|
||||
|
||||
# If it's a very large query, assume it's a copy paste which we may want to find exactly
|
||||
# or at least very closely, so don't rephrase it
|
||||
if len(user_query) >= size_heuristic:
|
||||
@@ -103,9 +125,6 @@ def history_based_query_rephrase(
|
||||
question=user_query, history_str=history_str
|
||||
)
|
||||
|
||||
if llm is None:
|
||||
llm = get_default_llm()
|
||||
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
|
||||
rephrased_query = llm.invoke(filled_llm_prompt)
|
||||
|
||||
@@ -130,13 +149,17 @@ def thread_based_query_rephrase(
|
||||
if count_punctuation(user_query) >= punctuation_heuristic:
|
||||
return user_query
|
||||
|
||||
if llm is None:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
# If Generative AI is turned off, just return the original query
|
||||
return user_query
|
||||
|
||||
prompt_msgs = get_contextual_rephrase_messages(
|
||||
question=user_query, history_str=history_str
|
||||
)
|
||||
|
||||
if llm is None:
|
||||
llm = get_default_llm()
|
||||
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
|
||||
rephrased_query = llm.invoke(filled_llm_prompt)
|
||||
|
||||
|
@@ -4,6 +4,7 @@ from collections.abc import Iterator
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.prompts.constants import ANSWERABLE_PAT
|
||||
@@ -48,9 +49,14 @@ def get_query_answerability(
|
||||
if skip_check:
|
||||
return "Query Answerability Evaluation feature is turned off", True
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
return "Generative AI is turned off - skipping check", True
|
||||
|
||||
messages = get_query_validation_messages(user_query)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
||||
model_output = llm.invoke(filled_llm_prompt)
|
||||
|
||||
reasoning = extract_answerability_reasoning(model_output)
|
||||
answerable = extract_answerability_bool(model_output)
|
||||
@@ -70,10 +76,21 @@ def stream_query_answerability(
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
yield get_json_line(
|
||||
QueryValidationResponse(
|
||||
reasoning="Generative AI is turned off - skipping check",
|
||||
answerable=True,
|
||||
).dict()
|
||||
)
|
||||
return
|
||||
|
||||
messages = get_query_validation_messages(user_query)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
try:
|
||||
tokens = get_default_llm().stream(filled_llm_prompt)
|
||||
tokens = llm.stream(filled_llm_prompt)
|
||||
reasoning_pat_found = False
|
||||
model_output = ""
|
||||
hold_answerable = ""
|
||||
|
@@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import fetch_unique_document_sources
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.prompts.constants import SOURCES_KEY
|
||||
@@ -145,13 +146,18 @@ def extract_source_filter(
|
||||
logger.warning("LLM failed to provide a valid Source Filter output")
|
||||
return None
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
return None
|
||||
|
||||
valid_sources = fetch_unique_document_sources(db_session)
|
||||
if not valid_sources:
|
||||
return None
|
||||
|
||||
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
||||
model_output = llm.invoke(filled_llm_prompt)
|
||||
logger.debug(model_output)
|
||||
|
||||
return _extract_source_filters_from_llm_out(model_output)
|
||||
|
@@ -5,6 +5,7 @@ from datetime import timezone
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.prompts.filter_extration import TIME_FILTER_PROMPT
|
||||
@@ -145,9 +146,14 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
|
||||
|
||||
return None, False
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
except GenAIDisabledException:
|
||||
return None, False
|
||||
|
||||
messages = _get_time_filter_messages(query)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
||||
model_output = llm.invoke(filled_llm_prompt)
|
||||
logger.debug(model_output)
|
||||
|
||||
return _extract_time_filter_from_llm_out(model_output)
|
||||
|
@@ -9,7 +9,6 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -22,6 +21,7 @@ from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import get_gen_ai_api_key
|
||||
from danswer.llm.utils import test_llm
|
||||
@@ -100,9 +100,6 @@ def document_hidden_update(
|
||||
def validate_existing_genai_api_key(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return
|
||||
|
||||
# Only validate every so often
|
||||
check_key_time = "genai_api_key_last_check_time"
|
||||
kv_store = get_dynamic_config_store()
|
||||
@@ -120,7 +117,11 @@ def validate_existing_genai_api_key(
|
||||
|
||||
genai_api_key = get_gen_ai_api_key()
|
||||
|
||||
llm = get_default_llm(api_key=genai_api_key, timeout=10)
|
||||
try:
|
||||
llm = get_default_llm(api_key=genai_api_key, timeout=10)
|
||||
except GenAIDisabledException:
|
||||
return
|
||||
|
||||
is_valid = test_llm(llm)
|
||||
|
||||
if not is_valid:
|
||||
@@ -165,6 +166,9 @@ def store_genai_api_key(
|
||||
if not is_valid:
|
||||
raise HTTPException(400, "Invalid API key provided")
|
||||
|
||||
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
|
||||
except GenAIDisabledException:
|
||||
# If Disable Generative AI is set, no need to verify, just store the key for later use
|
||||
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
|
@@ -40,6 +40,7 @@ services:
|
||||
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
|
||||
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
|
||||
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
|
||||
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||
# Query Options
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
||||
@@ -96,6 +97,7 @@ services:
|
||||
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
|
||||
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
|
||||
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
|
||||
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||
# Query Options
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
||||
|
Reference in New Issue
Block a user