diff --git a/backend/alembic/versions/33cb72ea4d80_single_tool_call_per_message.py b/backend/alembic/versions/33cb72ea4d80_single_tool_call_per_message.py new file mode 100644 index 0000000000..0cd3da444b --- /dev/null +++ b/backend/alembic/versions/33cb72ea4d80_single_tool_call_per_message.py @@ -0,0 +1,50 @@ +"""single tool call per message + +Revision ID: 33cb72ea4d80 +Revises: 5b29123cd710 +Create Date: 2024-11-01 12:51:01.535003 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "33cb72ea4d80" +down_revision = "5b29123cd710" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Step 1: Delete extraneous ToolCall entries + # Keep only the ToolCall with the smallest 'id' for each 'message_id' + op.execute( + sa.text( + """ + DELETE FROM tool_call + WHERE id NOT IN ( + SELECT MIN(id) + FROM tool_call + WHERE message_id IS NOT NULL + GROUP BY message_id + ); + """ + ) + ) + + # Step 2: Add a unique constraint on message_id + op.create_unique_constraint( + constraint_name="uq_tool_call_message_id", + table_name="tool_call", + columns=["message_id"], + ) + + +def downgrade() -> None: + # Step 1: Drop the unique constraint on message_id + op.drop_constraint( + constraint_name="uq_tool_call_message_id", + table_name="tool_call", + type_="unique", + ) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 0394f34b82..252f6df0f1 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -864,17 +864,15 @@ def stream_chat_message_objects( if message_specific_citations else None, error=None, - tool_calls=( - [ - ToolCall( - tool_id=tool_name_to_tool_id[tool_result.tool_name], - tool_name=tool_result.tool_name, - tool_arguments=tool_result.tool_args, - tool_result=tool_result.tool_result, - ) - ] + tool_call=( + ToolCall( + tool_id=tool_name_to_tool_id[tool_result.tool_name], + tool_name=tool_result.tool_name, + tool_arguments=tool_result.tool_args, + tool_result=tool_result.tool_result, + ) if tool_result - else [] + else None ), ) diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index d885e1efd6..4aaee09297 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -388,7 +388,7 @@ def get_chat_messages_by_session( ) if prefetch_tool_calls: - stmt = stmt.options(joinedload(ChatMessage.tool_calls)) + stmt = stmt.options(joinedload(ChatMessage.tool_call)) result = db_session.scalars(stmt).unique().all() else: result = db_session.scalars(stmt).all() @@ -474,7 +474,7 @@ def create_new_chat_message( alternate_assistant_id: int | None = None, # Maps the citation number [n] to the DB SearchDoc citations: dict[int, int] | None = None, - tool_calls: list[ToolCall] | None = None, + tool_call: ToolCall | None = None, commit: bool = True, reserved_message_id: int | None = None, overridden_model: str | None = None, @@ -494,7 +494,7 @@ def create_new_chat_message( existing_message.message_type = message_type existing_message.citations = citations existing_message.files = files - existing_message.tool_calls = tool_calls if tool_calls else [] + existing_message.tool_call = tool_call existing_message.error = error existing_message.alternate_assistant_id = alternate_assistant_id existing_message.overridden_model = overridden_model @@ -513,7 +513,7 @@ def create_new_chat_message( message_type=message_type, citations=citations, files=files, - tool_calls=tool_calls if tool_calls else [], + tool_call=tool_call, error=error, alternate_assistant_id=alternate_assistant_id, overridden_model=overridden_model, @@ -749,14 +749,13 @@ def translate_db_message_to_chat_message_detail( time_sent=chat_message.time_sent, citations=chat_message.citations, files=chat_message.files or [], - tool_calls=[ - ToolCallFinalResult( - tool_name=tool_call.tool_name, - tool_args=tool_call.tool_arguments, - tool_result=tool_call.tool_result, - ) - for tool_call in chat_message.tool_calls - ], + tool_call=ToolCallFinalResult( + tool_name=chat_message.tool_call.tool_name, + tool_args=chat_message.tool_call.tool_arguments, + tool_result=chat_message.tool_call.tool_result, + ) + if chat_message.tool_call + else None, alternate_assistant_id=chat_message.alternate_assistant_id, overridden_model=chat_message.overridden_model, ) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 3f63d1d0db..3ff2133155 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -918,10 +918,15 @@ class ToolCall(Base): tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB()) tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB()) - message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id")) + message_id: Mapped[int | None] = mapped_column( + ForeignKey("chat_message.id"), nullable=False + ) + # Update the relationship message: Mapped["ChatMessage"] = relationship( - "ChatMessage", back_populates="tool_calls" + "ChatMessage", + back_populates="tool_call", + uselist=False, ) @@ -1052,12 +1057,13 @@ class ChatMessage(Base): secondary=ChatMessage__SearchDoc.__table__, back_populates="chat_messages", ) - # NOTE: Should always be attached to the `assistant` message. - # represents the tool calls used to generate this message - tool_calls: Mapped[list["ToolCall"]] = relationship( + + tool_call: Mapped["ToolCall"] = relationship( "ToolCall", back_populates="message", + uselist=False, ) + standard_answers: Mapped[list["StandardAnswer"]] = relationship( "StandardAnswer", secondary=ChatMessage__StandardAnswer.__table__, diff --git a/backend/danswer/file_store/utils.py b/backend/danswer/file_store/utils.py index b71d20bbbb..e9eea2c262 100644 --- a/backend/danswer/file_store/utils.py +++ b/backend/danswer/file_store/utils.py @@ -8,12 +8,13 @@ import requests from sqlalchemy.orm import Session from danswer.configs.constants import FileOrigin -from danswer.db.engine import get_session_context_manager +from danswer.db.engine import get_session_with_tenant from danswer.db.models import ChatMessage from danswer.file_store.file_store import get_default_file_store from danswer.file_store.models import FileDescriptor from danswer.file_store.models import InMemoryChatFile from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR def load_chat_file( @@ -52,11 +53,11 @@ def load_all_chat_files( return files -def save_file_from_url(url: str) -> str: +def save_file_from_url(url: str, tenant_id: str) -> str: """NOTE: using multiple sessions here, since this is often called using multithreading. In practice, sharing a session has resulted in weird errors.""" - with get_session_context_manager() as db_session: + with get_session_with_tenant(tenant_id) as db_session: response = requests.get(url) response.raise_for_status() @@ -75,7 +76,10 @@ def save_file_from_url(url: str) -> str: def save_files_from_urls(urls: list[str]) -> list[str]: + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() + funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [ - (save_file_from_url, (url,)) for url in urls + (save_file_from_url, (url, tenant_id)) for url in urls ] + # Must pass in tenant_id here, since this is called by multithreading return run_functions_tuples_in_parallel(funcs) diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py index 87c1297fe9..03f72a0968 100644 --- a/backend/danswer/llm/answering/models.py +++ b/backend/danswer/llm/answering/models.py @@ -33,7 +33,7 @@ class PreviousMessage(BaseModel): token_count: int message_type: MessageType files: list[InMemoryChatFile] - tool_calls: list[ToolCallFinalResult] + tool_call: ToolCallFinalResult | None @classmethod def from_chat_message( @@ -51,14 +51,13 @@ class PreviousMessage(BaseModel): for file in available_files if str(file.file_id) in message_file_ids ], - tool_calls=[ - ToolCallFinalResult( - tool_name=tool_call.tool_name, - tool_args=tool_call.tool_arguments, - tool_result=tool_call.tool_result, - ) - for tool_call in chat_message.tool_calls - ], + tool_call=ToolCallFinalResult( + tool_name=chat_message.tool_call.tool_name, + tool_args=chat_message.tool_call.tool_arguments, + tool_result=chat_message.tool_call.tool_result, + ) + if chat_message.tool_call + else None, ) def to_langchain_msg(self) -> BaseMessage: diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index d450fff0a6..6bd4b51d40 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -83,8 +83,10 @@ def _convert_litellm_message_to_langchain_message( "args": json.loads(tool_call.function.arguments), "id": tool_call.id, } - for tool_call in (tool_calls if tool_calls else []) - ], + for tool_call in tool_calls + ] + if tool_calls + else [], ) elif role == "system": return SystemMessage(content=content) diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 1ca14f9283..13b3b1ec0a 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -188,7 +188,7 @@ class ChatMessageDetail(BaseModel): chat_session_id: UUID | None = None citations: dict[int, int] | None = None files: list[FileDescriptor] - tool_calls: list[ToolCallFinalResult] + tool_call: ToolCallFinalResult | None def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 03b53fe099..dcf1baf14d 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -277,13 +277,13 @@ export function ChatPage({ if ( lastMessage && lastMessage.type === "assistant" && - lastMessage.toolCalls[0] && - lastMessage.toolCalls[0].tool_result === undefined + lastMessage.toolCall && + lastMessage.toolCall.tool_result === undefined ) { const newCompleteMessageMap = new Map( currentMessageMap(completeMessageDetail) ); - const updatedMessage = { ...lastMessage, toolCalls: [] }; + const updatedMessage = { ...lastMessage, toolCall: null }; newCompleteMessageMap.set(lastMessage.messageId, updatedMessage); updateCompleteMessageDetail(currentSession, newCompleteMessageMap); } @@ -513,7 +513,7 @@ export function ChatPage({ message: "", type: "system", files: [], - toolCalls: [], + toolCall: null, parentMessageId: null, childrenMessageIds: [firstMessageId], latestChildMessageId: firstMessageId, @@ -1104,7 +1104,7 @@ export function ChatPage({ let stackTrace: string | null = null; let finalMessage: BackendMessage | null = null; - let toolCalls: ToolCallMetadata[] = []; + let toolCall: ToolCallMetadata | null = null; let initialFetchDetails: null | { user_message_id: number; @@ -1209,7 +1209,7 @@ export function ChatPage({ message: currMessage, type: "user", files: currentMessageFiles, - toolCalls: [], + toolCall: null, parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID, }, ]; @@ -1262,17 +1262,14 @@ export function ChatPage({ setSelectedMessageForDocDisplay(user_message_id); } } else if (Object.hasOwn(packet, "tool_name")) { - toolCalls = [ - { - tool_name: (packet as ToolCallMetadata).tool_name, - tool_args: (packet as ToolCallMetadata).tool_args, - tool_result: (packet as ToolCallMetadata).tool_result, - }, - ]; - if ( - !toolCalls[0].tool_result || - toolCalls[0].tool_result == undefined - ) { + // Will only ever be one tool call per message + toolCall = { + tool_name: (packet as ToolCallMetadata).tool_name, + tool_args: (packet as ToolCallMetadata).tool_args, + tool_result: (packet as ToolCallMetadata).tool_result, + }; + + if (!toolCall.tool_result || toolCall.tool_result == undefined) { updateChatState("toolBuilding", frozenSessionId); } else { updateChatState("streaming", frozenSessionId); @@ -1280,8 +1277,8 @@ export function ChatPage({ // This will be consolidated in upcoming tool calls udpate, // but for now, we need to set query as early as possible - if (toolCalls[0].tool_name == SEARCH_TOOL_NAME) { - query = toolCalls[0].tool_args["query"]; + if (toolCall.tool_name == SEARCH_TOOL_NAME) { + query = toolCall.tool_args["query"]; } } else if (Object.hasOwn(packet, "file_ids")) { aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map( @@ -1339,7 +1336,7 @@ export function ChatPage({ message: currMessage, type: "user", files: currentMessageFiles, - toolCalls: [], + toolCall: null, parentMessageId: error ? null : lastSuccessfulMessageId, childrenMessageIds: [ ...(regenerationRequest?.parentMessage?.childrenMessageIds || @@ -1358,7 +1355,7 @@ export function ChatPage({ finalMessage?.context_docs?.top_documents || documents, citations: finalMessage?.citations || {}, files: finalMessage?.files || aiMessageImages || [], - toolCalls: finalMessage?.tool_calls || toolCalls, + toolCall: finalMessage?.tool_call || null, parentMessageId: regenerationRequest ? regenerationRequest?.parentMessage?.messageId! : initialFetchDetails.user_message_id, @@ -1381,7 +1378,7 @@ export function ChatPage({ message: currMessage, type: "user", files: currentMessageFiles, - toolCalls: [], + toolCall: null, parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID, }, { @@ -1391,7 +1388,7 @@ export function ChatPage({ message: errorMsg, type: "error", files: aiMessageImages || [], - toolCalls: [], + toolCall: null, parentMessageId: initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID, }, @@ -2238,10 +2235,7 @@ export function ChatPage({ citedDocuments={getCitedDocumentsFromMessage( message )} - toolCall={ - message.toolCalls && - message.toolCalls[0] - } + toolCall={message.toolCall} isComplete={ i !== messageHistory.length - 1 || (currentSessionChatState != diff --git a/web/src/app/chat/interfaces.ts b/web/src/app/chat/interfaces.ts index f77bf50d6a..20ea4e7d28 100644 --- a/web/src/app/chat/interfaces.ts +++ b/web/src/app/chat/interfaces.ts @@ -86,7 +86,7 @@ export interface Message { documents?: DanswerDocument[] | null; citations?: CitationMap; files: FileDescriptor[]; - toolCalls: ToolCallMetadata[]; + toolCall: ToolCallMetadata | null; // for rebuilding the message tree parentMessageId: number | null; childrenMessageIds?: number[]; @@ -121,7 +121,7 @@ export interface BackendMessage { time_sent: string; citations: CitationMap; files: FileDescriptor[]; - tool_calls: ToolCallFinalResult[]; + tool_call: ToolCallFinalResult | null; alternate_assistant_id?: number | null; overridden_model?: string; } diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index 38fdac037a..cda037a08b 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -428,7 +428,7 @@ export function processRawChatHistory( citations: messageInfo?.citations || {}, } : {}), - toolCalls: messageInfo.tool_calls, + toolCall: messageInfo.tool_call, parentMessageId: messageInfo.parent_message, childrenMessageIds: [], latestChildMessageId: messageInfo.latest_child_message, diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index 101bcdbb67..2939c74d3f 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -189,7 +189,7 @@ export const AIMessage = ({ files?: FileDescriptor[]; query?: string; citedDocuments?: [string, DanswerDocument][] | null; - toolCall?: ToolCallMetadata; + toolCall?: ToolCallMetadata | null; isComplete?: boolean; hasDocs?: boolean; handleFeedback?: (feedbackType: FeedbackType) => void;