From 3466451d5171364916a2fe4addb9f9bc1aa3b572 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 24 Nov 2024 14:16:57 -0800 Subject: [PATCH] Fix Prompt for Non Function Calling LLMs (#3241) --- backend/danswer/llm/answering/answer.py | 2 ++ backend/danswer/llm/answering/prompts/build.py | 8 ++------ .../danswer/llm/answering/tool/tool_response_handler.py | 6 +++--- .../tools/tool_implementations/search_like_tool_utils.py | 6 +++++- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index f9c9dbcb100c..170a3d5cd1c1 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -233,6 +233,8 @@ class Answer: # DEBUG: good breakpoint stream = self.llm.stream( + # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM + # may choose to not call any tools and just generate the answer, in which case the task prompt is needed. prompt=current_llm_call.prompt_builder.build(), tools=[tool.tool_definition() for tool in current_llm_call.tools] or None, tool_choice=( diff --git a/backend/danswer/llm/answering/prompts/build.py b/backend/danswer/llm/answering/prompts/build.py index ac9ce6f1abd3..fd44adbe381c 100644 --- a/backend/danswer/llm/answering/prompts/build.py +++ b/backend/danswer/llm/answering/prompts/build.py @@ -58,8 +58,8 @@ class AnswerPromptBuilder: user_message: HumanMessage, message_history: list[PreviousMessage], llm_config: LLMConfig, + raw_user_text: str, single_message_history: str | None = None, - raw_user_text: str | None = None, ) -> None: self.max_tokens = compute_max_llm_input_tokens(llm_config) @@ -89,11 +89,7 @@ class AnswerPromptBuilder: self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = [] - self.raw_user_message = ( - HumanMessage(content=raw_user_text) - if raw_user_text is not None - else user_message - ) + self.raw_user_message = raw_user_text def update_system_prompt(self, system_message: SystemMessage | None) -> None: if not system_message: diff --git a/backend/danswer/llm/answering/tool/tool_response_handler.py b/backend/danswer/llm/answering/tool/tool_response_handler.py index 08e7284f7907..db35663c487f 100644 --- a/backend/danswer/llm/answering/tool/tool_response_handler.py +++ b/backend/danswer/llm/answering/tool/tool_response_handler.py @@ -62,7 +62,7 @@ class ToolResponseHandler: llm_call.force_use_tool.args if llm_call.force_use_tool.args is not None else tool.get_args_for_non_tool_calling_llm( - query=llm_call.prompt_builder.get_user_message_content(), + query=llm_call.prompt_builder.raw_user_message, history=llm_call.prompt_builder.raw_message_history, llm=llm, force_run=True, @@ -76,7 +76,7 @@ class ToolResponseHandler: else: tool_options = check_which_tools_should_run_for_non_tool_calling_llm( tools=llm_call.tools, - query=llm_call.prompt_builder.get_user_message_content(), + query=llm_call.prompt_builder.raw_user_message, history=llm_call.prompt_builder.raw_message_history, llm=llm, ) @@ -95,7 +95,7 @@ class ToolResponseHandler: select_single_tool_for_non_tool_calling_llm( tools_and_args=available_tools_and_args, history=llm_call.prompt_builder.raw_message_history, - query=llm_call.prompt_builder.get_user_message_content(), + query=llm_call.prompt_builder.raw_user_message, llm=llm, ) if available_tools_and_args diff --git a/backend/danswer/tools/tool_implementations/search_like_tool_utils.py b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py index 121841d0ba3c..55890188d7e7 100644 --- a/backend/danswer/tools/tool_implementations/search_like_tool_utils.py +++ b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py @@ -1,5 +1,7 @@ from typing import cast +from langchain_core.messages import HumanMessage + from danswer.chat.models import LlmDoc from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import PromptConfig @@ -58,9 +60,11 @@ def build_next_prompt_for_search_like_tool( # For Quotes, the system prompt is included in the user prompt prompt_builder.update_system_prompt(None) + human_message = HumanMessage(content=prompt_builder.raw_user_message) + prompt_builder.update_user_prompt( build_quotes_user_message( - message=prompt_builder.raw_user_message, + message=human_message, context_docs=final_context_documents, history_str=prompt_builder.single_message_history or "", prompt=prompt_config,