From 7874862902cbe4e3956a267c1545917901b2d17b Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 6 Jul 2023 18:52:32 -0700 Subject: [PATCH] Proper slack message batching --- backend/danswer/connectors/slack/connector.py | 80 +++++++++---------- 1 file changed, 39 insertions(+), 41 deletions(-) diff --git a/backend/danswer/connectors/slack/connector.py b/backend/danswer/connectors/slack/connector.py index 1dbb0ae21595..83d82538b0ba 100644 --- a/backend/danswer/connectors/slack/connector.py +++ b/backend/danswer/connectors/slack/connector.py @@ -1,6 +1,7 @@ import json import os from collections.abc import Callable +from collections.abc import Generator from pathlib import Path from typing import Any from typing import cast @@ -55,7 +56,7 @@ def get_channel_messages( channel: dict[str, Any], oldest: str | None = None, latest: str | None = None, -) -> list[MessageType]: +) -> Generator[list[MessageType], None, None]: """Get all messages in a channel""" # join so that the bot can access messages if not channel["is_member"]: @@ -63,15 +64,13 @@ def get_channel_messages( channel=channel["id"], is_private=channel["is_private"] ) - messages: list[MessageType] = [] for result in _make_slack_api_call( client.conversations_history, channel=channel["id"], oldest=oldest, latest=latest, ): - messages.extend(result["messages"]) - return messages + yield cast(list[MessageType], result["messages"]) def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType: @@ -130,47 +129,39 @@ def get_all_docs( oldest: str | None = None, latest: str | None = None, msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter, -) -> list[Document]: - """Get all documents in the workspace""" +) -> Generator[Document, None, None]: + """Get all documents in the workspace, channel by channel""" channels = get_channels(client) - channel_id_to_channel_info = {channel["id"]: channel for channel in channels} - channel_id_to_messages: dict[str, list[dict[str, Any]]] = {} for channel in channels: - channel_id_to_messages[channel["id"]] = get_channel_messages( + channel_docs = 0 + channel_message_batches = get_channel_messages( client=client, channel=channel, oldest=oldest, latest=latest ) - channel_id_to_threads: dict[str, list[ThreadType]] = {} - for channel_id, messages in channel_id_to_messages.items(): - final_threads: list[ThreadType] = [] - for message in messages: - thread_ts = message.get("thread_ts") - if thread_ts: - thread = get_thread( - client=client, channel_id=channel_id, thread_id=thread_ts - ) - filtered_thread = [ - message for message in thread if not msg_filter_func(message) - ] - if filtered_thread: - final_threads.append(filtered_thread) - elif not msg_filter_func(message): - final_threads.append([message]) - channel_id_to_threads[channel_id] = final_threads + for message_batch in channel_message_batches: + for message in message_batch: + filtered_thread: ThreadType | None = None + thread_ts = message.get("thread_ts") + if thread_ts: + thread = get_thread( + client=client, channel_id=channel["id"], thread_id=thread_ts + ) + filtered_thread = [ + message for message in thread if not msg_filter_func(message) + ] + elif not msg_filter_func(message): + filtered_thread = [message] - docs: list[Document] = [] - for channel_id, threads in channel_id_to_threads.items(): - docs.extend( - thread_to_doc( - workspace=workspace, - channel=channel_id_to_channel_info[channel_id], - thread=thread, - ) - for thread in threads + if filtered_thread: + channel_docs += 1 + yield thread_to_doc( + workspace=workspace, channel=channel, thread=filtered_thread + ) + + logger.info( + f"Pulled {channel_docs} documents from slack channel {channel['name']}" ) - logger.info(f"Pulled {len(docs)} documents from slack") - return docs class SlackLoadConnector(LoadConnector): @@ -286,7 +277,9 @@ class SlackPollConnector(PollConnector): raise PermissionError( "Slack Client is not set up, was load_credentials called?" ) - all_docs = get_all_docs( + + documents: list[Document] = [] + for document in get_all_docs( client=self.client, workspace=self.workspace, # NOTE: need to impute to `None` instead of using 0.0, since Slack will @@ -294,6 +287,11 @@ class SlackPollConnector(PollConnector): # retention oldest=str(start) if start else None, latest=str(end), - ) - for i in range(0, len(all_docs), self.batch_size): - yield all_docs[i : i + self.batch_size] + ): + documents.append(document) + if len(documents) >= self.batch_size: + yield documents + documents = [] + + if documents: + yield documents