mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-30 09:40:50 +02:00
Tool call per message (#3025)
* single tool call per message * finalize migration * minor image generation fix * validate simplify * k * remove print * validated
This commit is contained in:
@ -0,0 +1,50 @@
|
||||
"""single tool call per message
|
||||
|
||||
Revision ID: 33cb72ea4d80
|
||||
Revises: 5b29123cd710
|
||||
Create Date: 2024-11-01 12:51:01.535003
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33cb72ea4d80"
|
||||
down_revision = "5b29123cd710"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Delete extraneous ToolCall entries
|
||||
# Keep only the ToolCall with the smallest 'id' for each 'message_id'
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM tool_call
|
||||
WHERE id NOT IN (
|
||||
SELECT MIN(id)
|
||||
FROM tool_call
|
||||
WHERE message_id IS NOT NULL
|
||||
GROUP BY message_id
|
||||
);
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Add a unique constraint on message_id
|
||||
op.create_unique_constraint(
|
||||
constraint_name="uq_tool_call_message_id",
|
||||
table_name="tool_call",
|
||||
columns=["message_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Step 1: Drop the unique constraint on message_id
|
||||
op.drop_constraint(
|
||||
constraint_name="uq_tool_call_message_id",
|
||||
table_name="tool_call",
|
||||
type_="unique",
|
||||
)
|
@ -864,17 +864,15 @@ def stream_chat_message_objects(
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_calls=(
|
||||
[
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else []
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -388,7 +388,7 @@ def get_chat_messages_by_session(
|
||||
)
|
||||
|
||||
if prefetch_tool_calls:
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_call))
|
||||
result = db_session.scalars(stmt).unique().all()
|
||||
else:
|
||||
result = db_session.scalars(stmt).all()
|
||||
@ -474,7 +474,7 @@ def create_new_chat_message(
|
||||
alternate_assistant_id: int | None = None,
|
||||
# Maps the citation number [n] to the DB SearchDoc
|
||||
citations: dict[int, int] | None = None,
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
tool_call: ToolCall | None = None,
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
@ -494,7 +494,7 @@ def create_new_chat_message(
|
||||
existing_message.message_type = message_type
|
||||
existing_message.citations = citations
|
||||
existing_message.files = files
|
||||
existing_message.tool_calls = tool_calls if tool_calls else []
|
||||
existing_message.tool_call = tool_call
|
||||
existing_message.error = error
|
||||
existing_message.alternate_assistant_id = alternate_assistant_id
|
||||
existing_message.overridden_model = overridden_model
|
||||
@ -513,7 +513,7 @@ def create_new_chat_message(
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
files=files,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
tool_call=tool_call,
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
overridden_model=overridden_model,
|
||||
@ -749,14 +749,13 @@ def translate_db_message_to_chat_message_detail(
|
||||
time_sent=chat_message.time_sent,
|
||||
citations=chat_message.citations,
|
||||
files=chat_message.files or [],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
)
|
||||
|
@ -918,10 +918,15 @@ class ToolCall(Base):
|
||||
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
|
||||
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
|
||||
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
message_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("chat_message.id"), nullable=False
|
||||
)
|
||||
|
||||
# Update the relationship
|
||||
message: Mapped["ChatMessage"] = relationship(
|
||||
"ChatMessage", back_populates="tool_calls"
|
||||
"ChatMessage",
|
||||
back_populates="tool_call",
|
||||
uselist=False,
|
||||
)
|
||||
|
||||
|
||||
@ -1052,12 +1057,13 @@ class ChatMessage(Base):
|
||||
secondary=ChatMessage__SearchDoc.__table__,
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_calls: Mapped[list["ToolCall"]] = relationship(
|
||||
|
||||
tool_call: Mapped["ToolCall"] = relationship(
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
uselist=False,
|
||||
)
|
||||
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
|
@ -8,12 +8,13 @@ import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def load_chat_file(
|
||||
@ -52,11 +53,11 @@ def load_all_chat_files(
|
||||
return files
|
||||
|
||||
|
||||
def save_file_from_url(url: str) -> str:
|
||||
def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
"""NOTE: using multiple sessions here, since this is often called
|
||||
using multithreading. In practice, sharing a session has resulted in
|
||||
weird errors."""
|
||||
with get_session_context_manager() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
@ -75,7 +76,10 @@ def save_file_from_url(url: str) -> str:
|
||||
|
||||
|
||||
def save_files_from_urls(urls: list[str]) -> list[str]:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
|
||||
(save_file_from_url, (url,)) for url in urls
|
||||
(save_file_from_url, (url, tenant_id)) for url in urls
|
||||
]
|
||||
# Must pass in tenant_id here, since this is called by multithreading
|
||||
return run_functions_tuples_in_parallel(funcs)
|
||||
|
@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
@ -51,14 +51,13 @@ class PreviousMessage(BaseModel):
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
|
@ -83,8 +83,10 @@ def _convert_litellm_message_to_langchain_message(
|
||||
"args": json.loads(tool_call.function.arguments),
|
||||
"id": tool_call.id,
|
||||
}
|
||||
for tool_call in (tool_calls if tool_calls else [])
|
||||
],
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
|
@ -188,7 +188,7 @@ class ChatMessageDetail(BaseModel):
|
||||
chat_session_id: UUID | None = None
|
||||
citations: dict[int, int] | None = None
|
||||
files: list[FileDescriptor]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
|
@ -277,13 +277,13 @@ export function ChatPage({
|
||||
if (
|
||||
lastMessage &&
|
||||
lastMessage.type === "assistant" &&
|
||||
lastMessage.toolCalls[0] &&
|
||||
lastMessage.toolCalls[0].tool_result === undefined
|
||||
lastMessage.toolCall &&
|
||||
lastMessage.toolCall.tool_result === undefined
|
||||
) {
|
||||
const newCompleteMessageMap = new Map(
|
||||
currentMessageMap(completeMessageDetail)
|
||||
);
|
||||
const updatedMessage = { ...lastMessage, toolCalls: [] };
|
||||
const updatedMessage = { ...lastMessage, toolCall: null };
|
||||
newCompleteMessageMap.set(lastMessage.messageId, updatedMessage);
|
||||
updateCompleteMessageDetail(currentSession, newCompleteMessageMap);
|
||||
}
|
||||
@ -513,7 +513,7 @@ export function ChatPage({
|
||||
message: "",
|
||||
type: "system",
|
||||
files: [],
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId: null,
|
||||
childrenMessageIds: [firstMessageId],
|
||||
latestChildMessageId: firstMessageId,
|
||||
@ -1104,7 +1104,7 @@ export function ChatPage({
|
||||
let stackTrace: string | null = null;
|
||||
|
||||
let finalMessage: BackendMessage | null = null;
|
||||
let toolCalls: ToolCallMetadata[] = [];
|
||||
let toolCall: ToolCallMetadata | null = null;
|
||||
|
||||
let initialFetchDetails: null | {
|
||||
user_message_id: number;
|
||||
@ -1209,7 +1209,7 @@ export function ChatPage({
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
|
||||
},
|
||||
];
|
||||
@ -1262,17 +1262,14 @@ export function ChatPage({
|
||||
setSelectedMessageForDocDisplay(user_message_id);
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "tool_name")) {
|
||||
toolCalls = [
|
||||
{
|
||||
// Will only ever be one tool call per message
|
||||
toolCall = {
|
||||
tool_name: (packet as ToolCallMetadata).tool_name,
|
||||
tool_args: (packet as ToolCallMetadata).tool_args,
|
||||
tool_result: (packet as ToolCallMetadata).tool_result,
|
||||
},
|
||||
];
|
||||
if (
|
||||
!toolCalls[0].tool_result ||
|
||||
toolCalls[0].tool_result == undefined
|
||||
) {
|
||||
};
|
||||
|
||||
if (!toolCall.tool_result || toolCall.tool_result == undefined) {
|
||||
updateChatState("toolBuilding", frozenSessionId);
|
||||
} else {
|
||||
updateChatState("streaming", frozenSessionId);
|
||||
@ -1280,8 +1277,8 @@ export function ChatPage({
|
||||
|
||||
// This will be consolidated in upcoming tool calls udpate,
|
||||
// but for now, we need to set query as early as possible
|
||||
if (toolCalls[0].tool_name == SEARCH_TOOL_NAME) {
|
||||
query = toolCalls[0].tool_args["query"];
|
||||
if (toolCall.tool_name == SEARCH_TOOL_NAME) {
|
||||
query = toolCall.tool_args["query"];
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "file_ids")) {
|
||||
aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map(
|
||||
@ -1339,7 +1336,7 @@ export function ChatPage({
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId: error ? null : lastSuccessfulMessageId,
|
||||
childrenMessageIds: [
|
||||
...(regenerationRequest?.parentMessage?.childrenMessageIds ||
|
||||
@ -1358,7 +1355,7 @@ export function ChatPage({
|
||||
finalMessage?.context_docs?.top_documents || documents,
|
||||
citations: finalMessage?.citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCalls: finalMessage?.tool_calls || toolCalls,
|
||||
toolCall: finalMessage?.tool_call || null,
|
||||
parentMessageId: regenerationRequest
|
||||
? regenerationRequest?.parentMessage?.messageId!
|
||||
: initialFetchDetails.user_message_id,
|
||||
@ -1381,7 +1378,7 @@ export function ChatPage({
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
|
||||
},
|
||||
{
|
||||
@ -1391,7 +1388,7 @@ export function ChatPage({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
files: aiMessageImages || [],
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId:
|
||||
initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID,
|
||||
},
|
||||
@ -2238,10 +2235,7 @@ export function ChatPage({
|
||||
citedDocuments={getCitedDocumentsFromMessage(
|
||||
message
|
||||
)}
|
||||
toolCall={
|
||||
message.toolCalls &&
|
||||
message.toolCalls[0]
|
||||
}
|
||||
toolCall={message.toolCall}
|
||||
isComplete={
|
||||
i !== messageHistory.length - 1 ||
|
||||
(currentSessionChatState !=
|
||||
|
@ -86,7 +86,7 @@ export interface Message {
|
||||
documents?: DanswerDocument[] | null;
|
||||
citations?: CitationMap;
|
||||
files: FileDescriptor[];
|
||||
toolCalls: ToolCallMetadata[];
|
||||
toolCall: ToolCallMetadata | null;
|
||||
// for rebuilding the message tree
|
||||
parentMessageId: number | null;
|
||||
childrenMessageIds?: number[];
|
||||
@ -121,7 +121,7 @@ export interface BackendMessage {
|
||||
time_sent: string;
|
||||
citations: CitationMap;
|
||||
files: FileDescriptor[];
|
||||
tool_calls: ToolCallFinalResult[];
|
||||
tool_call: ToolCallFinalResult | null;
|
||||
alternate_assistant_id?: number | null;
|
||||
overridden_model?: string;
|
||||
}
|
||||
|
@ -428,7 +428,7 @@ export function processRawChatHistory(
|
||||
citations: messageInfo?.citations || {},
|
||||
}
|
||||
: {}),
|
||||
toolCalls: messageInfo.tool_calls,
|
||||
toolCall: messageInfo.tool_call,
|
||||
parentMessageId: messageInfo.parent_message,
|
||||
childrenMessageIds: [],
|
||||
latestChildMessageId: messageInfo.latest_child_message,
|
||||
|
@ -189,7 +189,7 @@ export const AIMessage = ({
|
||||
files?: FileDescriptor[];
|
||||
query?: string;
|
||||
citedDocuments?: [string, DanswerDocument][] | null;
|
||||
toolCall?: ToolCallMetadata;
|
||||
toolCall?: ToolCallMetadata | null;
|
||||
isComplete?: boolean;
|
||||
hasDocs?: boolean;
|
||||
handleFeedback?: (feedbackType: FeedbackType) => void;
|
||||
|
Reference in New Issue
Block a user