Proper slack message batching

This commit is contained in:
Weves
2023-07-06 18:52:32 -07:00
committed by Chris Weaver
parent 6978573a07
commit 7874862902

View File

@@ -1,6 +1,7 @@
import json
import os
from collections.abc import Callable
from collections.abc import Generator
from pathlib import Path
from typing import Any
from typing import cast
@@ -55,7 +56,7 @@ def get_channel_messages(
channel: dict[str, Any],
oldest: str | None = None,
latest: str | None = None,
) -> list[MessageType]:
) -> Generator[list[MessageType], None, None]:
"""Get all messages in a channel"""
# join so that the bot can access messages
if not channel["is_member"]:
@@ -63,15 +64,13 @@ def get_channel_messages(
channel=channel["id"], is_private=channel["is_private"]
)
messages: list[MessageType] = []
for result in _make_slack_api_call(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
latest=latest,
):
messages.extend(result["messages"])
return messages
yield cast(list[MessageType], result["messages"])
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
@@ -130,47 +129,39 @@ def get_all_docs(
oldest: str | None = None,
latest: str | None = None,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
) -> list[Document]:
"""Get all documents in the workspace"""
) -> Generator[Document, None, None]:
"""Get all documents in the workspace, channel by channel"""
channels = get_channels(client)
channel_id_to_channel_info = {channel["id"]: channel for channel in channels}
channel_id_to_messages: dict[str, list[dict[str, Any]]] = {}
for channel in channels:
channel_id_to_messages[channel["id"]] = get_channel_messages(
channel_docs = 0
channel_message_batches = get_channel_messages(
client=client, channel=channel, oldest=oldest, latest=latest
)
channel_id_to_threads: dict[str, list[ThreadType]] = {}
for channel_id, messages in channel_id_to_messages.items():
final_threads: list[ThreadType] = []
for message in messages:
thread_ts = message.get("thread_ts")
if thread_ts:
thread = get_thread(
client=client, channel_id=channel_id, thread_id=thread_ts
)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
if filtered_thread:
final_threads.append(filtered_thread)
elif not msg_filter_func(message):
final_threads.append([message])
channel_id_to_threads[channel_id] = final_threads
for message_batch in channel_message_batches:
for message in message_batch:
filtered_thread: ThreadType | None = None
thread_ts = message.get("thread_ts")
if thread_ts:
thread = get_thread(
client=client, channel_id=channel["id"], thread_id=thread_ts
)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
elif not msg_filter_func(message):
filtered_thread = [message]
docs: list[Document] = []
for channel_id, threads in channel_id_to_threads.items():
docs.extend(
thread_to_doc(
workspace=workspace,
channel=channel_id_to_channel_info[channel_id],
thread=thread,
)
for thread in threads
if filtered_thread:
channel_docs += 1
yield thread_to_doc(
workspace=workspace, channel=channel, thread=filtered_thread
)
logger.info(
f"Pulled {channel_docs} documents from slack channel {channel['name']}"
)
logger.info(f"Pulled {len(docs)} documents from slack")
return docs
class SlackLoadConnector(LoadConnector):
@@ -286,7 +277,9 @@ class SlackPollConnector(PollConnector):
raise PermissionError(
"Slack Client is not set up, was load_credentials called?"
)
all_docs = get_all_docs(
documents: list[Document] = []
for document in get_all_docs(
client=self.client,
workspace=self.workspace,
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
@@ -294,6 +287,11 @@ class SlackPollConnector(PollConnector):
# retention
oldest=str(start) if start else None,
latest=str(end),
)
for i in range(0, len(all_docs), self.batch_size):
yield all_docs[i : i + self.batch_size]
):
documents.append(document)
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
yield documents