Allow slack channels to be specified (#238)

Adds the capability to specify specific channels to index when using the Slack connector
This commit is contained in:
Sid Ravinutala
2023-08-08 01:09:27 -04:00
committed by GitHub
parent 3bfc72484d
commit ca72027b28
4 changed files with 75 additions and 12 deletions

View File

@@ -46,7 +46,10 @@ def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
]
def get_channels(client: WebClient, exclude_archived: bool = True) -> list[ChannelType]:
def get_channels(
client: WebClient,
exclude_archived: bool = True,
) -> list[ChannelType]:
"""Get all channels in the workspace"""
channels: list[dict[str, Any]] = []
for result in _make_slack_api_call(
@@ -133,9 +136,22 @@ def _default_msg_filter(message: MessageType) -> bool:
return message.get("subtype", "") in _DISALLOWED_MSG_SUBTYPES
def _filter_channels(
all_channels: list[dict[str, Any]], channels_to_connect: list[str] | None
) -> list[dict[str, Any]]:
if channels_to_connect:
return [
channel
for channel in all_channels
if channel["name"] in channels_to_connect
]
return all_channels
def get_all_docs(
client: WebClient,
workspace: str,
channels: list[str] | None = None,
oldest: str | None = None,
latest: str | None = None,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
@@ -143,9 +159,10 @@ def get_all_docs(
"""Get all documents in the workspace, channel by channel"""
user_id_replacer = UserIdReplacer(client=client)
channels = get_channels(client)
all_channels = get_channels(client)
filtered_channels = _filter_channels(all_channels, channels)
for channel in channels:
for channel in filtered_channels:
channel_docs = 0
channel_message_batches = get_channel_messages(
client=client, channel=channel, oldest=oldest, latest=latest
@@ -181,9 +198,14 @@ def get_all_docs(
class SlackLoadConnector(LoadConnector):
def __init__(
self, workspace: str, export_path_str: str, batch_size: int = INDEX_BATCH_SIZE
self,
workspace: str,
export_path_str: str,
channels: list[str] | None = None,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.workspace = workspace
self.channels = channels
self.export_path_str = export_path_str
self.batch_size = batch_size
@@ -245,10 +267,12 @@ class SlackLoadConnector(LoadConnector):
export_path = Path(self.export_path_str)
with open(export_path / "channels.json") as f:
channels = json.load(f)
all_channels = json.load(f)
filtered_channels = _filter_channels(all_channels, self.channels)
document_batch: dict[str, Document] = {}
for channel_info in channels:
for channel_info in filtered_channels:
channel_dir_path = export_path / cast(str, channel_info["name"])
channel_file_paths = [
channel_dir_path / file_name
@@ -275,8 +299,14 @@ class SlackLoadConnector(LoadConnector):
class SlackPollConnector(PollConnector):
def __init__(self, workspace: str, batch_size: int = INDEX_BATCH_SIZE) -> None:
def __init__(
self,
workspace: str,
channels: list[str] | None = None,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.workspace = workspace
self.channels = channels
self.batch_size = batch_size
self.client: WebClient | None = None
@@ -295,6 +325,7 @@ class SlackPollConnector(PollConnector):
for document in get_all_docs(
client=self.client,
workspace=self.workspace,
channels=self.channels,
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
# throw an error if we use 0.0 on an account without infinite data
# retention