mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-09 19:30:40 +02:00
85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
from langchain.schema import BaseMessage
|
|
from langchain.schema import HumanMessage
|
|
from langchain.schema import SystemMessage
|
|
|
|
from onyx.chat.chat_utils import combine_message_chain
|
|
from onyx.chat.prompt_builder.utils import translate_onyx_msg_to_langchain
|
|
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
|
from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
|
from onyx.db.models import ChatMessage
|
|
from onyx.llm.interfaces import LLM
|
|
from onyx.llm.models import PreviousMessage
|
|
from onyx.llm.utils import dict_based_prompt_to_langchain_prompt
|
|
from onyx.llm.utils import message_to_string
|
|
from onyx.prompts.chat_prompts import AGGRESSIVE_SEARCH_TEMPLATE
|
|
from onyx.prompts.chat_prompts import NO_SEARCH
|
|
from onyx.prompts.chat_prompts import REQUIRE_SEARCH_HINT
|
|
from onyx.prompts.chat_prompts import REQUIRE_SEARCH_SYSTEM_MSG
|
|
from onyx.prompts.chat_prompts import SKIP_SEARCH
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def check_if_need_search_multi_message(
|
|
query_message: ChatMessage,
|
|
history: list[ChatMessage],
|
|
llm: LLM,
|
|
) -> bool:
|
|
# Retrieve on start or when choosing is globally disabled
|
|
if not history or DISABLE_LLM_CHOOSE_SEARCH:
|
|
return True
|
|
|
|
prompt_msgs: list[BaseMessage] = [SystemMessage(content=REQUIRE_SEARCH_SYSTEM_MSG)]
|
|
prompt_msgs.extend([translate_onyx_msg_to_langchain(msg) for msg in history])
|
|
|
|
last_query = query_message.message
|
|
|
|
prompt_msgs.append(HumanMessage(content=f"{last_query}\n\n{REQUIRE_SEARCH_HINT}"))
|
|
|
|
model_out = message_to_string(llm.invoke(prompt_msgs))
|
|
|
|
if (NO_SEARCH.split()[0] + " ").lower() in model_out.lower():
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def check_if_need_search(
|
|
query: str,
|
|
history: list[PreviousMessage],
|
|
llm: LLM,
|
|
) -> bool:
|
|
def _get_search_messages(
|
|
question: str,
|
|
history_str: str,
|
|
) -> list[dict[str, str]]:
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": AGGRESSIVE_SEARCH_TEMPLATE.format(
|
|
final_query=question, chat_history=history_str
|
|
).strip(),
|
|
},
|
|
]
|
|
|
|
return messages
|
|
|
|
# Choosing is globally disabled, use search
|
|
if DISABLE_LLM_CHOOSE_SEARCH:
|
|
return True
|
|
|
|
history_str = combine_message_chain(
|
|
messages=history, token_limit=GEN_AI_HISTORY_CUTOFF
|
|
)
|
|
|
|
prompt_msgs = _get_search_messages(question=query, history_str=history_str)
|
|
|
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
|
|
require_search_output = message_to_string(llm.invoke(filled_llm_prompt))
|
|
|
|
logger.debug(f"Run search prediction: {require_search_output}")
|
|
|
|
return (SKIP_SEARCH.split()[0]).lower() not in require_search_output.lower()
|