mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 20:24:32 +02:00
Proper slack message batching
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -55,7 +56,7 @@ def get_channel_messages(
|
||||
channel: dict[str, Any],
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
) -> list[MessageType]:
|
||||
) -> Generator[list[MessageType], None, None]:
|
||||
"""Get all messages in a channel"""
|
||||
# join so that the bot can access messages
|
||||
if not channel["is_member"]:
|
||||
@@ -63,15 +64,13 @@ def get_channel_messages(
|
||||
channel=channel["id"], is_private=channel["is_private"]
|
||||
)
|
||||
|
||||
messages: list[MessageType] = []
|
||||
for result in _make_slack_api_call(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
):
|
||||
messages.extend(result["messages"])
|
||||
return messages
|
||||
yield cast(list[MessageType], result["messages"])
|
||||
|
||||
|
||||
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||
@@ -130,47 +129,39 @@ def get_all_docs(
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
|
||||
) -> list[Document]:
|
||||
"""Get all documents in the workspace"""
|
||||
) -> Generator[Document, None, None]:
|
||||
"""Get all documents in the workspace, channel by channel"""
|
||||
channels = get_channels(client)
|
||||
channel_id_to_channel_info = {channel["id"]: channel for channel in channels}
|
||||
|
||||
channel_id_to_messages: dict[str, list[dict[str, Any]]] = {}
|
||||
for channel in channels:
|
||||
channel_id_to_messages[channel["id"]] = get_channel_messages(
|
||||
channel_docs = 0
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client, channel=channel, oldest=oldest, latest=latest
|
||||
)
|
||||
|
||||
channel_id_to_threads: dict[str, list[ThreadType]] = {}
|
||||
for channel_id, messages in channel_id_to_messages.items():
|
||||
final_threads: list[ThreadType] = []
|
||||
for message in messages:
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
thread = get_thread(
|
||||
client=client, channel_id=channel_id, thread_id=thread_ts
|
||||
)
|
||||
filtered_thread = [
|
||||
message for message in thread if not msg_filter_func(message)
|
||||
]
|
||||
if filtered_thread:
|
||||
final_threads.append(filtered_thread)
|
||||
elif not msg_filter_func(message):
|
||||
final_threads.append([message])
|
||||
channel_id_to_threads[channel_id] = final_threads
|
||||
for message_batch in channel_message_batches:
|
||||
for message in message_batch:
|
||||
filtered_thread: ThreadType | None = None
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
thread = get_thread(
|
||||
client=client, channel_id=channel["id"], thread_id=thread_ts
|
||||
)
|
||||
filtered_thread = [
|
||||
message for message in thread if not msg_filter_func(message)
|
||||
]
|
||||
elif not msg_filter_func(message):
|
||||
filtered_thread = [message]
|
||||
|
||||
docs: list[Document] = []
|
||||
for channel_id, threads in channel_id_to_threads.items():
|
||||
docs.extend(
|
||||
thread_to_doc(
|
||||
workspace=workspace,
|
||||
channel=channel_id_to_channel_info[channel_id],
|
||||
thread=thread,
|
||||
)
|
||||
for thread in threads
|
||||
if filtered_thread:
|
||||
channel_docs += 1
|
||||
yield thread_to_doc(
|
||||
workspace=workspace, channel=channel, thread=filtered_thread
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Pulled {channel_docs} documents from slack channel {channel['name']}"
|
||||
)
|
||||
logger.info(f"Pulled {len(docs)} documents from slack")
|
||||
return docs
|
||||
|
||||
|
||||
class SlackLoadConnector(LoadConnector):
|
||||
@@ -286,7 +277,9 @@ class SlackPollConnector(PollConnector):
|
||||
raise PermissionError(
|
||||
"Slack Client is not set up, was load_credentials called?"
|
||||
)
|
||||
all_docs = get_all_docs(
|
||||
|
||||
documents: list[Document] = []
|
||||
for document in get_all_docs(
|
||||
client=self.client,
|
||||
workspace=self.workspace,
|
||||
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
|
||||
@@ -294,6 +287,11 @@ class SlackPollConnector(PollConnector):
|
||||
# retention
|
||||
oldest=str(start) if start else None,
|
||||
latest=str(end),
|
||||
)
|
||||
for i in range(0, len(all_docs), self.batch_size):
|
||||
yield all_docs[i : i + self.batch_size]
|
||||
):
|
||||
documents.append(document)
|
||||
if len(documents) >= self.batch_size:
|
||||
yield documents
|
||||
documents = []
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
|
Reference in New Issue
Block a user