danswer/backend/onyx/chat/prompt_builder/answer_prompt_builder.py
evan-danswer a7125662f1
Fix gpt o-series code block formatting (#4089)
* prompt addition for gpt o-series to encourage markdown formatting of code blocks

* fix to match https://simonwillison.net/tags/markdown/

* chris comment

* chris comment
2025-02-24 00:59:48 +00:00

201 lines
7.0 KiB
Python

from collections.abc import Callable
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModel__v1
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLMConfig
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import check_message_tokens
from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.llm.utils import model_supports_image_input
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
def default_build_system_message(
prompt_config: PromptConfig,
llm_config: LLMConfig,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
# for o-series markdown generation
if (
llm_config.model_provider == OPENAI_PROVIDER_NAME
and llm_config.model_name.startswith("o")
):
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt,
prompt_config,
add_additional_info_if_no_tag=prompt_config.datetime_aware,
)
if not tag_handled_prompt:
return None
return SystemMessage(content=tag_handled_prompt)
def default_build_user_message(
user_query: str,
prompt_config: PromptConfig,
files: list[InMemoryChatFile] = [],
single_message_history: str | None = None,
) -> HumanMessage:
history_block = (
HISTORY_BLOCK.format(history_str=single_message_history)
if single_message_history
else ""
)
user_prompt = (
CHAT_USER_CONTEXT_FREE_PROMPT.format(
history_block=history_block,
task_prompt=prompt_config.task_prompt,
user_query=user_query,
)
if prompt_config.task_prompt
else user_query
)
user_prompt = user_prompt.strip()
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
user_msg = HumanMessage(
content=build_content_with_imgs(tag_handled_prompt, files)
if files
else tag_handled_prompt
)
return user_msg
class AnswerPromptBuilder:
def __init__(
self,
user_message: HumanMessage,
message_history: list[PreviousMessage],
llm_config: LLMConfig,
raw_user_query: str,
raw_user_uploaded_files: list[InMemoryChatFile],
single_message_history: str | None = None,
system_message: SystemMessage | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
llm_tokenizer = get_tokenizer(
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
)
self.llm_config = llm_config
self.llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
self.raw_message_history = message_history
(
self.message_history,
self.history_token_cnts,
) = translate_history_to_basemessages(
message_history,
exclude_images=not model_supports_image_input(
self.llm_config.model_name,
self.llm_config.model_provider,
),
)
self.update_system_prompt(system_message)
self.update_user_prompt(user_message)
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
# used for building a new prompt after a tool-call
self.raw_user_query = raw_user_query
self.raw_user_uploaded_files = raw_user_uploaded_files
self.single_message_history = single_message_history
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:
self.system_message_and_token_cnt = None
return
self.system_message_and_token_cnt = (
system_message,
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
)
def update_user_prompt(self, user_message: HumanMessage) -> None:
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
)
def append_message(self, message: BaseMessage) -> None:
"""Append a new message to the message history."""
token_count = check_message_tokens(message, self.llm_tokenizer_encode_func)
self.new_messages_and_token_cnts.append((message, token_count))
def get_user_message_content(self) -> str:
query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0])
return query
def build(self) -> list[BaseMessage]:
if not self.user_message_and_token_cnt:
raise ValueError("User message must be set before building prompt")
final_messages_with_tokens: list[tuple[BaseMessage, int]] = []
if self.system_message_and_token_cnt:
final_messages_with_tokens.append(self.system_message_and_token_cnt)
final_messages_with_tokens.extend(
[
(self.message_history[i], self.history_token_cnts[i])
for i in range(len(self.message_history))
]
)
final_messages_with_tokens.append(self.user_message_and_token_cnt)
if self.new_messages_and_token_cnts:
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens
)
# Stores some parts of a prompt builder as needed for tool calls
class PromptSnapshot(BaseModel):
raw_message_history: list[PreviousMessage]
raw_user_query: str
built_prompt: list[BaseMessage]
# TODO: rename this? AnswerConfig maybe?
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder
tools: list[Tool]
force_use_tool: ForceUseTool
files: list[InMemoryChatFile]
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
using_tool_calling_llm: bool
class Config:
arbitrary_types_allowed = True