From 584e9e6da5694609d36ebb9c2fcfb64b37b2f7d7 Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Tue, 31 Dec 2024 00:51:43 -0800 Subject: [PATCH] refac: threads --- backend/open_webui/models/messages.py | 24 +++- backend/open_webui/routers/channels.py | 50 ++++++++ backend/open_webui/socket/main.py | 1 + src/lib/apis/channels/index.ts | 42 +++++++ src/lib/components/channel/Channel.svelte | 5 +- .../components/channel/MessageInput.svelte | 6 +- src/lib/components/channel/Messages.svelte | 7 +- .../Messages/Message/ReactionPicker.svelte | 2 + .../components/channel/Messages/Thread.svelte | 28 ----- src/lib/components/channel/Thread.svelte | 108 ++++++++++++++++++ src/lib/components/chat/MessageInput.svelte | 2 +- 11 files changed, 238 insertions(+), 37 deletions(-) delete mode 100644 src/lib/components/channel/Messages/Thread.svelte create mode 100644 src/lib/components/channel/Thread.svelte diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index 68e396bc2..03a24abf3 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -140,7 +140,7 @@ class MessageTable: with get_db() as db: all_messages = ( db.query(Message) - .filter_by(channel_id=channel_id) + .filter_by(channel_id=channel_id, parent_id=None) .order_by(Message.created_at.desc()) .offset(skip) .limit(limit) @@ -148,6 +148,28 @@ class MessageTable: ) return [MessageModel.model_validate(message) for message in all_messages] + def get_messages_by_parent_id( + self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 + ) -> list[MessageModel]: + with get_db() as db: + message = db.get(Message, parent_id) + + if not message: + return [] + + all_messages = ( + db.query(Message) + .filter_by(channel_id=channel_id, parent_id=parent_id) + .order_by(Message.created_at.desc()) + .offset(skip) + .limit(limit) + .all() + ) + + return [MessageModel.model_validate(message)] + [ + MessageModel.model_validate(message) for message in all_messages + ] + def update_message_by_id( self, id: str, form_data: MessageForm ) -> Optional[MessageModel]: diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index c79bb46a2..9e12911cc 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -275,6 +275,56 @@ async def post_new_message( ) +############################ +# GetChannelThreadMessages +############################ + + +@router.get( + "/{id}/messages/{message_id}/thread", response_model=list[MessageUserResponse] +) +async def get_channel_thread_messages( + id: str, + message_id: str, + skip: int = 0, + limit: int = 50, + user=Depends(get_verified_user), +): + channel = Channels.get_channel_by_id(id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) + + if user.role != "admin" and not has_access( + user.id, type="read", access_control=channel.access_control + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + + message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit) + users = {} + + messages = [] + for message in message_list: + if message.user_id not in users: + user = Users.get_user_by_id(message.user_id) + users[message.user_id] = user + + messages.append( + MessageUserResponse( + **{ + **message.model_dump(), + "reactions": Messages.get_reactions_by_message_id(message.id), + "user": UserNameResponse(**users[message.user_id].model_dump()), + } + ) + ) + + return messages + + ############################ # UpdateMessageById ############################ diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index f3e9a033e..2d12f5803 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -237,6 +237,7 @@ async def channel_events(sid, data): "channel-events", { "channel_id": data["channel_id"], + "message_id": data.get("message_id", None), "data": event_data, "user": UserNameResponse(**SESSION_POOL[sid]).model_dump(), }, diff --git a/src/lib/apis/channels/index.ts b/src/lib/apis/channels/index.ts index 1b9cd3ebf..99ea44614 100644 --- a/src/lib/apis/channels/index.ts +++ b/src/lib/apis/channels/index.ts @@ -1,4 +1,5 @@ import { WEBUI_API_BASE_URL } from '$lib/constants'; +import { t } from 'i18next'; type ChannelForm = { name: string; @@ -207,6 +208,47 @@ export const getChannelMessages = async ( return res; }; + +export const getChannelThreadMessages = async ( + token: string = '', + channel_id: string, + message_id: string, + skip: number = 0, + limit: number = 50 +) => { + let error = null; + + const res = await fetch( + `${WEBUI_API_BASE_URL}/channels/${channel_id}/messages/${message_id}/thread?skip=${skip}&limit=${limit}`, + { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + } + } + ) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +} + type MessageForm = { content: string; data?: object; diff --git a/src/lib/components/channel/Channel.svelte b/src/lib/components/channel/Channel.svelte index 65bad2746..93cc6e590 100644 --- a/src/lib/components/channel/Channel.svelte +++ b/src/lib/components/channel/Channel.svelte @@ -13,7 +13,7 @@ import Navbar from './Navbar.svelte'; import Drawer from '../common/Drawer.svelte'; import EllipsisVertical from '../icons/EllipsisVertical.svelte'; - import Thread from './Messages/Thread.svelte'; + import Thread from './Thread.svelte'; export let id = ''; @@ -147,6 +147,7 @@ const onChange = async () => { $socket?.emit('channel-events', { channel_id: id, + message_id: null, data: { type: 'typing', data: { @@ -276,7 +277,7 @@ - +
{}; const screenCaptureHandler = async () => { try { @@ -313,7 +313,7 @@ filesInputElement.value = ''; }} /> -
+
{#if files.length > 0} diff --git a/src/lib/components/channel/Messages.svelte b/src/lib/components/channel/Messages.svelte index 91b93a247..15586dc4d 100644 --- a/src/lib/components/channel/Messages.svelte +++ b/src/lib/components/channel/Messages.svelte @@ -20,9 +20,11 @@ const i18n = getContext('i18n'); + export let id = null; export let channel = null; export let messages = []; export let top = false; + export let thread = false; export let onLoad: Function = () => {}; export let onThread: Function = () => {}; @@ -60,7 +62,7 @@
Loading...
- {:else} + {:else if !thread}
- import XMark from '$lib/components/icons/XMark.svelte'; - - export let threadId = null; - export let channel = null; - - export let onClose = () => {}; - - -
-
-
Thread
- -
- -
-
- {threadId} - - {channel} -
diff --git a/src/lib/components/channel/Thread.svelte b/src/lib/components/channel/Thread.svelte new file mode 100644 index 000000000..69a658f33 --- /dev/null +++ b/src/lib/components/channel/Thread.svelte @@ -0,0 +1,108 @@ + + +{#if channel} +
+
+
Thread
+ +
+ +
+
+ + { + const newMessages = await getChannelThreadMessages( + localStorage.token, + channel.id, + threadId, + messages.length + ); + + messages = [...messages, ...newMessages]; + + if (newMessages.length < 50) { + top = true; + return; + } + }} + /> + + +
+{/if} diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index a65150ec4..694302c6d 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -544,7 +544,7 @@ }} >
{#if files.length > 0}