From 8a4d762798027fbb13f80a3c8bd7896ca8e68827 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Wed, 15 Jan 2025 18:40:25 -0800 Subject: [PATCH] Fix follow ups in thread + fix user name (#3686) * Fix follow ups in thread + fix user name * Add back single history str * Remove newline --- ...add_chat_message__standard_answer_table.py | 36 +++++++++++++ .../slack/handlers/handle_standard_answers.py | 15 ++++-- backend/onyx/chat/answer.py | 25 +++------ backend/onyx/chat/llm_response_handler.py | 2 +- backend/onyx/chat/models.py | 22 +------- backend/onyx/chat/process_message.py | 6 +++ .../{build.py => answer_prompt_builder.py} | 28 +++++++--- .../chat/prompt_builder/citations_prompt.py | 4 +- .../tool_handling/tool_response_handler.py | 8 +-- .../onyxbot/slack/handlers/handle_buttons.py | 2 +- .../onyxbot/slack/handlers/handle_message.py | 8 +-- .../slack/handlers/handle_regular_answer.py | 16 ++++-- backend/onyx/onyxbot/slack/listener.py | 25 ++++++--- backend/onyx/onyxbot/slack/models.py | 2 +- backend/onyx/prompts/chat_prompts.py | 3 +- backend/onyx/tools/base_tool.py | 2 +- backend/onyx/tools/tool.py | 2 +- .../custom/custom_tool.py | 2 +- .../images/image_generation_tool.py | 2 +- .../internet_search/internet_search_tool.py | 2 +- .../search/search_tool.py | 2 +- .../search_like_tool_utils.py | 54 ++++++++----------- backend/tests/unit/onyx/chat/conftest.py | 2 +- 23 files changed, 153 insertions(+), 117 deletions(-) create mode 100644 backend/alembic/versions/c5eae4a75a1b_add_chat_message__standard_answer_table.py rename backend/onyx/chat/prompt_builder/{build.py => answer_prompt_builder.py} (88%) diff --git a/backend/alembic/versions/c5eae4a75a1b_add_chat_message__standard_answer_table.py b/backend/alembic/versions/c5eae4a75a1b_add_chat_message__standard_answer_table.py new file mode 100644 index 000000000..ae743655a --- /dev/null +++ b/backend/alembic/versions/c5eae4a75a1b_add_chat_message__standard_answer_table.py @@ -0,0 +1,36 @@ +"""Add chat_message__standard_answer table + +Revision ID: c5eae4a75a1b +Revises: 0f7ff6d75b57 +Create Date: 2025-01-15 14:08:49.688998 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "c5eae4a75a1b" +down_revision = "0f7ff6d75b57" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "chat_message__standard_answer", + sa.Column("chat_message_id", sa.Integer(), nullable=False), + sa.Column("standard_answer_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["chat_message_id"], + ["chat_message.id"], + ), + sa.ForeignKeyConstraint( + ["standard_answer_id"], + ["standard_answer.id"], + ), + sa.PrimaryKeyConstraint("chat_message_id", "standard_answer_id"), + ) + + +def downgrade() -> None: + op.drop_table("chat_message__standard_answer") diff --git a/backend/ee/onyx/onyxbot/slack/handlers/handle_standard_answers.py b/backend/ee/onyx/onyxbot/slack/handlers/handle_standard_answers.py index 478713377..5b994c126 100644 --- a/backend/ee/onyx/onyxbot/slack/handlers/handle_standard_answers.py +++ b/backend/ee/onyx/onyxbot/slack/handlers/handle_standard_answers.py @@ -150,9 +150,9 @@ def _handle_standard_answers( db_session=db_session, description="", user_id=None, - persona_id=slack_channel_config.persona.id - if slack_channel_config.persona - else 0, + persona_id=( + slack_channel_config.persona.id if slack_channel_config.persona else 0 + ), onyxbot_flow=True, slack_thread_id=slack_thread_id, ) @@ -182,7 +182,7 @@ def _handle_standard_answers( formatted_answers.append(formatted_answer) answer_message = "\n\n".join(formatted_answers) - _ = create_new_chat_message( + chat_message = create_new_chat_message( chat_session_id=chat_session.id, parent_message=new_user_message, prompt_id=prompt.id if prompt else None, @@ -191,8 +191,13 @@ def _handle_standard_answers( message_type=MessageType.ASSISTANT, error=None, db_session=db_session, - commit=True, + commit=False, ) + # attach the standard answers to the chat message + chat_message.standard_answers = [ + standard_answer for standard_answer, _ in matching_standard_answers + ] + db_session.commit() update_emote_react( emoji=DANSWER_REACT_EMOJI, diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index 14b9c227a..ff211cbf3 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -12,10 +12,10 @@ from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationInfo from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import PromptConfig -from onyx.chat.prompt_builder.build import AnswerPromptBuilder -from onyx.chat.prompt_builder.build import default_build_system_message -from onyx.chat.prompt_builder.build import default_build_user_message -from onyx.chat.prompt_builder.build import LLMCall +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder +from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message +from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message +from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall from onyx.chat.stream_processing.answer_response_handler import ( CitationResponseHandler, ) @@ -212,19 +212,6 @@ class Answer: current_llm_call ) or ([], []) - # Quotes are no longer supported - # answer_handler: AnswerResponseHandler - # if self.answer_style_config.citation_config: - # answer_handler = CitationResponseHandler( - # context_docs=search_result, - # doc_id_to_rank_map=map_document_id_order(search_result), - # ) - # elif self.answer_style_config.quotes_config: - # answer_handler = QuotesResponseHandler( - # context_docs=search_result, - # ) - # else: - # raise ValueError("No answer style config provided") answer_handler = CitationResponseHandler( context_docs=final_search_results, final_doc_id_to_rank_map=map_document_id_order(final_search_results), @@ -265,11 +252,13 @@ class Answer: user_query=self.question, prompt_config=self.prompt_config, files=self.latest_query_files, + single_message_history=self.single_message_history, ), message_history=self.message_history, llm_config=self.llm.config, + raw_user_query=self.question, + raw_user_uploaded_files=self.latest_query_files or [], single_message_history=self.single_message_history, - raw_user_text=self.question, ) prompt_builder.update_system_prompt( default_build_system_message(self.prompt_config) diff --git a/backend/onyx/chat/llm_response_handler.py b/backend/onyx/chat/llm_response_handler.py index 612ce5dd5..7c9c8ee71 100644 --- a/backend/onyx/chat/llm_response_handler.py +++ b/backend/onyx/chat/llm_response_handler.py @@ -7,7 +7,7 @@ from langchain_core.messages import BaseMessage from onyx.chat.models import ResponsePart from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason -from onyx.chat.prompt_builder.build import LLMCall +from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index 44973446f..2c5426045 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field -from pydantic import model_validator from onyx.configs.constants import DocumentSource from onyx.configs.constants import MessageType @@ -261,13 +260,8 @@ class CitationConfig(BaseModel): all_docs_useful: bool = False -class QuotesConfig(BaseModel): - pass - - class AnswerStyleConfig(BaseModel): - citation_config: CitationConfig | None = None - quotes_config: QuotesConfig | None = None + citation_config: CitationConfig document_pruning_config: DocumentPruningConfig = Field( default_factory=DocumentPruningConfig ) @@ -276,20 +270,6 @@ class AnswerStyleConfig(BaseModel): # right now, only used by the simple chat API structured_response_format: dict | None = None - @model_validator(mode="after") - def check_quotes_and_citation(self) -> "AnswerStyleConfig": - if self.citation_config is None and self.quotes_config is None: - raise ValueError( - "One of `citation_config` or `quotes_config` must be provided" - ) - - if self.citation_config is not None and self.quotes_config is not None: - raise ValueError( - "Only one of `citation_config` or `quotes_config` must be provided" - ) - - return self - class PromptConfig(BaseModel): """Final representation of the Prompt configuration passed diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index 28d67d032..b874ea7ef 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -302,6 +302,11 @@ def stream_chat_message_objects( enforce_chat_session_id_for_search_docs: bool = True, bypass_acl: bool = False, include_contexts: bool = False, + # a string which represents the history of a conversation. Used in cases like + # Slack threads where the conversation cannot be represented by a chain of User/Assistant + # messages. + # NOTE: is not stored in the database at all. + single_message_history: str | None = None, ) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run @@ -707,6 +712,7 @@ def stream_chat_message_objects( ], tools=tools, force_use_tool=_get_force_search_settings(new_msg_req, tools), + single_message_history=single_message_history, ) reference_db_search_docs = None diff --git a/backend/onyx/chat/prompt_builder/build.py b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py similarity index 88% rename from backend/onyx/chat/prompt_builder/build.py rename to backend/onyx/chat/prompt_builder/answer_prompt_builder.py index affca0e93..b1beb211b 100644 --- a/backend/onyx/chat/prompt_builder/build.py +++ b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py @@ -17,6 +17,7 @@ from onyx.llm.utils import check_message_tokens from onyx.llm.utils import message_to_prompt_and_imgs from onyx.natural_language_processing.utils import get_tokenizer from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT +from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK from onyx.prompts.prompt_utils import add_date_time_to_prompt from onyx.prompts.prompt_utils import drop_messages_history_overflow from onyx.tools.force import ForceUseTool @@ -42,11 +43,22 @@ def default_build_system_message( def default_build_user_message( - user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = [] + user_query: str, + prompt_config: PromptConfig, + files: list[InMemoryChatFile] = [], + single_message_history: str | None = None, ) -> HumanMessage: + history_block = ( + HISTORY_BLOCK.format(history_str=single_message_history) + if single_message_history + else "" + ) + user_prompt = ( CHAT_USER_CONTEXT_FREE_PROMPT.format( - task_prompt=prompt_config.task_prompt, user_query=user_query + history_block=history_block, + task_prompt=prompt_config.task_prompt, + user_query=user_query, ) if prompt_config.task_prompt else user_query @@ -64,7 +76,8 @@ class AnswerPromptBuilder: user_message: HumanMessage, message_history: list[PreviousMessage], llm_config: LLMConfig, - raw_user_text: str, + raw_user_query: str, + raw_user_uploaded_files: list[InMemoryChatFile], single_message_history: str | None = None, ) -> None: self.max_tokens = compute_max_llm_input_tokens(llm_config) @@ -83,10 +96,6 @@ class AnswerPromptBuilder: self.history_token_cnts, ) = translate_history_to_basemessages(message_history) - # for cases where like the QA flow where we want to condense the chat history - # into a single message rather than a sequence of User / Assistant messages - self.single_message_history = single_message_history - self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None self.user_message_and_token_cnt = ( user_message, @@ -95,7 +104,10 @@ class AnswerPromptBuilder: self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = [] - self.raw_user_message = raw_user_text + # used for building a new prompt after a tool-call + self.raw_user_query = raw_user_query + self.raw_user_uploaded_files = raw_user_uploaded_files + self.single_message_history = single_message_history def update_system_prompt(self, system_message: SystemMessage | None) -> None: if not system_message: diff --git a/backend/onyx/chat/prompt_builder/citations_prompt.py b/backend/onyx/chat/prompt_builder/citations_prompt.py index 03ff8a657..f2d88cc12 100644 --- a/backend/onyx/chat/prompt_builder/citations_prompt.py +++ b/backend/onyx/chat/prompt_builder/citations_prompt.py @@ -144,9 +144,7 @@ def build_citations_user_message( ) history_block = ( - HISTORY_BLOCK.format(history_str=history_message) + "\n" - if history_message - else "" + HISTORY_BLOCK.format(history_str=history_message) if history_message else "" ) query, img_urls = message_to_prompt_and_imgs(message) diff --git a/backend/onyx/chat/tool_handling/tool_response_handler.py b/backend/onyx/chat/tool_handling/tool_response_handler.py index 1a39e5c8d..0c17693a2 100644 --- a/backend/onyx/chat/tool_handling/tool_response_handler.py +++ b/backend/onyx/chat/tool_handling/tool_response_handler.py @@ -5,7 +5,7 @@ from langchain_core.messages import BaseMessage from langchain_core.messages import ToolCall from onyx.chat.models import ResponsePart -from onyx.chat.prompt_builder.build import LLMCall +from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall from onyx.llm.interfaces import LLM from onyx.tools.force import ForceUseTool from onyx.tools.message import build_tool_message @@ -62,7 +62,7 @@ class ToolResponseHandler: llm_call.force_use_tool.args if llm_call.force_use_tool.args is not None else tool.get_args_for_non_tool_calling_llm( - query=llm_call.prompt_builder.raw_user_message, + query=llm_call.prompt_builder.raw_user_query, history=llm_call.prompt_builder.raw_message_history, llm=llm, force_run=True, @@ -76,7 +76,7 @@ class ToolResponseHandler: else: tool_options = check_which_tools_should_run_for_non_tool_calling_llm( tools=llm_call.tools, - query=llm_call.prompt_builder.raw_user_message, + query=llm_call.prompt_builder.raw_user_query, history=llm_call.prompt_builder.raw_message_history, llm=llm, ) @@ -95,7 +95,7 @@ class ToolResponseHandler: select_single_tool_for_non_tool_calling_llm( tools_and_args=available_tools_and_args, history=llm_call.prompt_builder.raw_message_history, - query=llm_call.prompt_builder.raw_user_message, + query=llm_call.prompt_builder.raw_user_query, llm=llm, ) if available_tools_and_args diff --git a/backend/onyx/onyxbot/slack/handlers/handle_buttons.py b/backend/onyx/onyxbot/slack/handlers/handle_buttons.py index 7b26e0bfe..78eb1783d 100644 --- a/backend/onyx/onyxbot/slack/handlers/handle_buttons.py +++ b/backend/onyx/onyxbot/slack/handlers/handle_buttons.py @@ -127,7 +127,7 @@ def handle_generate_answer_button( channel_to_respond=channel_id, msg_to_respond=cast(str, message_ts or thread_ts), thread_to_respond=cast(str, thread_ts or message_ts), - sender=user_id or None, + sender_id=user_id or None, email=email or None, bypass_filters=True, is_bot_msg=False, diff --git a/backend/onyx/onyxbot/slack/handlers/handle_message.py b/backend/onyx/onyxbot/slack/handlers/handle_message.py index 78e58b85f..ed6a79a96 100644 --- a/backend/onyx/onyxbot/slack/handlers/handle_message.py +++ b/backend/onyx/onyxbot/slack/handlers/handle_message.py @@ -28,12 +28,12 @@ logger_base = setup_logger() def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None: - if details.is_bot_msg and details.sender: + if details.is_bot_msg and details.sender_id: respond_in_thread( client=client, channel=details.channel_to_respond, thread_ts=details.msg_to_respond, - receiver_ids=[details.sender], + receiver_ids=[details.sender_id], text="Hi, we're evaluating your query :face_with_monocle:", ) return @@ -70,7 +70,7 @@ def schedule_feedback_reminder( try: response = client.chat_scheduleMessage( - channel=details.sender, # type:ignore + channel=details.sender_id, # type:ignore post_at=int(future.timestamp()), blocks=[ get_feedback_reminder_blocks( @@ -123,7 +123,7 @@ def handle_message( logger = setup_logger(extra={SLACK_CHANNEL_ID: channel}) messages = message_info.thread_messages - sender_id = message_info.sender + sender_id = message_info.sender_id bypass_filters = message_info.bypass_filters is_bot_msg = message_info.is_bot_msg is_bot_dm = message_info.is_bot_dm diff --git a/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py b/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py index 19e3e2a36..cad549a76 100644 --- a/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py +++ b/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py @@ -126,7 +126,12 @@ def handle_regular_answer( # messages, max_tokens=max_history_tokens, llm_tokenizer=llm_tokenizer # ) - combined_message = slackify_message_thread(messages) + # NOTE: only the message history will contain the person asking. This is likely + # fine since the most common use case for this info is when referring to a user + # who previously posted in the thread. + user_message = messages[-1] + history_messages = messages[:-1] + single_message_history = slackify_message_thread(history_messages) or None bypass_acl = False if ( @@ -159,6 +164,7 @@ def handle_regular_answer( user=onyx_user, db_session=db_session, bypass_acl=bypass_acl, + single_message_history=single_message_history, ) answer = gather_stream_for_slack(packets) @@ -198,7 +204,7 @@ def handle_regular_answer( with get_session_with_tenant(tenant_id) as db_session: answer_request = prepare_chat_message_request( - message_text=combined_message, + message_text=user_message.message, user=user, persona_id=persona.id, # This is not used in the Slack flow, only in the answer API @@ -312,7 +318,7 @@ def handle_regular_answer( top_docs = retrieval_info.top_documents if not top_docs and not should_respond_even_with_no_docs: logger.error( - f"Unable to answer question: '{combined_message}' - no documents found" + f"Unable to answer question: '{user_message}' - no documents found" ) # Optionally, respond in thread with the error message # Used primarily for debugging purposes @@ -371,8 +377,8 @@ def handle_regular_answer( respond_in_thread( client=client, channel=channel, - receiver_ids=[message_info.sender] - if message_info.is_bot_msg and message_info.sender + receiver_ids=[message_info.sender_id] + if message_info.is_bot_msg and message_info.sender_id else receiver_ids, text="Hello! Onyx has some results for you!", blocks=all_blocks, diff --git a/backend/onyx/onyxbot/slack/listener.py b/backend/onyx/onyxbot/slack/listener.py index 4cb3065ae..7624b35c8 100644 --- a/backend/onyx/onyxbot/slack/listener.py +++ b/backend/onyx/onyxbot/slack/listener.py @@ -540,9 +540,9 @@ def build_request_details( tagged = event.get("type") == "app_mention" message_ts = event.get("ts") thread_ts = event.get("thread_ts") - sender = event.get("user") or None + sender_id = event.get("user") or None expert_info = expert_info_from_slack_id( - sender, client.web_client, user_cache={} + sender_id, client.web_client, user_cache={} ) email = expert_info.email if expert_info else None @@ -566,8 +566,21 @@ def build_request_details( channel=channel, thread=thread_ts, client=client.web_client ) else: + sender_display_name = None + if expert_info: + sender_display_name = expert_info.display_name + if sender_display_name is None: + sender_display_name = ( + f"{expert_info.first_name} {expert_info.last_name}" + if expert_info.last_name + else expert_info.first_name + ) + if sender_display_name is None: + sender_display_name = expert_info.email thread_messages = [ - ThreadMessage(message=msg, sender=None, role=MessageType.USER) + ThreadMessage( + message=msg, sender=sender_display_name, role=MessageType.USER + ) ] return SlackMessageInfo( @@ -575,7 +588,7 @@ def build_request_details( channel_to_respond=channel, msg_to_respond=cast(str, message_ts or thread_ts), thread_to_respond=cast(str, thread_ts or message_ts), - sender=sender, + sender_id=sender_id, email=email, bypass_filters=tagged, is_bot_msg=False, @@ -598,7 +611,7 @@ def build_request_details( channel_to_respond=channel, msg_to_respond=None, thread_to_respond=None, - sender=sender, + sender_id=sender, email=email, bypass_filters=True, is_bot_msg=True, @@ -687,7 +700,7 @@ def process_message( if feedback_reminder_id: remove_scheduled_feedback_reminder( client=client.web_client, - channel=details.sender, + channel=details.sender_id, msg_id=feedback_reminder_id, ) # Skipping answering due to pre-filtering is not considered a failure diff --git a/backend/onyx/onyxbot/slack/models.py b/backend/onyx/onyxbot/slack/models.py index 3921e0b9b..f3cb6add2 100644 --- a/backend/onyx/onyxbot/slack/models.py +++ b/backend/onyx/onyxbot/slack/models.py @@ -8,7 +8,7 @@ class SlackMessageInfo(BaseModel): channel_to_respond: str msg_to_respond: str | None thread_to_respond: str | None - sender: str | None + sender_id: str | None email: str | None bypass_filters: bool # User has tagged @OnyxBot is_bot_msg: bool # User is using /OnyxBot diff --git a/backend/onyx/prompts/chat_prompts.py b/backend/onyx/prompts/chat_prompts.py index 56cd279fd..93e467550 100644 --- a/backend/onyx/prompts/chat_prompts.py +++ b/backend/onyx/prompts/chat_prompts.py @@ -31,8 +31,9 @@ CONTEXT: """.strip() +# History block is optional. CHAT_USER_CONTEXT_FREE_PROMPT = f""" -{{task_prompt}} +{{history_block}}{{task_prompt}} {QUESTION_PAT.upper()} {{user_query}} diff --git a/backend/onyx/tools/base_tool.py b/backend/onyx/tools/base_tool.py index 5a434c33e..16ec5d92a 100644 --- a/backend/onyx/tools/base_tool.py +++ b/backend/onyx/tools/base_tool.py @@ -7,7 +7,7 @@ from onyx.llm.utils import message_to_prompt_and_imgs from onyx.tools.tool import Tool if TYPE_CHECKING: - from onyx.chat.prompt_builder.build import AnswerPromptBuilder + from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.tools.tool_implementations.custom.custom_tool import ( CustomToolCallSummary, ) diff --git a/backend/onyx/tools/tool.py b/backend/onyx/tools/tool.py index 814ccd9b3..4a8ba8099 100644 --- a/backend/onyx/tools/tool.py +++ b/backend/onyx/tools/tool.py @@ -9,7 +9,7 @@ from onyx.utils.special_types import JSON_ro if TYPE_CHECKING: - from onyx.chat.prompt_builder.build import AnswerPromptBuilder + from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.tools.message import ToolCallSummary from onyx.tools.models import ToolResponse diff --git a/backend/onyx/tools/tool_implementations/custom/custom_tool.py b/backend/onyx/tools/tool_implementations/custom/custom_tool.py index 892bea972..9a0f8c3d1 100644 --- a/backend/onyx/tools/tool_implementations/custom/custom_tool.py +++ b/backend/onyx/tools/tool_implementations/custom/custom_tool.py @@ -15,7 +15,7 @@ from langchain_core.messages import SystemMessage from pydantic import BaseModel from requests import JSONDecodeError -from onyx.chat.prompt_builder.build import AnswerPromptBuilder +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.configs.constants import FileOrigin from onyx.db.engine import get_session_with_default_tenant from onyx.file_store.file_store import get_default_file_store diff --git a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py index 7f2863bd7..0deee37b1 100644 --- a/backend/onyx/tools/tool_implementations/images/image_generation_tool.py +++ b/backend/onyx/tools/tool_implementations/images/image_generation_tool.py @@ -9,7 +9,7 @@ from litellm import image_generation # type: ignore from pydantic import BaseModel from onyx.chat.chat_utils import combine_message_chain -from onyx.chat.prompt_builder.build import AnswerPromptBuilder +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF from onyx.configs.tool_configs import IMAGE_GENERATION_OUTPUT_FORMAT from onyx.llm.interfaces import LLM diff --git a/backend/onyx/tools/tool_implementations/internet_search/internet_search_tool.py b/backend/onyx/tools/tool_implementations/internet_search/internet_search_tool.py index 0229eef4a..084c03853 100644 --- a/backend/onyx/tools/tool_implementations/internet_search/internet_search_tool.py +++ b/backend/onyx/tools/tool_implementations/internet_search/internet_search_tool.py @@ -10,7 +10,7 @@ from onyx.chat.chat_utils import combine_message_chain from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import LlmDoc from onyx.chat.models import PromptConfig -from onyx.chat.prompt_builder.build import AnswerPromptBuilder +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.configs.constants import DocumentSource from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF from onyx.context.search.models import SearchDoc diff --git a/backend/onyx/tools/tool_implementations/search/search_tool.py b/backend/onyx/tools/tool_implementations/search/search_tool.py index da8bc6a1a..ed8af4c9c 100644 --- a/backend/onyx/tools/tool_implementations/search/search_tool.py +++ b/backend/onyx/tools/tool_implementations/search/search_tool.py @@ -16,7 +16,7 @@ from onyx.chat.models import OnyxContext from onyx.chat.models import OnyxContexts from onyx.chat.models import PromptConfig from onyx.chat.models import SectionRelevancePiece -from onyx.chat.prompt_builder.build import AnswerPromptBuilder +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens from onyx.chat.prune_and_merge import prune_and_merge_sections from onyx.chat.prune_and_merge import prune_sections diff --git a/backend/onyx/tools/tool_implementations/search_like_tool_utils.py b/backend/onyx/tools/tool_implementations/search_like_tool_utils.py index 44dc8f2c3..cf4dfda08 100644 --- a/backend/onyx/tools/tool_implementations/search_like_tool_utils.py +++ b/backend/onyx/tools/tool_implementations/search_like_tool_utils.py @@ -5,12 +5,12 @@ from langchain_core.messages import HumanMessage from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import LlmDoc from onyx.chat.models import PromptConfig -from onyx.chat.prompt_builder.build import AnswerPromptBuilder +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.prompt_builder.citations_prompt import ( build_citations_system_message, ) from onyx.chat.prompt_builder.citations_prompt import build_citations_user_message -from onyx.chat.prompt_builder.quotes_prompt import build_quotes_user_message +from onyx.llm.utils import build_content_with_imgs from onyx.tools.message import ToolCallSummary from onyx.tools.models import ToolResponse @@ -40,37 +40,27 @@ def build_next_prompt_for_search_like_tool( # if using tool calling llm, then the final context documents are the tool responses final_context_documents = [] - if answer_style_config.citation_config: - prompt_builder.update_system_prompt( - build_citations_system_message(prompt_config) - ) - prompt_builder.update_user_prompt( - build_citations_user_message( - message=prompt_builder.user_message_and_token_cnt[0], - prompt_config=prompt_config, - context_docs=final_context_documents, - all_doc_useful=( - answer_style_config.citation_config.all_docs_useful - if answer_style_config.citation_config - else False - ), - history_message=prompt_builder.single_message_history or "", - ) - ) - elif answer_style_config.quotes_config: - # For Quotes, the system prompt is included in the user prompt - prompt_builder.update_system_prompt(None) - - human_message = HumanMessage(content=prompt_builder.raw_user_message) - - prompt_builder.update_user_prompt( - build_quotes_user_message( - message=human_message, - context_docs=final_context_documents, - history_str=prompt_builder.single_message_history or "", - prompt=prompt_config, - ) + prompt_builder.update_system_prompt(build_citations_system_message(prompt_config)) + prompt_builder.update_user_prompt( + build_citations_user_message( + # make sure to use the original user query here in order to avoid duplication + # of the task prompt + message=HumanMessage( + content=build_content_with_imgs( + prompt_builder.raw_user_query, + prompt_builder.raw_user_uploaded_files, + ) + ), + prompt_config=prompt_config, + context_docs=final_context_documents, + all_doc_useful=( + answer_style_config.citation_config.all_docs_useful + if answer_style_config.citation_config + else False + ), + history_message=prompt_builder.single_message_history or "", ) + ) if using_tool_calling_llm: prompt_builder.append_message(tool_call_summary.tool_call_request) diff --git a/backend/tests/unit/onyx/chat/conftest.py b/backend/tests/unit/onyx/chat/conftest.py index 5ee156012..138668b19 100644 --- a/backend/tests/unit/onyx/chat/conftest.py +++ b/backend/tests/unit/onyx/chat/conftest.py @@ -9,7 +9,7 @@ from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationConfig from onyx.chat.models import LlmDoc from onyx.chat.models import PromptConfig -from onyx.chat.prompt_builder.build import AnswerPromptBuilder +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.configs.constants import DocumentSource from onyx.llm.interfaces import LLMConfig from onyx.tools.models import ToolResponse