mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 03:58:30 +02:00
welcome to onyx
This commit is contained in:
0
backend/onyx/connectors/slack/__init__.py
Normal file
0
backend/onyx/connectors/slack/__init__.py
Normal file
444
backend/onyx/connectors/slack/connector.py
Normal file
444
backend/onyx/connectors/slack/connector.py
Normal file
@@ -0,0 +1,444 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import get_message_link
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
ChannelType = dict[str, Any]
|
||||
MessageType = dict[str, Any]
|
||||
# list of messages in a thread
|
||||
ThreadType = list[MessageType]
|
||||
|
||||
|
||||
def _collect_paginated_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool,
|
||||
channel_types: list[str],
|
||||
) -> list[ChannelType]:
|
||||
channels: list[dict[str, Any]] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_list,
|
||||
exclude_archived=exclude_archived,
|
||||
# also get private channels the bot is added to
|
||||
types=channel_types,
|
||||
):
|
||||
channels.extend(result["channels"])
|
||||
|
||||
return channels
|
||||
|
||||
|
||||
def get_channels(
|
||||
client: WebClient,
|
||||
exclude_archived: bool = True,
|
||||
get_public: bool = True,
|
||||
get_private: bool = True,
|
||||
) -> list[ChannelType]:
|
||||
"""Get all channels in the workspace"""
|
||||
channels: list[dict[str, Any]] = []
|
||||
channel_types = []
|
||||
if get_public:
|
||||
channel_types.append("public_channel")
|
||||
if get_private:
|
||||
channel_types.append("private_channel")
|
||||
# try getting private channels as well at first
|
||||
try:
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
logger.info(f"Unable to fetch private channels due to - {e}")
|
||||
logger.info("trying again without private channels")
|
||||
if get_public:
|
||||
channel_types = ["public_channel"]
|
||||
else:
|
||||
logger.warning("No channels to fetch")
|
||||
return []
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
|
||||
return channels
|
||||
|
||||
|
||||
def get_channel_messages(
|
||||
client: WebClient,
|
||||
channel: dict[str, Any],
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
) -> Generator[list[MessageType], None, None]:
|
||||
"""Get all messages in a channel"""
|
||||
# join so that the bot can access messages
|
||||
if not channel["is_member"]:
|
||||
make_slack_api_call_w_retries(
|
||||
client.conversations_join,
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
logger.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
):
|
||||
yield cast(list[MessageType], result["messages"])
|
||||
|
||||
|
||||
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||
"""Get all messages in a thread"""
|
||||
threads: list[MessageType] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
client.conversations_replies, channel=channel_id, ts=thread_id
|
||||
):
|
||||
threads.extend(result["messages"])
|
||||
return threads
|
||||
|
||||
|
||||
def get_latest_message_time(thread: ThreadType) -> datetime:
|
||||
max_ts = max([float(msg.get("ts", 0)) for msg in thread])
|
||||
return datetime.fromtimestamp(max_ts, tz=timezone.utc)
|
||||
|
||||
|
||||
def thread_to_doc(
|
||||
channel: ChannelType,
|
||||
thread: ThreadType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
client: WebClient,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
) -> Document:
|
||||
channel_id = channel["id"]
|
||||
|
||||
initial_sender_expert_info = expert_info_from_slack_id(
|
||||
user_id=thread[0].get("user"), client=client, user_cache=user_cache
|
||||
)
|
||||
initial_sender_name = (
|
||||
initial_sender_expert_info.get_semantic_name()
|
||||
if initial_sender_expert_info
|
||||
else "Unknown"
|
||||
)
|
||||
|
||||
valid_experts = None
|
||||
if ENABLE_EXPENSIVE_EXPERT_CALLS:
|
||||
all_sender_ids = [m.get("user") for m in thread]
|
||||
experts = [
|
||||
expert_info_from_slack_id(
|
||||
user_id=sender_id, client=client, user_cache=user_cache
|
||||
)
|
||||
for sender_id in all_sender_ids
|
||||
if sender_id
|
||||
]
|
||||
valid_experts = [expert for expert in experts if expert]
|
||||
|
||||
first_message = slack_cleaner.index_clean(cast(str, thread[0]["text"]))
|
||||
snippet = (
|
||||
first_message[:50].rstrip() + "..."
|
||||
if len(first_message) > 50
|
||||
else first_message
|
||||
)
|
||||
|
||||
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace(
|
||||
"\n", " "
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=f"{channel_id}__{thread[0]['ts']}",
|
||||
sections=[
|
||||
Section(
|
||||
link=get_message_link(event=m, client=client, channel_id=channel_id),
|
||||
text=slack_cleaner.index_clean(cast(str, m["text"])),
|
||||
)
|
||||
for m in thread
|
||||
],
|
||||
source=DocumentSource.SLACK,
|
||||
semantic_identifier=doc_sem_id,
|
||||
doc_updated_at=get_latest_message_time(thread),
|
||||
title="", # slack docs don't really have a "title"
|
||||
primary_owners=valid_experts,
|
||||
metadata={"Channel": channel["name"]},
|
||||
)
|
||||
|
||||
|
||||
# list of subtypes can be found here: https://api.slack.com/events/message
|
||||
_DISALLOWED_MSG_SUBTYPES = {
|
||||
"channel_join",
|
||||
"channel_leave",
|
||||
"channel_archive",
|
||||
"channel_unarchive",
|
||||
"pinned_item",
|
||||
"unpinned_item",
|
||||
"ekm_access_denied",
|
||||
"channel_posting_permissions",
|
||||
"group_join",
|
||||
"group_leave",
|
||||
"group_archive",
|
||||
"group_unarchive",
|
||||
"channel_leave",
|
||||
"channel_name",
|
||||
"channel_join",
|
||||
}
|
||||
|
||||
|
||||
def default_msg_filter(message: MessageType) -> bool:
|
||||
# Don't keep messages from bots
|
||||
if message.get("bot_id") or message.get("app_id"):
|
||||
if message.get("bot_profile", {}).get("name") == "OnyxConnector":
|
||||
return False
|
||||
return True
|
||||
|
||||
# Uninformative
|
||||
if message.get("subtype", "") in _DISALLOWED_MSG_SUBTYPES:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def filter_channels(
|
||||
all_channels: list[dict[str, Any]],
|
||||
channels_to_connect: list[str] | None,
|
||||
regex_enabled: bool,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not channels_to_connect:
|
||||
return all_channels
|
||||
|
||||
if regex_enabled:
|
||||
return [
|
||||
channel
|
||||
for channel in all_channels
|
||||
if any(
|
||||
re.fullmatch(channel_to_connect, channel["name"])
|
||||
for channel_to_connect in channels_to_connect
|
||||
)
|
||||
]
|
||||
|
||||
# validate that all channels in `channels_to_connect` are valid
|
||||
# fail loudly in the case of an invalid channel so that the user
|
||||
# knows that one of the channels they've specified is typo'd or private
|
||||
all_channel_names = {channel["name"] for channel in all_channels}
|
||||
for channel in channels_to_connect:
|
||||
if channel not in all_channel_names:
|
||||
raise ValueError(
|
||||
f"Channel '{channel}' not found in workspace. "
|
||||
f"Available channels: {all_channel_names}"
|
||||
)
|
||||
|
||||
return [
|
||||
channel for channel in all_channels if channel["name"] in channels_to_connect
|
||||
]
|
||||
|
||||
|
||||
def _get_all_docs(
|
||||
client: WebClient,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> Generator[Document, None, None]:
|
||||
"""Get all documents in the workspace, channel by channel"""
|
||||
slack_cleaner = SlackTextCleaner(client=client)
|
||||
|
||||
# Cache to prevent refetching via API since users
|
||||
user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
|
||||
all_channels = get_channels(client)
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, channels, channel_name_regex_enabled
|
||||
)
|
||||
|
||||
for channel in filtered_channels:
|
||||
channel_docs = 0
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client, channel=channel, oldest=oldest, latest=latest
|
||||
)
|
||||
|
||||
seen_thread_ts: set[str] = set()
|
||||
for message_batch in channel_message_batches:
|
||||
for message in message_batch:
|
||||
filtered_thread: ThreadType | None = None
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
# skip threads we've already seen, since we've already processed all
|
||||
# messages in that thread
|
||||
if thread_ts in seen_thread_ts:
|
||||
continue
|
||||
seen_thread_ts.add(thread_ts)
|
||||
thread = get_thread(
|
||||
client=client, channel_id=channel["id"], thread_id=thread_ts
|
||||
)
|
||||
filtered_thread = [
|
||||
message for message in thread if not msg_filter_func(message)
|
||||
]
|
||||
elif not msg_filter_func(message):
|
||||
filtered_thread = [message]
|
||||
|
||||
if filtered_thread:
|
||||
channel_docs += 1
|
||||
yield thread_to_doc(
|
||||
channel=channel,
|
||||
thread=filtered_thread,
|
||||
slack_cleaner=slack_cleaner,
|
||||
client=client,
|
||||
user_cache=user_cache,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Pulled {channel_docs} documents from slack channel {channel['name']}"
|
||||
)
|
||||
|
||||
|
||||
def _get_all_doc_ids(
|
||||
client: WebClient,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
|
||||
This makes it an order of magnitude faster than get_all_docs
|
||||
"""
|
||||
|
||||
all_channels = get_channels(client)
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, channels, channel_name_regex_enabled
|
||||
)
|
||||
|
||||
for channel in filtered_channels:
|
||||
channel_id = channel["id"]
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client,
|
||||
channel=channel,
|
||||
)
|
||||
|
||||
message_ts_set: set[str] = set()
|
||||
for message_batch in channel_message_batches:
|
||||
for message in message_batch:
|
||||
if msg_filter_func(message):
|
||||
continue
|
||||
|
||||
# The document id is the channel id and the ts of the first message in the thread
|
||||
# Since we already have the first message of the thread, we dont have to
|
||||
# fetch the thread for id retrieval, saving time and API calls
|
||||
message_ts_set.add(message["ts"])
|
||||
|
||||
channel_metadata_list: list[SlimDocument] = []
|
||||
for message_ts in message_ts_set:
|
||||
channel_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=f"{channel_id}__{message_ts}",
|
||||
perm_sync_data={"channel_id": channel_id},
|
||||
)
|
||||
)
|
||||
|
||||
yield channel_metadata_list
|
||||
|
||||
|
||||
class SlackPollConnector(PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
channels: list[str] | None = None,
|
||||
# if specified, will treat the specified channel strings as
|
||||
# regexes, and will only index channels that fully match the regexes
|
||||
channel_regex_enabled: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
self.batch_size = batch_size
|
||||
self.client: WebClient | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
bot_token = credentials["slack_bot_token"]
|
||||
self.client = WebClient(token=bot_token)
|
||||
return None
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
return _get_all_doc_ids(
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
documents: list[Document] = []
|
||||
for document in _get_all_docs(
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
|
||||
# throw an error if we use 0.0 on an account without infinite data
|
||||
# retention
|
||||
oldest=str(start) if start else None,
|
||||
latest=str(end),
|
||||
):
|
||||
documents.append(document)
|
||||
if len(documents) >= self.batch_size:
|
||||
yield documents
|
||||
documents = []
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import time
|
||||
|
||||
slack_channel = os.environ.get("SLACK_CHANNEL")
|
||||
connector = SlackPollConnector(
|
||||
channels=[slack_channel] if slack_channel else None,
|
||||
)
|
||||
connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]})
|
||||
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
|
||||
document_batches = connector.poll_source(one_day_ago, current)
|
||||
|
||||
print(next(document_batches))
|
140
backend/onyx/connectors/slack/load_connector.py
Normal file
140
backend/onyx/connectors/slack/load_connector.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.slack.connector import filter_channels
|
||||
from onyx.connectors.slack.utils import get_message_link
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_event_time(event: dict[str, Any]) -> datetime | None:
|
||||
ts = event.get("ts")
|
||||
if not ts:
|
||||
return None
|
||||
return datetime.fromtimestamp(float(ts), tz=timezone.utc)
|
||||
|
||||
|
||||
class SlackLoadConnector(LoadConnector):
|
||||
# WARNING: DEPRECATED, DO NOT USE
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str,
|
||||
export_path_str: str,
|
||||
channels: list[str] | None = None,
|
||||
# if specified, will treat the specified channel strings as
|
||||
# regexes, and will only index channels that fully match the regexes
|
||||
channel_regex_enabled: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.workspace = workspace
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
self.export_path_str = export_path_str
|
||||
self.batch_size = batch_size
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if credentials:
|
||||
logger.warning("Unexpected credentials provided for Slack Load Connector")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _process_batch_event(
|
||||
slack_event: dict[str, Any],
|
||||
channel: dict[str, Any],
|
||||
matching_doc: Document | None,
|
||||
workspace: str,
|
||||
) -> Document | None:
|
||||
if (
|
||||
slack_event["type"] == "message"
|
||||
and slack_event.get("subtype") != "channel_join"
|
||||
):
|
||||
if matching_doc:
|
||||
return Document(
|
||||
id=matching_doc.id,
|
||||
sections=matching_doc.sections
|
||||
+ [
|
||||
Section(
|
||||
link=get_message_link(
|
||||
event=slack_event,
|
||||
workspace=workspace,
|
||||
channel_id=channel["id"],
|
||||
),
|
||||
text=slack_event["text"],
|
||||
)
|
||||
],
|
||||
source=matching_doc.source,
|
||||
semantic_identifier=matching_doc.semantic_identifier,
|
||||
title="", # slack docs don't really have a "title"
|
||||
doc_updated_at=get_event_time(slack_event),
|
||||
metadata=matching_doc.metadata,
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=slack_event["ts"],
|
||||
sections=[
|
||||
Section(
|
||||
link=get_message_link(
|
||||
event=slack_event,
|
||||
workspace=workspace,
|
||||
channel_id=channel["id"],
|
||||
),
|
||||
text=slack_event["text"],
|
||||
)
|
||||
],
|
||||
source=DocumentSource.SLACK,
|
||||
semantic_identifier=channel["name"],
|
||||
title="", # slack docs don't really have a "title"
|
||||
doc_updated_at=get_event_time(slack_event),
|
||||
metadata={},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
export_path = Path(self.export_path_str)
|
||||
|
||||
with open(export_path / "channels.json") as f:
|
||||
all_channels = json.load(f)
|
||||
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, self.channels, self.channel_regex_enabled
|
||||
)
|
||||
|
||||
document_batch: dict[str, Document] = {}
|
||||
for channel_info in filtered_channels:
|
||||
channel_dir_path = export_path / cast(str, channel_info["name"])
|
||||
channel_file_paths = [
|
||||
channel_dir_path / file_name
|
||||
for file_name in os.listdir(channel_dir_path)
|
||||
]
|
||||
for path in channel_file_paths:
|
||||
with open(path) as f:
|
||||
events = cast(list[dict[str, Any]], json.load(f))
|
||||
for slack_event in events:
|
||||
doc = self._process_batch_event(
|
||||
slack_event=slack_event,
|
||||
channel=channel_info,
|
||||
matching_doc=document_batch.get(
|
||||
slack_event.get("thread_ts", "")
|
||||
),
|
||||
workspace=self.workspace,
|
||||
)
|
||||
if doc:
|
||||
document_batch[doc.id] = doc
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield list(document_batch.values())
|
||||
|
||||
yield list(document_batch.values())
|
297
backend/onyx/connectors/slack/utils.py
Normal file
297
backend/onyx/connectors/slack/utils.py
Normal file
@@ -0,0 +1,297 @@
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from functools import lru_cache
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web import SlackResponse
|
||||
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
basic_retry_wrapper = retry_builder()
|
||||
# number of messages we request per page when fetching paginated slack messages
|
||||
_SLACK_LIMIT = 900
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_base_url(token: str) -> str:
|
||||
"""Retrieve and cache the base URL of the Slack workspace based on the client token."""
|
||||
client = WebClient(token=token)
|
||||
return client.auth_test()["url"]
|
||||
|
||||
|
||||
def get_message_link(
|
||||
event: dict[str, Any], client: WebClient, channel_id: str | None = None
|
||||
) -> str:
|
||||
channel_id = channel_id or event["channel"]
|
||||
message_ts = event["ts"]
|
||||
response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts)
|
||||
permalink = response["permalink"]
|
||||
return permalink
|
||||
|
||||
|
||||
def _make_slack_api_call_logged(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., SlackResponse]:
|
||||
@wraps(call)
|
||||
def logged_call(**kwargs: Any) -> SlackResponse:
|
||||
logger.debug(f"Making call to Slack API '{call.__name__}' with args '{kwargs}'")
|
||||
result = call(**kwargs)
|
||||
logger.debug(f"Call to Slack API '{call.__name__}' returned '{result}'")
|
||||
return result
|
||||
|
||||
return logged_call
|
||||
|
||||
|
||||
def _make_slack_api_call_paginated(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., Generator[dict[str, Any], None, None]]:
|
||||
"""Wraps calls to slack API so that they automatically handle pagination"""
|
||||
|
||||
@wraps(call)
|
||||
def paginated_call(**kwargs: Any) -> Generator[dict[str, Any], None, None]:
|
||||
cursor: str | None = None
|
||||
has_more = True
|
||||
while has_more:
|
||||
response = call(cursor=cursor, limit=_SLACK_LIMIT, **kwargs)
|
||||
yield cast(dict[str, Any], response.validate())
|
||||
cursor = cast(dict[str, Any], response.get("response_metadata", {})).get(
|
||||
"next_cursor", ""
|
||||
)
|
||||
has_more = bool(cursor)
|
||||
|
||||
return paginated_call
|
||||
|
||||
|
||||
def make_slack_api_rate_limited(
|
||||
call: Callable[..., SlackResponse], max_retries: int = 7
|
||||
) -> Callable[..., SlackResponse]:
|
||||
"""Wraps calls to slack API so that they automatically handle rate limiting"""
|
||||
|
||||
@wraps(call)
|
||||
def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
||||
last_exception = None
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
# Make the API call
|
||||
response = call(**kwargs)
|
||||
|
||||
# Check for errors in the response, will raise `SlackApiError`
|
||||
# if anything went wrong
|
||||
response.validate()
|
||||
return response
|
||||
|
||||
except SlackApiError as e:
|
||||
last_exception = e
|
||||
try:
|
||||
error = e.response["error"]
|
||||
except KeyError:
|
||||
error = "unknown error"
|
||||
|
||||
if error == "ratelimited":
|
||||
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
|
||||
retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||
logger.info(
|
||||
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
elif error in ["already_reacted", "no_reaction"]:
|
||||
# The response isn't used for reactions, this is basically just a pass
|
||||
return e.response
|
||||
else:
|
||||
# Raise the error for non-transient errors
|
||||
raise
|
||||
|
||||
# If the code reaches this point, all retries have been exhausted
|
||||
msg = f"Max retries ({max_retries}) exceeded"
|
||||
if last_exception:
|
||||
raise Exception(msg) from last_exception
|
||||
else:
|
||||
raise Exception(msg)
|
||||
|
||||
return rate_limited_call
|
||||
|
||||
|
||||
def make_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def expert_info_from_slack_id(
|
||||
user_id: str | None,
|
||||
client: WebClient,
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
) -> BasicExpertInfo | None:
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
if user_id in user_cache:
|
||||
return user_cache[user_id]
|
||||
|
||||
response = make_slack_api_rate_limited(client.users_info)(user=user_id)
|
||||
|
||||
if not response["ok"]:
|
||||
user_cache[user_id] = None
|
||||
return None
|
||||
|
||||
user: dict = cast(dict[Any, dict], response.data).get("user", {})
|
||||
profile = user.get("profile", {})
|
||||
|
||||
expert = BasicExpertInfo(
|
||||
display_name=user.get("real_name") or profile.get("display_name"),
|
||||
first_name=profile.get("first_name"),
|
||||
last_name=profile.get("last_name"),
|
||||
email=profile.get("email"),
|
||||
)
|
||||
|
||||
user_cache[user_id] = expert
|
||||
|
||||
return expert
|
||||
|
||||
|
||||
class SlackTextCleaner:
|
||||
"""Utility class to replace user IDs with usernames in a message.
|
||||
Handles caching, so the same request is not made multiple times
|
||||
for the same user ID"""
|
||||
|
||||
def __init__(self, client: WebClient) -> None:
|
||||
self._client = client
|
||||
self._id_to_name_map: dict[str, str] = {}
|
||||
|
||||
def _get_slack_name(self, user_id: str) -> str:
|
||||
if user_id not in self._id_to_name_map:
|
||||
try:
|
||||
response = make_slack_api_rate_limited(self._client.users_info)(
|
||||
user=user_id
|
||||
)
|
||||
# prefer display name if set, since that is what is shown in Slack
|
||||
self._id_to_name_map[user_id] = (
|
||||
response["user"]["profile"]["display_name"]
|
||||
or response["user"]["profile"]["real_name"]
|
||||
)
|
||||
except SlackApiError as e:
|
||||
logger.exception(
|
||||
f"Error fetching data for user {user_id}: {e.response['error']}"
|
||||
)
|
||||
raise
|
||||
|
||||
return self._id_to_name_map[user_id]
|
||||
|
||||
def _replace_user_ids_with_names(self, message: str) -> str:
|
||||
# Find user IDs in the message
|
||||
user_ids = re.findall("<@(.*?)>", message)
|
||||
|
||||
# Iterate over each user ID found
|
||||
for user_id in user_ids:
|
||||
try:
|
||||
if user_id in self._id_to_name_map:
|
||||
user_name = self._id_to_name_map[user_id]
|
||||
else:
|
||||
user_name = self._get_slack_name(user_id)
|
||||
|
||||
# Replace the user ID with the username in the message
|
||||
message = message.replace(f"<@{user_id}>", f"@{user_name}")
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unable to replace user ID with username for user_id '{user_id}'"
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
def index_clean(self, message: str) -> str:
|
||||
"""During indexing, replace pattern sets that may cause confusion to the model
|
||||
Some special patterns are left in as they can provide information
|
||||
ie. links that contain format text|link, both the text and the link may be informative
|
||||
"""
|
||||
message = self._replace_user_ids_with_names(message)
|
||||
message = self.replace_tags_basic(message)
|
||||
message = self.replace_channels_basic(message)
|
||||
message = self.replace_special_mentions(message)
|
||||
message = self.replace_special_catchall(message)
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def replace_tags_basic(message: str) -> str:
|
||||
"""Simply replaces all tags with `@<USER_ID>` in order to prevent us from
|
||||
tagging users in Slack when we don't want to"""
|
||||
# Find user IDs in the message
|
||||
user_ids = re.findall("<@(.*?)>", message)
|
||||
for user_id in user_ids:
|
||||
message = message.replace(f"<@{user_id}>", f"@{user_id}")
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def replace_channels_basic(message: str) -> str:
|
||||
"""Simply replaces all channel mentions with `#<CHANNEL_ID>` in order
|
||||
to make a message work as part of a link"""
|
||||
# Find user IDs in the message
|
||||
channel_matches = re.findall(r"<#(.*?)\|(.*?)>", message)
|
||||
for channel_id, channel_name in channel_matches:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}|{channel_name}>", f"#{channel_name}"
|
||||
)
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def replace_special_mentions(message: str) -> str:
|
||||
"""Simply replaces @channel, @here, and @everyone so we don't tag
|
||||
a bunch of people in Slack when we don't want to"""
|
||||
# Find user IDs in the message
|
||||
message = message.replace("<!channel>", "@channel")
|
||||
message = message.replace("<!here>", "@here")
|
||||
message = message.replace("<!everyone>", "@everyone")
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def replace_links(message: str) -> str:
|
||||
"""Replaces slack links e.g. `<URL>` -> `URL` and `<URL|DISPLAY>` -> `DISPLAY`"""
|
||||
# Find user IDs in the message
|
||||
possible_link_matches = re.findall(r"<(.*?)>", message)
|
||||
for possible_link in possible_link_matches:
|
||||
if not possible_link:
|
||||
continue
|
||||
# Special slack patterns that aren't for links
|
||||
if possible_link[0] not in ["#", "@", "!"]:
|
||||
link_display = (
|
||||
possible_link
|
||||
if "|" not in possible_link
|
||||
else possible_link.split("|")[1]
|
||||
)
|
||||
message = message.replace(f"<{possible_link}>", link_display)
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def replace_special_catchall(message: str) -> str:
|
||||
"""Replaces pattern of <!something|another-thing> with another-thing
|
||||
This is added for <!subteam^TEAM-ID|@team-name> but may match other cases as well
|
||||
"""
|
||||
|
||||
pattern = r"<!([^|]+)\|([^>]+)>"
|
||||
return re.sub(pattern, r"\2", message)
|
||||
|
||||
@staticmethod
|
||||
def add_zero_width_whitespace_after_tag(message: str) -> str:
|
||||
"""Add a 0 width whitespace after every @"""
|
||||
return message.replace("@", "@\u200B")
|
Reference in New Issue
Block a user