2025-02-03 20:07:57 -08:00

85 lines
2.7 KiB
Python

from langchain.schema import BaseMessage
from langchain.schema import HumanMessage
from langchain.schema import SystemMessage
from onyx.chat.chat_utils import combine_message_chain
from onyx.chat.prompt_builder.utils import translate_onyx_msg_to_langchain
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from onyx.db.models import ChatMessage
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import dict_based_prompt_to_langchain_prompt
from onyx.llm.utils import message_to_string
from onyx.prompts.chat_prompts import AGGRESSIVE_SEARCH_TEMPLATE
from onyx.prompts.chat_prompts import NO_SEARCH
from onyx.prompts.chat_prompts import REQUIRE_SEARCH_HINT
from onyx.prompts.chat_prompts import REQUIRE_SEARCH_SYSTEM_MSG
from onyx.prompts.chat_prompts import SKIP_SEARCH
from onyx.utils.logger import setup_logger
logger = setup_logger()
def check_if_need_search_multi_message(
query_message: ChatMessage,
history: list[ChatMessage],
llm: LLM,
) -> bool:
# Retrieve on start or when choosing is globally disabled
if not history or DISABLE_LLM_CHOOSE_SEARCH:
return True
prompt_msgs: list[BaseMessage] = [SystemMessage(content=REQUIRE_SEARCH_SYSTEM_MSG)]
prompt_msgs.extend([translate_onyx_msg_to_langchain(msg) for msg in history])
last_query = query_message.message
prompt_msgs.append(HumanMessage(content=f"{last_query}\n\n{REQUIRE_SEARCH_HINT}"))
model_out = message_to_string(llm.invoke(prompt_msgs))
if (NO_SEARCH.split()[0] + " ").lower() in model_out.lower():
return False
return True
def check_if_need_search(
query: str,
history: list[PreviousMessage],
llm: LLM,
) -> bool:
def _get_search_messages(
question: str,
history_str: str,
) -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": AGGRESSIVE_SEARCH_TEMPLATE.format(
final_query=question, chat_history=history_str
).strip(),
},
]
return messages
# Choosing is globally disabled, use search
if DISABLE_LLM_CHOOSE_SEARCH:
return True
history_str = combine_message_chain(
messages=history, token_limit=GEN_AI_HISTORY_CUTOFF
)
prompt_msgs = _get_search_messages(question=query, history_str=history_str)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
require_search_output = message_to_string(llm.invoke(filled_llm_prompt))
logger.debug(f"Run search prediction: {require_search_output}")
return (SKIP_SEARCH.split()[0]).lower() not in require_search_output.lower()