Proper slack message batching

This commit is contained in:
Weves
2023-07-06 18:52:32 -07:00
committed by Chris Weaver
parent 6978573a07
commit 7874862902

View File

@@ -1,6 +1,7 @@
import json import json
import os import os
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Generator
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from typing import cast from typing import cast
@@ -55,7 +56,7 @@ def get_channel_messages(
channel: dict[str, Any], channel: dict[str, Any],
oldest: str | None = None, oldest: str | None = None,
latest: str | None = None, latest: str | None = None,
) -> list[MessageType]: ) -> Generator[list[MessageType], None, None]:
"""Get all messages in a channel""" """Get all messages in a channel"""
# join so that the bot can access messages # join so that the bot can access messages
if not channel["is_member"]: if not channel["is_member"]:
@@ -63,15 +64,13 @@ def get_channel_messages(
channel=channel["id"], is_private=channel["is_private"] channel=channel["id"], is_private=channel["is_private"]
) )
messages: list[MessageType] = []
for result in _make_slack_api_call( for result in _make_slack_api_call(
client.conversations_history, client.conversations_history,
channel=channel["id"], channel=channel["id"],
oldest=oldest, oldest=oldest,
latest=latest, latest=latest,
): ):
messages.extend(result["messages"]) yield cast(list[MessageType], result["messages"])
return messages
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType: def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
@@ -130,47 +129,39 @@ def get_all_docs(
oldest: str | None = None, oldest: str | None = None,
latest: str | None = None, latest: str | None = None,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter, msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
) -> list[Document]: ) -> Generator[Document, None, None]:
"""Get all documents in the workspace""" """Get all documents in the workspace, channel by channel"""
channels = get_channels(client) channels = get_channels(client)
channel_id_to_channel_info = {channel["id"]: channel for channel in channels}
channel_id_to_messages: dict[str, list[dict[str, Any]]] = {}
for channel in channels: for channel in channels:
channel_id_to_messages[channel["id"]] = get_channel_messages( channel_docs = 0
channel_message_batches = get_channel_messages(
client=client, channel=channel, oldest=oldest, latest=latest client=client, channel=channel, oldest=oldest, latest=latest
) )
channel_id_to_threads: dict[str, list[ThreadType]] = {} for message_batch in channel_message_batches:
for channel_id, messages in channel_id_to_messages.items(): for message in message_batch:
final_threads: list[ThreadType] = [] filtered_thread: ThreadType | None = None
for message in messages: thread_ts = message.get("thread_ts")
thread_ts = message.get("thread_ts") if thread_ts:
if thread_ts: thread = get_thread(
thread = get_thread( client=client, channel_id=channel["id"], thread_id=thread_ts
client=client, channel_id=channel_id, thread_id=thread_ts )
) filtered_thread = [
filtered_thread = [ message for message in thread if not msg_filter_func(message)
message for message in thread if not msg_filter_func(message) ]
] elif not msg_filter_func(message):
if filtered_thread: filtered_thread = [message]
final_threads.append(filtered_thread)
elif not msg_filter_func(message):
final_threads.append([message])
channel_id_to_threads[channel_id] = final_threads
docs: list[Document] = [] if filtered_thread:
for channel_id, threads in channel_id_to_threads.items(): channel_docs += 1
docs.extend( yield thread_to_doc(
thread_to_doc( workspace=workspace, channel=channel, thread=filtered_thread
workspace=workspace, )
channel=channel_id_to_channel_info[channel_id],
thread=thread, logger.info(
) f"Pulled {channel_docs} documents from slack channel {channel['name']}"
for thread in threads
) )
logger.info(f"Pulled {len(docs)} documents from slack")
return docs
class SlackLoadConnector(LoadConnector): class SlackLoadConnector(LoadConnector):
@@ -286,7 +277,9 @@ class SlackPollConnector(PollConnector):
raise PermissionError( raise PermissionError(
"Slack Client is not set up, was load_credentials called?" "Slack Client is not set up, was load_credentials called?"
) )
all_docs = get_all_docs(
documents: list[Document] = []
for document in get_all_docs(
client=self.client, client=self.client,
workspace=self.workspace, workspace=self.workspace,
# NOTE: need to impute to `None` instead of using 0.0, since Slack will # NOTE: need to impute to `None` instead of using 0.0, since Slack will
@@ -294,6 +287,11 @@ class SlackPollConnector(PollConnector):
# retention # retention
oldest=str(start) if start else None, oldest=str(start) if start else None,
latest=str(end), latest=str(end),
) ):
for i in range(0, len(all_docs), self.batch_size): documents.append(document)
yield all_docs[i : i + self.batch_size] if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
yield documents