from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage from sqlalchemy.orm import Session from onyx.chat.models import LlmDoc from onyx.chat.models import PromptConfig from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from onyx.context.search.models import InferenceChunk from onyx.db.models import Persona from onyx.db.prompts import get_default_prompt from onyx.db.search_settings import get_multilingual_expansion from onyx.llm.factory import get_llms_for_persona from onyx.llm.factory import get_main_llm_from_tuple from onyx.llm.interfaces import LLMConfig from onyx.llm.utils import build_content_with_imgs from onyx.llm.utils import check_number_of_tokens from onyx.llm.utils import get_max_input_tokens from onyx.llm.utils import message_to_prompt_and_imgs from onyx.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT from onyx.prompts.constants import DEFAULT_IGNORE_STATEMENT from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK from onyx.prompts.prompt_utils import build_complete_context_str from onyx.prompts.prompt_utils import build_task_prompt_reminders from onyx.prompts.prompt_utils import handle_onyx_date_awareness from onyx.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT from onyx.prompts.token_counts import ( CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT, ) from onyx.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT from onyx.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT from onyx.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT from onyx.utils.logger import setup_logger logger = setup_logger() def get_prompt_tokens(prompt_config: PromptConfig) -> int: # Note: currently custom prompts do not allow datetime aware, only default prompts return ( check_number_of_tokens(prompt_config.system_prompt) + check_number_of_tokens(prompt_config.task_prompt) + CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT + CITATION_STATEMENT_TOKEN_CNT + CITATION_REMINDER_TOKEN_CNT + (LANGUAGE_HINT_TOKEN_CNT if get_multilingual_expansion() else 0) + (ADDITIONAL_INFO_TOKEN_CNT if prompt_config.datetime_aware else 0) ) # buffer just to be safe so that we don't overflow the token limit due to # a small miscalculation _MISC_BUFFER = 40 def compute_max_document_tokens( prompt_config: PromptConfig, llm_config: LLMConfig, actual_user_input: str | None = None, tool_token_count: int = 0, max_llm_token_override: int | None = None, ) -> int: """Estimates the number of tokens available for context documents. Formula is roughly: ( model_context_window - reserved_output_tokens - prompt_tokens - (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe) ) The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g. if we're trying to determine if the user should be able to select another document) then we just set an arbitrary "upper bound". """ # if we can't find a number of tokens, just assume some common default max_input_tokens = ( max_llm_token_override if max_llm_token_override else get_max_input_tokens( model_name=llm_config.model_name, model_provider=llm_config.model_provider ) ) prompt_tokens = get_prompt_tokens(prompt_config) user_input_tokens = ( check_number_of_tokens(actual_user_input) if actual_user_input is not None else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS ) return ( max_input_tokens - prompt_tokens - user_input_tokens - tool_token_count - _MISC_BUFFER ) def compute_max_document_tokens_for_persona( db_session: Session, persona: Persona, actual_user_input: str | None = None, max_llm_token_override: int | None = None, ) -> int: prompt = persona.prompts[0] if persona.prompts else get_default_prompt(db_session) return compute_max_document_tokens( prompt_config=PromptConfig.from_model(prompt), llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config, actual_user_input=actual_user_input, max_llm_token_override=max_llm_token_override, ) def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int: """Maximum tokens allows in the input to the LLM (of any type).""" input_tokens = get_max_input_tokens( model_name=llm_config.model_name, model_provider=llm_config.model_provider ) return input_tokens - _MISC_BUFFER def build_citations_system_message( prompt_config: PromptConfig, ) -> SystemMessage: system_prompt = prompt_config.system_prompt.strip() if prompt_config.include_citations: system_prompt += REQUIRE_CITATION_STATEMENT tag_handled_prompt = handle_onyx_date_awareness( system_prompt, prompt_config, add_additional_info_if_no_tag=True ) return SystemMessage(content=tag_handled_prompt) def build_citations_user_message( message: HumanMessage, prompt_config: PromptConfig, context_docs: list[LlmDoc] | list[InferenceChunk], all_doc_useful: bool, history_message: str = "", context_type: str = "context documents", ) -> HumanMessage: multilingual_expansion = get_multilingual_expansion() task_prompt_with_reminder = build_task_prompt_reminders( prompt=prompt_config, use_language_hint=bool(multilingual_expansion) ) history_block = ( HISTORY_BLOCK.format(history_str=history_message) if history_message else "" ) query, img_urls = message_to_prompt_and_imgs(message) if context_docs: context_docs_str = build_complete_context_str(context_docs) optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT user_prompt = CITATIONS_PROMPT.format( context_type=context_type, optional_ignore_statement=optional_ignore, context_docs_str=context_docs_str, task_prompt=task_prompt_with_reminder, user_query=query, history_block=history_block, ) else: # if no context docs provided, assume we're in the tool calling flow user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format( context_type=context_type, task_prompt=task_prompt_with_reminder, user_query=query, history_block=history_block, ) user_prompt = user_prompt.strip() user_msg = HumanMessage( content=build_content_with_imgs(user_prompt, img_urls=img_urls) if img_urls else user_prompt ) return user_msg