mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
welcome to onyx
This commit is contained in:
159
backend/onyx/chat/prompt_builder/build.py
Normal file
159
backend/onyx/chat/prompt_builder/build.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.llm.utils import check_message_tokens
|
||||
from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from onyx.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
|
||||
def default_build_system_message(
|
||||
prompt_config: PromptConfig,
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
if prompt_config.datetime_aware:
|
||||
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
|
||||
|
||||
if not system_prompt:
|
||||
return None
|
||||
|
||||
system_msg = SystemMessage(content=system_prompt)
|
||||
|
||||
return system_msg
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
|
||||
) -> HumanMessage:
|
||||
user_prompt = (
|
||||
CHAT_USER_CONTEXT_FREE_PROMPT.format(
|
||||
task_prompt=prompt_config.task_prompt, user_query=user_query
|
||||
)
|
||||
if prompt_config.task_prompt
|
||||
else user_query
|
||||
)
|
||||
user_prompt = user_prompt.strip()
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
|
||||
)
|
||||
return user_msg
|
||||
|
||||
|
||||
class AnswerPromptBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
user_message: HumanMessage,
|
||||
message_history: list[PreviousMessage],
|
||||
llm_config: LLMConfig,
|
||||
raw_user_text: str,
|
||||
single_message_history: str | None = None,
|
||||
) -> None:
|
||||
self.max_tokens = compute_max_llm_input_tokens(llm_config)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
)
|
||||
self.llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
|
||||
self.raw_message_history = message_history
|
||||
(
|
||||
self.message_history,
|
||||
self.history_token_cnts,
|
||||
) = translate_history_to_basemessages(message_history)
|
||||
|
||||
# for cases where like the QA flow where we want to condense the chat history
|
||||
# into a single message rather than a sequence of User / Assistant messages
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
||||
self.raw_user_message = raw_user_text
|
||||
|
||||
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
|
||||
if not system_message:
|
||||
self.system_message_and_token_cnt = None
|
||||
return
|
||||
|
||||
self.system_message_and_token_cnt = (
|
||||
system_message,
|
||||
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
|
||||
def update_user_prompt(self, user_message: HumanMessage) -> None:
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
|
||||
def append_message(self, message: BaseMessage) -> None:
|
||||
"""Append a new message to the message history."""
|
||||
token_count = check_message_tokens(message, self.llm_tokenizer_encode_func)
|
||||
self.new_messages_and_token_cnts.append((message, token_count))
|
||||
|
||||
def get_user_message_content(self) -> str:
|
||||
query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0])
|
||||
return query
|
||||
|
||||
def build(self) -> list[BaseMessage]:
|
||||
if not self.user_message_and_token_cnt:
|
||||
raise ValueError("User message must be set before building prompt")
|
||||
|
||||
final_messages_with_tokens: list[tuple[BaseMessage, int]] = []
|
||||
if self.system_message_and_token_cnt:
|
||||
final_messages_with_tokens.append(self.system_message_and_token_cnt)
|
||||
|
||||
final_messages_with_tokens.extend(
|
||||
[
|
||||
(self.message_history[i], self.history_token_cnts[i])
|
||||
for i in range(len(self.message_history))
|
||||
]
|
||||
)
|
||||
|
||||
final_messages_with_tokens.append(self.user_message_and_token_cnt)
|
||||
|
||||
if self.new_messages_and_token_cnts:
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
)
|
||||
|
||||
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
force_use_tool: ForceUseTool
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
||||
using_tool_calling_llm: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
179
backend/onyx/chat/prompt_builder/citations_prompt.py
Normal file
179
backend/onyx/chat/prompt_builder/citations_prompt.py
Normal file
@@ -0,0 +1,179 @@
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.persona import get_default_prompt__read_only
|
||||
from onyx.db.search_settings import get_multilingual_expansion
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||
from onyx.prompts.constants import DEFAULT_IGNORE_STATEMENT
|
||||
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from onyx.prompts.prompt_utils import build_complete_context_str
|
||||
from onyx.prompts.prompt_utils import build_task_prompt_reminders
|
||||
from onyx.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
|
||||
from onyx.prompts.token_counts import (
|
||||
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
|
||||
)
|
||||
from onyx.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
||||
from onyx.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
||||
from onyx.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_prompt_tokens(prompt_config: PromptConfig) -> int:
|
||||
# Note: currently custom prompts do not allow datetime aware, only default prompts
|
||||
return (
|
||||
check_number_of_tokens(prompt_config.system_prompt)
|
||||
+ check_number_of_tokens(prompt_config.task_prompt)
|
||||
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
|
||||
+ CITATION_STATEMENT_TOKEN_CNT
|
||||
+ CITATION_REMINDER_TOKEN_CNT
|
||||
+ (LANGUAGE_HINT_TOKEN_CNT if get_multilingual_expansion() else 0)
|
||||
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt_config.datetime_aware else 0)
|
||||
)
|
||||
|
||||
|
||||
# buffer just to be safe so that we don't overflow the token limit due to
|
||||
# a small miscalculation
|
||||
_MISC_BUFFER = 40
|
||||
|
||||
|
||||
def compute_max_document_tokens(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
actual_user_input: str | None = None,
|
||||
tool_token_count: int = 0,
|
||||
max_llm_token_override: int | None = None,
|
||||
) -> int:
|
||||
"""Estimates the number of tokens available for context documents. Formula is roughly:
|
||||
|
||||
(
|
||||
model_context_window - reserved_output_tokens - prompt_tokens
|
||||
- (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe)
|
||||
)
|
||||
|
||||
The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g.
|
||||
if we're trying to determine if the user should be able to select another document) then we just set an
|
||||
arbitrary "upper bound".
|
||||
"""
|
||||
# if we can't find a number of tokens, just assume some common default
|
||||
max_input_tokens = (
|
||||
max_llm_token_override
|
||||
if max_llm_token_override
|
||||
else get_max_input_tokens(
|
||||
model_name=llm_config.model_name, model_provider=llm_config.model_provider
|
||||
)
|
||||
)
|
||||
prompt_tokens = get_prompt_tokens(prompt_config)
|
||||
|
||||
user_input_tokens = (
|
||||
check_number_of_tokens(actual_user_input)
|
||||
if actual_user_input is not None
|
||||
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
)
|
||||
|
||||
return (
|
||||
max_input_tokens
|
||||
- prompt_tokens
|
||||
- user_input_tokens
|
||||
- tool_token_count
|
||||
- _MISC_BUFFER
|
||||
)
|
||||
|
||||
|
||||
def compute_max_document_tokens_for_persona(
|
||||
persona: Persona,
|
||||
actual_user_input: str | None = None,
|
||||
max_llm_token_override: int | None = None,
|
||||
) -> int:
|
||||
prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only()
|
||||
return compute_max_document_tokens(
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,
|
||||
actual_user_input=actual_user_input,
|
||||
max_llm_token_override=max_llm_token_override,
|
||||
)
|
||||
|
||||
|
||||
def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int:
|
||||
"""Maximum tokens allows in the input to the LLM (of any type)."""
|
||||
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm_config.model_name, model_provider=llm_config.model_provider
|
||||
)
|
||||
return input_tokens - _MISC_BUFFER
|
||||
|
||||
|
||||
def build_citations_system_message(
|
||||
prompt_config: PromptConfig,
|
||||
) -> SystemMessage:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
if prompt_config.include_citations:
|
||||
system_prompt += REQUIRE_CITATION_STATEMENT
|
||||
if prompt_config.datetime_aware:
|
||||
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
|
||||
|
||||
return SystemMessage(content=system_prompt)
|
||||
|
||||
|
||||
def build_citations_user_message(
|
||||
message: HumanMessage,
|
||||
prompt_config: PromptConfig,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
all_doc_useful: bool,
|
||||
history_message: str = "",
|
||||
) -> HumanMessage:
|
||||
multilingual_expansion = get_multilingual_expansion()
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(
|
||||
prompt=prompt_config, use_language_hint=bool(multilingual_expansion)
|
||||
)
|
||||
|
||||
history_block = (
|
||||
HISTORY_BLOCK.format(history_str=history_message) + "\n"
|
||||
if history_message
|
||||
else ""
|
||||
)
|
||||
query, img_urls = message_to_prompt_and_imgs(message)
|
||||
|
||||
if context_docs:
|
||||
context_docs_str = build_complete_context_str(context_docs)
|
||||
optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT
|
||||
|
||||
user_prompt = CITATIONS_PROMPT.format(
|
||||
optional_ignore_statement=optional_ignore,
|
||||
context_docs_str=context_docs_str,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=query,
|
||||
history_block=history_block,
|
||||
)
|
||||
else:
|
||||
# if no context docs provided, assume we're in the tool calling flow
|
||||
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=query,
|
||||
history_block=history_block,
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(user_prompt, img_urls=img_urls)
|
||||
if img_urls
|
||||
else user_prompt
|
||||
)
|
||||
|
||||
return user_msg
|
61
backend/onyx/chat/prompt_builder/quotes_prompt.py
Normal file
61
backend/onyx/chat/prompt_builder/quotes_prompt.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.configs.chat_configs import LANGUAGE_HINT
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.search_settings import get_multilingual_expansion
|
||||
from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from onyx.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from onyx.prompts.prompt_utils import build_complete_context_str
|
||||
|
||||
|
||||
def _build_strong_llm_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
) -> HumanMessage:
|
||||
use_language_hint = bool(get_multilingual_expansion())
|
||||
|
||||
context_block = ""
|
||||
if context_docs:
|
||||
context_docs_str = build_complete_context_str(context_docs)
|
||||
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
|
||||
|
||||
history_block = ""
|
||||
if history_str:
|
||||
history_block = HISTORY_BLOCK.format(history_str=history_str)
|
||||
|
||||
full_prompt = JSON_PROMPT.format(
|
||||
system_prompt=prompt.system_prompt,
|
||||
context_block=context_block,
|
||||
history_block=history_block,
|
||||
task_prompt=prompt.task_prompt,
|
||||
user_query=question,
|
||||
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
|
||||
).strip()
|
||||
|
||||
if prompt.datetime_aware:
|
||||
full_prompt = add_date_time_to_prompt(prompt_str=full_prompt)
|
||||
|
||||
return HumanMessage(content=full_prompt)
|
||||
|
||||
|
||||
def build_quotes_user_message(
|
||||
message: HumanMessage,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
) -> HumanMessage:
|
||||
query, _ = message_to_prompt_and_imgs(message)
|
||||
|
||||
return _build_strong_llm_quotes_prompt(
|
||||
question=query,
|
||||
context_docs=context_docs,
|
||||
history_str=history_str,
|
||||
prompt=prompt,
|
||||
)
|
60
backend/onyx/chat/prompt_builder/utils.py
Normal file
60
backend/onyx/chat/prompt_builder/utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
|
||||
def translate_onyx_msg_to_langchain(
|
||||
msg: ChatMessage | PreviousMessage,
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
if msg.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
|
||||
raise ValueError(f"New message type {msg.message_type} not handled")
|
||||
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_onyx_msg_to_langchain(msg) for msg in history if msg.token_count != 0
|
||||
]
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
return history_basemessages, history_token_counts
|
Reference in New Issue
Block a user