2024-12-13 09:56:10 -08:00

167 lines
5.2 KiB
Python

from collections.abc import Callable
from onyx.chat.chat_utils import combine_message_chain
from onyx.configs.chat_configs import DISABLE_LLM_QUERY_REPHRASE
from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from onyx.db.models import ChatMessage
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_default_llms
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import dict_based_prompt_to_langchain_prompt
from onyx.llm.utils import message_to_string
from onyx.prompts.chat_prompts import HISTORY_QUERY_REPHRASE
from onyx.prompts.miscellaneous_prompts import LANGUAGE_REPHRASE_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import count_punctuation
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
def llm_multilingual_query_expansion(query: str, language: str) -> str:
def _get_rephrase_messages() -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": LANGUAGE_REPHRASE_PROMPT.format(
query=query, target_language=language
),
},
]
return messages
try:
_, fast_llm = get_default_llms(timeout=5)
except GenAIDisabledException:
logger.warning(
"Unable to perform multilingual query expansion, Gen AI disabled"
)
return query
messages = _get_rephrase_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = message_to_string(fast_llm.invoke(filled_llm_prompt))
logger.debug(model_output)
return model_output
def multilingual_query_expansion(
query: str,
expansion_languages: list[str],
use_threads: bool = True,
) -> list[str]:
languages = [language.strip() for language in expansion_languages]
if use_threads:
functions_with_args: list[tuple[Callable, tuple]] = [
(llm_multilingual_query_expansion, (query, language))
for language in languages
]
query_rephrases = run_functions_tuples_in_parallel(functions_with_args)
return query_rephrases
else:
query_rephrases = [
llm_multilingual_query_expansion(query, language) for language in languages
]
return query_rephrases
def get_contextual_rephrase_messages(
question: str,
history_str: str,
prompt_template: str = HISTORY_QUERY_REPHRASE,
) -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": prompt_template.format(
question=question, chat_history=history_str
),
},
]
return messages
def history_based_query_rephrase(
query: str,
history: list[ChatMessage] | list[PreviousMessage],
llm: LLM,
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
skip_first_rephrase: bool = True,
prompt_template: str = HISTORY_QUERY_REPHRASE,
) -> str:
# Globally disabled, just use the exact user query
if DISABLE_LLM_QUERY_REPHRASE:
return query
# For some use cases, the first query should be untouched. Later queries must be rephrased
# due to needing context but the first query has no context.
if skip_first_rephrase and not history:
return query
# If it's a very large query, assume it's a copy paste which we may want to find exactly
# or at least very closely, so don't rephrase it
if len(query) >= size_heuristic:
return query
# If there is an unusually high number of punctuations, it's probably not natural language
# so don't rephrase it
if count_punctuation(query) >= punctuation_heuristic:
return query
history_str = combine_message_chain(
messages=history, token_limit=GEN_AI_HISTORY_CUTOFF
)
prompt_msgs = get_contextual_rephrase_messages(
question=query, history_str=history_str, prompt_template=prompt_template
)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
rephrased_query = message_to_string(llm.invoke(filled_llm_prompt))
logger.debug(f"Rephrased combined query: {rephrased_query}")
return rephrased_query
def thread_based_query_rephrase(
user_query: str,
history_str: str,
llm: LLM | None = None,
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
) -> str:
if not history_str:
return user_query
if len(user_query) >= size_heuristic:
return user_query
if count_punctuation(user_query) >= punctuation_heuristic:
return user_query
if llm is None:
try:
llm, _ = get_default_llms()
except GenAIDisabledException:
# If Generative AI is turned off, just return the original query
return user_query
prompt_msgs = get_contextual_rephrase_messages(
question=user_query, history_str=history_str
)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
rephrased_query = message_to_string(llm.invoke(filled_llm_prompt))
logger.debug(f"Rephrased combined query: {rephrased_query}")
return rephrased_query