mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-03 16:30:21 +02:00
* prompt addition for gpt o-series to encourage markdown formatting of code blocks * fix to match https://simonwillison.net/tags/markdown/ * chris comment * chris comment
201 lines
7.0 KiB
Python
201 lines
7.0 KiB
Python
from collections.abc import Callable
|
|
from typing import cast
|
|
|
|
from langchain_core.messages import BaseMessage
|
|
from langchain_core.messages import HumanMessage
|
|
from langchain_core.messages import SystemMessage
|
|
from pydantic import BaseModel
|
|
from pydantic.v1 import BaseModel as BaseModel__v1
|
|
|
|
from onyx.chat.models import PromptConfig
|
|
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
|
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
|
|
from onyx.file_store.models import InMemoryChatFile
|
|
from onyx.llm.interfaces import LLMConfig
|
|
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
|
from onyx.llm.models import PreviousMessage
|
|
from onyx.llm.utils import build_content_with_imgs
|
|
from onyx.llm.utils import check_message_tokens
|
|
from onyx.llm.utils import message_to_prompt_and_imgs
|
|
from onyx.llm.utils import model_supports_image_input
|
|
from onyx.natural_language_processing.utils import get_tokenizer
|
|
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
|
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
|
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
|
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
|
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
|
from onyx.tools.force import ForceUseTool
|
|
from onyx.tools.models import ToolCallFinalResult
|
|
from onyx.tools.models import ToolCallKickoff
|
|
from onyx.tools.models import ToolResponse
|
|
from onyx.tools.tool import Tool
|
|
|
|
|
|
def default_build_system_message(
|
|
prompt_config: PromptConfig,
|
|
llm_config: LLMConfig,
|
|
) -> SystemMessage | None:
|
|
system_prompt = prompt_config.system_prompt.strip()
|
|
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
|
|
# for o-series markdown generation
|
|
if (
|
|
llm_config.model_provider == OPENAI_PROVIDER_NAME
|
|
and llm_config.model_name.startswith("o")
|
|
):
|
|
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
|
|
tag_handled_prompt = handle_onyx_date_awareness(
|
|
system_prompt,
|
|
prompt_config,
|
|
add_additional_info_if_no_tag=prompt_config.datetime_aware,
|
|
)
|
|
|
|
if not tag_handled_prompt:
|
|
return None
|
|
|
|
return SystemMessage(content=tag_handled_prompt)
|
|
|
|
|
|
def default_build_user_message(
|
|
user_query: str,
|
|
prompt_config: PromptConfig,
|
|
files: list[InMemoryChatFile] = [],
|
|
single_message_history: str | None = None,
|
|
) -> HumanMessage:
|
|
history_block = (
|
|
HISTORY_BLOCK.format(history_str=single_message_history)
|
|
if single_message_history
|
|
else ""
|
|
)
|
|
|
|
user_prompt = (
|
|
CHAT_USER_CONTEXT_FREE_PROMPT.format(
|
|
history_block=history_block,
|
|
task_prompt=prompt_config.task_prompt,
|
|
user_query=user_query,
|
|
)
|
|
if prompt_config.task_prompt
|
|
else user_query
|
|
)
|
|
user_prompt = user_prompt.strip()
|
|
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
|
|
user_msg = HumanMessage(
|
|
content=build_content_with_imgs(tag_handled_prompt, files)
|
|
if files
|
|
else tag_handled_prompt
|
|
)
|
|
return user_msg
|
|
|
|
|
|
class AnswerPromptBuilder:
|
|
def __init__(
|
|
self,
|
|
user_message: HumanMessage,
|
|
message_history: list[PreviousMessage],
|
|
llm_config: LLMConfig,
|
|
raw_user_query: str,
|
|
raw_user_uploaded_files: list[InMemoryChatFile],
|
|
single_message_history: str | None = None,
|
|
system_message: SystemMessage | None = None,
|
|
) -> None:
|
|
self.max_tokens = compute_max_llm_input_tokens(llm_config)
|
|
|
|
llm_tokenizer = get_tokenizer(
|
|
provider_type=llm_config.model_provider,
|
|
model_name=llm_config.model_name,
|
|
)
|
|
self.llm_config = llm_config
|
|
self.llm_tokenizer_encode_func = cast(
|
|
Callable[[str], list[int]], llm_tokenizer.encode
|
|
)
|
|
|
|
self.raw_message_history = message_history
|
|
(
|
|
self.message_history,
|
|
self.history_token_cnts,
|
|
) = translate_history_to_basemessages(
|
|
message_history,
|
|
exclude_images=not model_supports_image_input(
|
|
self.llm_config.model_name,
|
|
self.llm_config.model_provider,
|
|
),
|
|
)
|
|
|
|
self.update_system_prompt(system_message)
|
|
self.update_user_prompt(user_message)
|
|
|
|
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
|
|
|
# used for building a new prompt after a tool-call
|
|
self.raw_user_query = raw_user_query
|
|
self.raw_user_uploaded_files = raw_user_uploaded_files
|
|
self.single_message_history = single_message_history
|
|
|
|
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
|
|
if not system_message:
|
|
self.system_message_and_token_cnt = None
|
|
return
|
|
|
|
self.system_message_and_token_cnt = (
|
|
system_message,
|
|
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
|
|
)
|
|
|
|
def update_user_prompt(self, user_message: HumanMessage) -> None:
|
|
self.user_message_and_token_cnt = (
|
|
user_message,
|
|
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
|
)
|
|
|
|
def append_message(self, message: BaseMessage) -> None:
|
|
"""Append a new message to the message history."""
|
|
token_count = check_message_tokens(message, self.llm_tokenizer_encode_func)
|
|
self.new_messages_and_token_cnts.append((message, token_count))
|
|
|
|
def get_user_message_content(self) -> str:
|
|
query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0])
|
|
return query
|
|
|
|
def build(self) -> list[BaseMessage]:
|
|
if not self.user_message_and_token_cnt:
|
|
raise ValueError("User message must be set before building prompt")
|
|
|
|
final_messages_with_tokens: list[tuple[BaseMessage, int]] = []
|
|
if self.system_message_and_token_cnt:
|
|
final_messages_with_tokens.append(self.system_message_and_token_cnt)
|
|
|
|
final_messages_with_tokens.extend(
|
|
[
|
|
(self.message_history[i], self.history_token_cnts[i])
|
|
for i in range(len(self.message_history))
|
|
]
|
|
)
|
|
|
|
final_messages_with_tokens.append(self.user_message_and_token_cnt)
|
|
|
|
if self.new_messages_and_token_cnts:
|
|
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
|
|
|
|
return drop_messages_history_overflow(
|
|
final_messages_with_tokens, self.max_tokens
|
|
)
|
|
|
|
|
|
# Stores some parts of a prompt builder as needed for tool calls
|
|
class PromptSnapshot(BaseModel):
|
|
raw_message_history: list[PreviousMessage]
|
|
raw_user_query: str
|
|
built_prompt: list[BaseMessage]
|
|
|
|
|
|
# TODO: rename this? AnswerConfig maybe?
|
|
class LLMCall(BaseModel__v1):
|
|
prompt_builder: AnswerPromptBuilder
|
|
tools: list[Tool]
|
|
force_use_tool: ForceUseTool
|
|
files: list[InMemoryChatFile]
|
|
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
|
using_tool_calling_llm: bool
|
|
|
|
class Config:
|
|
arbitrary_types_allowed = True
|