mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-05 17:30:26 +02:00
184 lines
6.7 KiB
Python
184 lines
6.7 KiB
Python
from langchain.schema.messages import HumanMessage
|
|
from langchain.schema.messages import SystemMessage
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.chat.models import LlmDoc
|
|
from onyx.chat.models import PromptConfig
|
|
from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
|
from onyx.context.search.models import InferenceChunk
|
|
from onyx.db.models import Persona
|
|
from onyx.db.prompts import get_default_prompt
|
|
from onyx.db.search_settings import get_multilingual_expansion
|
|
from onyx.llm.factory import get_llms_for_persona
|
|
from onyx.llm.factory import get_main_llm_from_tuple
|
|
from onyx.llm.interfaces import LLMConfig
|
|
from onyx.llm.utils import build_content_with_imgs
|
|
from onyx.llm.utils import check_number_of_tokens
|
|
from onyx.llm.utils import get_max_input_tokens
|
|
from onyx.llm.utils import message_to_prompt_and_imgs
|
|
from onyx.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
|
from onyx.prompts.constants import DEFAULT_IGNORE_STATEMENT
|
|
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT
|
|
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING
|
|
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
|
from onyx.prompts.prompt_utils import build_complete_context_str
|
|
from onyx.prompts.prompt_utils import build_task_prompt_reminders
|
|
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
|
from onyx.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
|
|
from onyx.prompts.token_counts import (
|
|
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
|
|
)
|
|
from onyx.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
|
from onyx.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
|
from onyx.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def get_prompt_tokens(prompt_config: PromptConfig) -> int:
|
|
# Note: currently custom prompts do not allow datetime aware, only default prompts
|
|
return (
|
|
check_number_of_tokens(prompt_config.system_prompt)
|
|
+ check_number_of_tokens(prompt_config.task_prompt)
|
|
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
|
|
+ CITATION_STATEMENT_TOKEN_CNT
|
|
+ CITATION_REMINDER_TOKEN_CNT
|
|
+ (LANGUAGE_HINT_TOKEN_CNT if get_multilingual_expansion() else 0)
|
|
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt_config.datetime_aware else 0)
|
|
)
|
|
|
|
|
|
# buffer just to be safe so that we don't overflow the token limit due to
|
|
# a small miscalculation
|
|
_MISC_BUFFER = 40
|
|
|
|
|
|
def compute_max_document_tokens(
|
|
prompt_config: PromptConfig,
|
|
llm_config: LLMConfig,
|
|
actual_user_input: str | None = None,
|
|
tool_token_count: int = 0,
|
|
max_llm_token_override: int | None = None,
|
|
) -> int:
|
|
"""Estimates the number of tokens available for context documents. Formula is roughly:
|
|
|
|
(
|
|
model_context_window - reserved_output_tokens - prompt_tokens
|
|
- (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe)
|
|
)
|
|
|
|
The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g.
|
|
if we're trying to determine if the user should be able to select another document) then we just set an
|
|
arbitrary "upper bound".
|
|
"""
|
|
# if we can't find a number of tokens, just assume some common default
|
|
max_input_tokens = (
|
|
max_llm_token_override
|
|
if max_llm_token_override
|
|
else get_max_input_tokens(
|
|
model_name=llm_config.model_name, model_provider=llm_config.model_provider
|
|
)
|
|
)
|
|
prompt_tokens = get_prompt_tokens(prompt_config)
|
|
|
|
user_input_tokens = (
|
|
check_number_of_tokens(actual_user_input)
|
|
if actual_user_input is not None
|
|
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
|
)
|
|
|
|
return (
|
|
max_input_tokens
|
|
- prompt_tokens
|
|
- user_input_tokens
|
|
- tool_token_count
|
|
- _MISC_BUFFER
|
|
)
|
|
|
|
|
|
def compute_max_document_tokens_for_persona(
|
|
db_session: Session,
|
|
persona: Persona,
|
|
actual_user_input: str | None = None,
|
|
max_llm_token_override: int | None = None,
|
|
) -> int:
|
|
prompt = persona.prompts[0] if persona.prompts else get_default_prompt(db_session)
|
|
return compute_max_document_tokens(
|
|
prompt_config=PromptConfig.from_model(prompt),
|
|
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,
|
|
actual_user_input=actual_user_input,
|
|
max_llm_token_override=max_llm_token_override,
|
|
)
|
|
|
|
|
|
def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int:
|
|
"""Maximum tokens allows in the input to the LLM (of any type)."""
|
|
|
|
input_tokens = get_max_input_tokens(
|
|
model_name=llm_config.model_name, model_provider=llm_config.model_provider
|
|
)
|
|
return input_tokens - _MISC_BUFFER
|
|
|
|
|
|
def build_citations_system_message(
|
|
prompt_config: PromptConfig,
|
|
) -> SystemMessage:
|
|
system_prompt = prompt_config.system_prompt.strip()
|
|
if prompt_config.include_citations:
|
|
system_prompt += REQUIRE_CITATION_STATEMENT
|
|
tag_handled_prompt = handle_onyx_date_awareness(
|
|
system_prompt, prompt_config, add_additional_info_if_no_tag=True
|
|
)
|
|
|
|
return SystemMessage(content=tag_handled_prompt)
|
|
|
|
|
|
def build_citations_user_message(
|
|
message: HumanMessage,
|
|
prompt_config: PromptConfig,
|
|
context_docs: list[LlmDoc] | list[InferenceChunk],
|
|
all_doc_useful: bool,
|
|
history_message: str = "",
|
|
context_type: str = "context documents",
|
|
) -> HumanMessage:
|
|
multilingual_expansion = get_multilingual_expansion()
|
|
task_prompt_with_reminder = build_task_prompt_reminders(
|
|
prompt=prompt_config, use_language_hint=bool(multilingual_expansion)
|
|
)
|
|
|
|
history_block = (
|
|
HISTORY_BLOCK.format(history_str=history_message) if history_message else ""
|
|
)
|
|
query, img_urls = message_to_prompt_and_imgs(message)
|
|
|
|
if context_docs:
|
|
context_docs_str = build_complete_context_str(context_docs)
|
|
optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT
|
|
|
|
user_prompt = CITATIONS_PROMPT.format(
|
|
context_type=context_type,
|
|
optional_ignore_statement=optional_ignore,
|
|
context_docs_str=context_docs_str,
|
|
task_prompt=task_prompt_with_reminder,
|
|
user_query=query,
|
|
history_block=history_block,
|
|
)
|
|
else:
|
|
# if no context docs provided, assume we're in the tool calling flow
|
|
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
|
|
context_type=context_type,
|
|
task_prompt=task_prompt_with_reminder,
|
|
user_query=query,
|
|
history_block=history_block,
|
|
)
|
|
|
|
user_prompt = user_prompt.strip()
|
|
user_msg = HumanMessage(
|
|
content=build_content_with_imgs(user_prompt, img_urls=img_urls)
|
|
if img_urls
|
|
else user_prompt
|
|
)
|
|
|
|
return user_msg
|