Tool call per message (#3025)

* single tool call per message

* finalize migration

* minor image generation fix

* validate simplify

* k

* remove print

* validated
This commit is contained in:
pablodanswer
2024-11-03 10:51:51 -08:00
committed by GitHub
parent a7002dfa1d
commit 51b79f688a
12 changed files with 126 additions and 74 deletions

View File

@ -0,0 +1,50 @@
"""single tool call per message
Revision ID: 33cb72ea4d80
Revises: 5b29123cd710
Create Date: 2024-11-01 12:51:01.535003
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "33cb72ea4d80"
down_revision = "5b29123cd710"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Step 1: Delete extraneous ToolCall entries
# Keep only the ToolCall with the smallest 'id' for each 'message_id'
op.execute(
sa.text(
"""
DELETE FROM tool_call
WHERE id NOT IN (
SELECT MIN(id)
FROM tool_call
WHERE message_id IS NOT NULL
GROUP BY message_id
);
"""
)
)
# Step 2: Add a unique constraint on message_id
op.create_unique_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
columns=["message_id"],
)
def downgrade() -> None:
# Step 1: Drop the unique constraint on message_id
op.drop_constraint(
constraint_name="uq_tool_call_message_id",
table_name="tool_call",
type_="unique",
)

View File

@ -864,17 +864,15 @@ def stream_chat_message_objects(
if message_specific_citations
else None,
error=None,
tool_calls=(
[
tool_call=(
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else []
else None
),
)

View File

@ -388,7 +388,7 @@ def get_chat_messages_by_session(
)
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
stmt = stmt.options(joinedload(ChatMessage.tool_call))
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
@ -474,7 +474,7 @@ def create_new_chat_message(
alternate_assistant_id: int | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
tool_calls: list[ToolCall] | None = None,
tool_call: ToolCall | None = None,
commit: bool = True,
reserved_message_id: int | None = None,
overridden_model: str | None = None,
@ -494,7 +494,7 @@ def create_new_chat_message(
existing_message.message_type = message_type
existing_message.citations = citations
existing_message.files = files
existing_message.tool_calls = tool_calls if tool_calls else []
existing_message.tool_call = tool_call
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
@ -513,7 +513,7 @@ def create_new_chat_message(
message_type=message_type,
citations=citations,
files=files,
tool_calls=tool_calls if tool_calls else [],
tool_call=tool_call,
error=error,
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
@ -749,14 +749,13 @@ def translate_db_message_to_chat_message_detail(
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
if chat_message.tool_call
else None,
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)

View File

@ -918,10 +918,15 @@ class ToolCall(Base):
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
message_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_message.id"), nullable=False
)
# Update the relationship
message: Mapped["ChatMessage"] = relationship(
"ChatMessage", back_populates="tool_calls"
"ChatMessage",
back_populates="tool_call",
uselist=False,
)
@ -1052,12 +1057,13 @@ class ChatMessage(Base):
secondary=ChatMessage__SearchDoc.__table__,
back_populates="chat_messages",
)
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_calls: Mapped[list["ToolCall"]] = relationship(
tool_call: Mapped["ToolCall"] = relationship(
"ToolCall",
back_populates="message",
uselist=False,
)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=ChatMessage__StandardAnswer.__table__,

View File

@ -8,12 +8,13 @@ import requests
from sqlalchemy.orm import Session
from danswer.configs.constants import FileOrigin
from danswer.db.engine import get_session_context_manager
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChatMessage
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import FileDescriptor
from danswer.file_store.models import InMemoryChatFile
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
def load_chat_file(
@ -52,11 +53,11 @@ def load_all_chat_files(
return files
def save_file_from_url(url: str) -> str:
def save_file_from_url(url: str, tenant_id: str) -> str:
"""NOTE: using multiple sessions here, since this is often called
using multithreading. In practice, sharing a session has resulted in
weird errors."""
with get_session_context_manager() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
response = requests.get(url)
response.raise_for_status()
@ -75,7 +76,10 @@ def save_file_from_url(url: str) -> str:
def save_files_from_urls(urls: list[str]) -> list[str]:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
(save_file_from_url, (url,)) for url in urls
(save_file_from_url, (url, tenant_id)) for url in urls
]
# Must pass in tenant_id here, since this is called by multithreading
return run_functions_tuples_in_parallel(funcs)

View File

@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_calls: list[ToolCallFinalResult]
tool_call: ToolCallFinalResult | None
@classmethod
def from_chat_message(
@ -51,14 +51,13 @@ class PreviousMessage(BaseModel):
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
if chat_message.tool_call
else None,
)
def to_langchain_msg(self) -> BaseMessage:

View File

@ -83,8 +83,10 @@ def _convert_litellm_message_to_langchain_message(
"args": json.loads(tool_call.function.arguments),
"id": tool_call.id,
}
for tool_call in (tool_calls if tool_calls else [])
],
for tool_call in tool_calls
]
if tool_calls
else [],
)
elif role == "system":
return SystemMessage(content=content)

View File

@ -188,7 +188,7 @@ class ChatMessageDetail(BaseModel):
chat_session_id: UUID | None = None
citations: dict[int, int] | None = None
files: list[FileDescriptor]
tool_calls: list[ToolCallFinalResult]
tool_call: ToolCallFinalResult | None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore

View File

@ -277,13 +277,13 @@ export function ChatPage({
if (
lastMessage &&
lastMessage.type === "assistant" &&
lastMessage.toolCalls[0] &&
lastMessage.toolCalls[0].tool_result === undefined
lastMessage.toolCall &&
lastMessage.toolCall.tool_result === undefined
) {
const newCompleteMessageMap = new Map(
currentMessageMap(completeMessageDetail)
);
const updatedMessage = { ...lastMessage, toolCalls: [] };
const updatedMessage = { ...lastMessage, toolCall: null };
newCompleteMessageMap.set(lastMessage.messageId, updatedMessage);
updateCompleteMessageDetail(currentSession, newCompleteMessageMap);
}
@ -513,7 +513,7 @@ export function ChatPage({
message: "",
type: "system",
files: [],
toolCalls: [],
toolCall: null,
parentMessageId: null,
childrenMessageIds: [firstMessageId],
latestChildMessageId: firstMessageId,
@ -1104,7 +1104,7 @@ export function ChatPage({
let stackTrace: string | null = null;
let finalMessage: BackendMessage | null = null;
let toolCalls: ToolCallMetadata[] = [];
let toolCall: ToolCallMetadata | null = null;
let initialFetchDetails: null | {
user_message_id: number;
@ -1209,7 +1209,7 @@ export function ChatPage({
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
toolCall: null,
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
},
];
@ -1262,17 +1262,14 @@ export function ChatPage({
setSelectedMessageForDocDisplay(user_message_id);
}
} else if (Object.hasOwn(packet, "tool_name")) {
toolCalls = [
{
// Will only ever be one tool call per message
toolCall = {
tool_name: (packet as ToolCallMetadata).tool_name,
tool_args: (packet as ToolCallMetadata).tool_args,
tool_result: (packet as ToolCallMetadata).tool_result,
},
];
if (
!toolCalls[0].tool_result ||
toolCalls[0].tool_result == undefined
) {
};
if (!toolCall.tool_result || toolCall.tool_result == undefined) {
updateChatState("toolBuilding", frozenSessionId);
} else {
updateChatState("streaming", frozenSessionId);
@ -1280,8 +1277,8 @@ export function ChatPage({
// This will be consolidated in upcoming tool calls udpate,
// but for now, we need to set query as early as possible
if (toolCalls[0].tool_name == SEARCH_TOOL_NAME) {
query = toolCalls[0].tool_args["query"];
if (toolCall.tool_name == SEARCH_TOOL_NAME) {
query = toolCall.tool_args["query"];
}
} else if (Object.hasOwn(packet, "file_ids")) {
aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map(
@ -1339,7 +1336,7 @@ export function ChatPage({
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
toolCall: null,
parentMessageId: error ? null : lastSuccessfulMessageId,
childrenMessageIds: [
...(regenerationRequest?.parentMessage?.childrenMessageIds ||
@ -1358,7 +1355,7 @@ export function ChatPage({
finalMessage?.context_docs?.top_documents || documents,
citations: finalMessage?.citations || {},
files: finalMessage?.files || aiMessageImages || [],
toolCalls: finalMessage?.tool_calls || toolCalls,
toolCall: finalMessage?.tool_call || null,
parentMessageId: regenerationRequest
? regenerationRequest?.parentMessage?.messageId!
: initialFetchDetails.user_message_id,
@ -1381,7 +1378,7 @@ export function ChatPage({
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
toolCall: null,
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
},
{
@ -1391,7 +1388,7 @@ export function ChatPage({
message: errorMsg,
type: "error",
files: aiMessageImages || [],
toolCalls: [],
toolCall: null,
parentMessageId:
initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID,
},
@ -2238,10 +2235,7 @@ export function ChatPage({
citedDocuments={getCitedDocumentsFromMessage(
message
)}
toolCall={
message.toolCalls &&
message.toolCalls[0]
}
toolCall={message.toolCall}
isComplete={
i !== messageHistory.length - 1 ||
(currentSessionChatState !=

View File

@ -86,7 +86,7 @@ export interface Message {
documents?: DanswerDocument[] | null;
citations?: CitationMap;
files: FileDescriptor[];
toolCalls: ToolCallMetadata[];
toolCall: ToolCallMetadata | null;
// for rebuilding the message tree
parentMessageId: number | null;
childrenMessageIds?: number[];
@ -121,7 +121,7 @@ export interface BackendMessage {
time_sent: string;
citations: CitationMap;
files: FileDescriptor[];
tool_calls: ToolCallFinalResult[];
tool_call: ToolCallFinalResult | null;
alternate_assistant_id?: number | null;
overridden_model?: string;
}

View File

@ -428,7 +428,7 @@ export function processRawChatHistory(
citations: messageInfo?.citations || {},
}
: {}),
toolCalls: messageInfo.tool_calls,
toolCall: messageInfo.tool_call,
parentMessageId: messageInfo.parent_message,
childrenMessageIds: [],
latestChildMessageId: messageInfo.latest_child_message,

View File

@ -189,7 +189,7 @@ export const AIMessage = ({
files?: FileDescriptor[];
query?: string;
citedDocuments?: [string, DanswerDocument][] | null;
toolCall?: ToolCallMetadata;
toolCall?: ToolCallMetadata | null;
isComplete?: boolean;
hasDocs?: boolean;
handleFeedback?: (feedbackType: FeedbackType) => void;