welcome to onyx

This commit is contained in:
pablodanswer
2024-12-13 09:48:43 -08:00
parent 54dcbfa288
commit 21ec5ed795
813 changed files with 7021 additions and 6824 deletions

View File

@@ -0,0 +1,444 @@
import re
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import get_message_link
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
from onyx.connectors.slack.utils import SlackTextCleaner
from onyx.utils.logger import setup_logger
logger = setup_logger()
ChannelType = dict[str, Any]
MessageType = dict[str, Any]
# list of messages in a thread
ThreadType = list[MessageType]
def _collect_paginated_channels(
client: WebClient,
exclude_archived: bool,
channel_types: list[str],
) -> list[ChannelType]:
channels: list[dict[str, Any]] = []
for result in make_paginated_slack_api_call_w_retries(
client.conversations_list,
exclude_archived=exclude_archived,
# also get private channels the bot is added to
types=channel_types,
):
channels.extend(result["channels"])
return channels
def get_channels(
client: WebClient,
exclude_archived: bool = True,
get_public: bool = True,
get_private: bool = True,
) -> list[ChannelType]:
"""Get all channels in the workspace"""
channels: list[dict[str, Any]] = []
channel_types = []
if get_public:
channel_types.append("public_channel")
if get_private:
channel_types.append("private_channel")
# try getting private channels as well at first
try:
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
except SlackApiError as e:
logger.info(f"Unable to fetch private channels due to - {e}")
logger.info("trying again without private channels")
if get_public:
channel_types = ["public_channel"]
else:
logger.warning("No channels to fetch")
return []
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
return channels
def get_channel_messages(
client: WebClient,
channel: dict[str, Any],
oldest: str | None = None,
latest: str | None = None,
) -> Generator[list[MessageType], None, None]:
"""Get all messages in a channel"""
# join so that the bot can access messages
if not channel["is_member"]:
make_slack_api_call_w_retries(
client.conversations_join,
channel=channel["id"],
is_private=channel["is_private"],
)
logger.info(f"Successfully joined '{channel['name']}'")
for result in make_paginated_slack_api_call_w_retries(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
latest=latest,
):
yield cast(list[MessageType], result["messages"])
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
"""Get all messages in a thread"""
threads: list[MessageType] = []
for result in make_paginated_slack_api_call_w_retries(
client.conversations_replies, channel=channel_id, ts=thread_id
):
threads.extend(result["messages"])
return threads
def get_latest_message_time(thread: ThreadType) -> datetime:
max_ts = max([float(msg.get("ts", 0)) for msg in thread])
return datetime.fromtimestamp(max_ts, tz=timezone.utc)
def thread_to_doc(
channel: ChannelType,
thread: ThreadType,
slack_cleaner: SlackTextCleaner,
client: WebClient,
user_cache: dict[str, BasicExpertInfo | None],
) -> Document:
channel_id = channel["id"]
initial_sender_expert_info = expert_info_from_slack_id(
user_id=thread[0].get("user"), client=client, user_cache=user_cache
)
initial_sender_name = (
initial_sender_expert_info.get_semantic_name()
if initial_sender_expert_info
else "Unknown"
)
valid_experts = None
if ENABLE_EXPENSIVE_EXPERT_CALLS:
all_sender_ids = [m.get("user") for m in thread]
experts = [
expert_info_from_slack_id(
user_id=sender_id, client=client, user_cache=user_cache
)
for sender_id in all_sender_ids
if sender_id
]
valid_experts = [expert for expert in experts if expert]
first_message = slack_cleaner.index_clean(cast(str, thread[0]["text"]))
snippet = (
first_message[:50].rstrip() + "..."
if len(first_message) > 50
else first_message
)
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace(
"\n", " "
)
return Document(
id=f"{channel_id}__{thread[0]['ts']}",
sections=[
Section(
link=get_message_link(event=m, client=client, channel_id=channel_id),
text=slack_cleaner.index_clean(cast(str, m["text"])),
)
for m in thread
],
source=DocumentSource.SLACK,
semantic_identifier=doc_sem_id,
doc_updated_at=get_latest_message_time(thread),
title="", # slack docs don't really have a "title"
primary_owners=valid_experts,
metadata={"Channel": channel["name"]},
)
# list of subtypes can be found here: https://api.slack.com/events/message
_DISALLOWED_MSG_SUBTYPES = {
"channel_join",
"channel_leave",
"channel_archive",
"channel_unarchive",
"pinned_item",
"unpinned_item",
"ekm_access_denied",
"channel_posting_permissions",
"group_join",
"group_leave",
"group_archive",
"group_unarchive",
"channel_leave",
"channel_name",
"channel_join",
}
def default_msg_filter(message: MessageType) -> bool:
# Don't keep messages from bots
if message.get("bot_id") or message.get("app_id"):
if message.get("bot_profile", {}).get("name") == "OnyxConnector":
return False
return True
# Uninformative
if message.get("subtype", "") in _DISALLOWED_MSG_SUBTYPES:
return True
return False
def filter_channels(
all_channels: list[dict[str, Any]],
channels_to_connect: list[str] | None,
regex_enabled: bool,
) -> list[dict[str, Any]]:
if not channels_to_connect:
return all_channels
if regex_enabled:
return [
channel
for channel in all_channels
if any(
re.fullmatch(channel_to_connect, channel["name"])
for channel_to_connect in channels_to_connect
)
]
# validate that all channels in `channels_to_connect` are valid
# fail loudly in the case of an invalid channel so that the user
# knows that one of the channels they've specified is typo'd or private
all_channel_names = {channel["name"] for channel in all_channels}
for channel in channels_to_connect:
if channel not in all_channel_names:
raise ValueError(
f"Channel '{channel}' not found in workspace. "
f"Available channels: {all_channel_names}"
)
return [
channel for channel in all_channels if channel["name"] in channels_to_connect
]
def _get_all_docs(
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
oldest: str | None = None,
latest: str | None = None,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> Generator[Document, None, None]:
"""Get all documents in the workspace, channel by channel"""
slack_cleaner = SlackTextCleaner(client=client)
# Cache to prevent refetching via API since users
user_cache: dict[str, BasicExpertInfo | None] = {}
all_channels = get_channels(client)
filtered_channels = filter_channels(
all_channels, channels, channel_name_regex_enabled
)
for channel in filtered_channels:
channel_docs = 0
channel_message_batches = get_channel_messages(
client=client, channel=channel, oldest=oldest, latest=latest
)
seen_thread_ts: set[str] = set()
for message_batch in channel_message_batches:
for message in message_batch:
filtered_thread: ThreadType | None = None
thread_ts = message.get("thread_ts")
if thread_ts:
# skip threads we've already seen, since we've already processed all
# messages in that thread
if thread_ts in seen_thread_ts:
continue
seen_thread_ts.add(thread_ts)
thread = get_thread(
client=client, channel_id=channel["id"], thread_id=thread_ts
)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
elif not msg_filter_func(message):
filtered_thread = [message]
if filtered_thread:
channel_docs += 1
yield thread_to_doc(
channel=channel,
thread=filtered_thread,
slack_cleaner=slack_cleaner,
client=client,
user_cache=user_cache,
)
logger.info(
f"Pulled {channel_docs} documents from slack channel {channel['name']}"
)
def _get_all_doc_ids(
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> GenerateSlimDocumentOutput:
"""
Get all document ids in the workspace, channel by channel
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
This makes it an order of magnitude faster than get_all_docs
"""
all_channels = get_channels(client)
filtered_channels = filter_channels(
all_channels, channels, channel_name_regex_enabled
)
for channel in filtered_channels:
channel_id = channel["id"]
channel_message_batches = get_channel_messages(
client=client,
channel=channel,
)
message_ts_set: set[str] = set()
for message_batch in channel_message_batches:
for message in message_batch:
if msg_filter_func(message):
continue
# The document id is the channel id and the ts of the first message in the thread
# Since we already have the first message of the thread, we dont have to
# fetch the thread for id retrieval, saving time and API calls
message_ts_set.add(message["ts"])
channel_metadata_list: list[SlimDocument] = []
for message_ts in message_ts_set:
channel_metadata_list.append(
SlimDocument(
id=f"{channel_id}__{message_ts}",
perm_sync_data={"channel_id": channel_id},
)
)
yield channel_metadata_list
class SlackPollConnector(PollConnector, SlimConnector):
def __init__(
self,
channels: list[str] | None = None,
# if specified, will treat the specified channel strings as
# regexes, and will only index channels that fully match the regexes
channel_regex_enabled: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.channels = channels
self.channel_regex_enabled = channel_regex_enabled
self.batch_size = batch_size
self.client: WebClient | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
bot_token = credentials["slack_bot_token"]
self.client = WebClient(token=bot_token)
return None
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")
return _get_all_doc_ids(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")
documents: list[Document] = []
for document in _get_all_docs(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
# throw an error if we use 0.0 on an account without infinite data
# retention
oldest=str(start) if start else None,
latest=str(end),
):
documents.append(document)
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
yield documents
if __name__ == "__main__":
import os
import time
slack_channel = os.environ.get("SLACK_CHANNEL")
connector = SlackPollConnector(
channels=[slack_channel] if slack_channel else None,
)
connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]})
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
document_batches = connector.poll_source(one_day_ago, current)
print(next(document_batches))

View File

@@ -0,0 +1,140 @@
import json
import os
from datetime import datetime
from datetime import timezone
from pathlib import Path
from typing import Any
from typing import cast
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.slack.connector import filter_channels
from onyx.connectors.slack.utils import get_message_link
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_event_time(event: dict[str, Any]) -> datetime | None:
ts = event.get("ts")
if not ts:
return None
return datetime.fromtimestamp(float(ts), tz=timezone.utc)
class SlackLoadConnector(LoadConnector):
# WARNING: DEPRECATED, DO NOT USE
def __init__(
self,
workspace: str,
export_path_str: str,
channels: list[str] | None = None,
# if specified, will treat the specified channel strings as
# regexes, and will only index channels that fully match the regexes
channel_regex_enabled: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.workspace = workspace
self.channels = channels
self.channel_regex_enabled = channel_regex_enabled
self.export_path_str = export_path_str
self.batch_size = batch_size
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
if credentials:
logger.warning("Unexpected credentials provided for Slack Load Connector")
return None
@staticmethod
def _process_batch_event(
slack_event: dict[str, Any],
channel: dict[str, Any],
matching_doc: Document | None,
workspace: str,
) -> Document | None:
if (
slack_event["type"] == "message"
and slack_event.get("subtype") != "channel_join"
):
if matching_doc:
return Document(
id=matching_doc.id,
sections=matching_doc.sections
+ [
Section(
link=get_message_link(
event=slack_event,
workspace=workspace,
channel_id=channel["id"],
),
text=slack_event["text"],
)
],
source=matching_doc.source,
semantic_identifier=matching_doc.semantic_identifier,
title="", # slack docs don't really have a "title"
doc_updated_at=get_event_time(slack_event),
metadata=matching_doc.metadata,
)
return Document(
id=slack_event["ts"],
sections=[
Section(
link=get_message_link(
event=slack_event,
workspace=workspace,
channel_id=channel["id"],
),
text=slack_event["text"],
)
],
source=DocumentSource.SLACK,
semantic_identifier=channel["name"],
title="", # slack docs don't really have a "title"
doc_updated_at=get_event_time(slack_event),
metadata={},
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
export_path = Path(self.export_path_str)
with open(export_path / "channels.json") as f:
all_channels = json.load(f)
filtered_channels = filter_channels(
all_channels, self.channels, self.channel_regex_enabled
)
document_batch: dict[str, Document] = {}
for channel_info in filtered_channels:
channel_dir_path = export_path / cast(str, channel_info["name"])
channel_file_paths = [
channel_dir_path / file_name
for file_name in os.listdir(channel_dir_path)
]
for path in channel_file_paths:
with open(path) as f:
events = cast(list[dict[str, Any]], json.load(f))
for slack_event in events:
doc = self._process_batch_event(
slack_event=slack_event,
channel=channel_info,
matching_doc=document_batch.get(
slack_event.get("thread_ts", "")
),
workspace=self.workspace,
)
if doc:
document_batch[doc.id] = doc
if len(document_batch) >= self.batch_size:
yield list(document_batch.values())
yield list(document_batch.values())

View File

@@ -0,0 +1,297 @@
import re
import time
from collections.abc import Callable
from collections.abc import Generator
from functools import lru_cache
from functools import wraps
from typing import Any
from typing import cast
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
from onyx.connectors.models import BasicExpertInfo
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
basic_retry_wrapper = retry_builder()
# number of messages we request per page when fetching paginated slack messages
_SLACK_LIMIT = 900
@lru_cache()
def get_base_url(token: str) -> str:
"""Retrieve and cache the base URL of the Slack workspace based on the client token."""
client = WebClient(token=token)
return client.auth_test()["url"]
def get_message_link(
event: dict[str, Any], client: WebClient, channel_id: str | None = None
) -> str:
channel_id = channel_id or event["channel"]
message_ts = event["ts"]
response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts)
permalink = response["permalink"]
return permalink
def _make_slack_api_call_logged(
call: Callable[..., SlackResponse],
) -> Callable[..., SlackResponse]:
@wraps(call)
def logged_call(**kwargs: Any) -> SlackResponse:
logger.debug(f"Making call to Slack API '{call.__name__}' with args '{kwargs}'")
result = call(**kwargs)
logger.debug(f"Call to Slack API '{call.__name__}' returned '{result}'")
return result
return logged_call
def _make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., Generator[dict[str, Any], None, None]]:
"""Wraps calls to slack API so that they automatically handle pagination"""
@wraps(call)
def paginated_call(**kwargs: Any) -> Generator[dict[str, Any], None, None]:
cursor: str | None = None
has_more = True
while has_more:
response = call(cursor=cursor, limit=_SLACK_LIMIT, **kwargs)
yield cast(dict[str, Any], response.validate())
cursor = cast(dict[str, Any], response.get("response_metadata", {})).get(
"next_cursor", ""
)
has_more = bool(cursor)
return paginated_call
def make_slack_api_rate_limited(
call: Callable[..., SlackResponse], max_retries: int = 7
) -> Callable[..., SlackResponse]:
"""Wraps calls to slack API so that they automatically handle rate limiting"""
@wraps(call)
def rate_limited_call(**kwargs: Any) -> SlackResponse:
last_exception = None
for _ in range(max_retries):
try:
# Make the API call
response = call(**kwargs)
# Check for errors in the response, will raise `SlackApiError`
# if anything went wrong
response.validate()
return response
except SlackApiError as e:
last_exception = e
try:
error = e.response["error"]
except KeyError:
error = "unknown error"
if error == "ratelimited":
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
retry_after = int(e.response.headers.get("Retry-After", 1))
logger.info(
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
)
time.sleep(retry_after)
elif error in ["already_reacted", "no_reaction"]:
# The response isn't used for reactions, this is basically just a pass
return e.response
else:
# Raise the error for non-transient errors
raise
# If the code reaches this point, all retries have been exhausted
msg = f"Max retries ({max_retries}) exceeded"
if last_exception:
raise Exception(msg) from last_exception
else:
raise Exception(msg)
return rate_limited_call
def make_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)(**kwargs)
def make_paginated_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return _make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)
)(**kwargs)
def expert_info_from_slack_id(
user_id: str | None,
client: WebClient,
user_cache: dict[str, BasicExpertInfo | None],
) -> BasicExpertInfo | None:
if not user_id:
return None
if user_id in user_cache:
return user_cache[user_id]
response = make_slack_api_rate_limited(client.users_info)(user=user_id)
if not response["ok"]:
user_cache[user_id] = None
return None
user: dict = cast(dict[Any, dict], response.data).get("user", {})
profile = user.get("profile", {})
expert = BasicExpertInfo(
display_name=user.get("real_name") or profile.get("display_name"),
first_name=profile.get("first_name"),
last_name=profile.get("last_name"),
email=profile.get("email"),
)
user_cache[user_id] = expert
return expert
class SlackTextCleaner:
"""Utility class to replace user IDs with usernames in a message.
Handles caching, so the same request is not made multiple times
for the same user ID"""
def __init__(self, client: WebClient) -> None:
self._client = client
self._id_to_name_map: dict[str, str] = {}
def _get_slack_name(self, user_id: str) -> str:
if user_id not in self._id_to_name_map:
try:
response = make_slack_api_rate_limited(self._client.users_info)(
user=user_id
)
# prefer display name if set, since that is what is shown in Slack
self._id_to_name_map[user_id] = (
response["user"]["profile"]["display_name"]
or response["user"]["profile"]["real_name"]
)
except SlackApiError as e:
logger.exception(
f"Error fetching data for user {user_id}: {e.response['error']}"
)
raise
return self._id_to_name_map[user_id]
def _replace_user_ids_with_names(self, message: str) -> str:
# Find user IDs in the message
user_ids = re.findall("<@(.*?)>", message)
# Iterate over each user ID found
for user_id in user_ids:
try:
if user_id in self._id_to_name_map:
user_name = self._id_to_name_map[user_id]
else:
user_name = self._get_slack_name(user_id)
# Replace the user ID with the username in the message
message = message.replace(f"<@{user_id}>", f"@{user_name}")
except Exception:
logger.exception(
f"Unable to replace user ID with username for user_id '{user_id}'"
)
return message
def index_clean(self, message: str) -> str:
"""During indexing, replace pattern sets that may cause confusion to the model
Some special patterns are left in as they can provide information
ie. links that contain format text|link, both the text and the link may be informative
"""
message = self._replace_user_ids_with_names(message)
message = self.replace_tags_basic(message)
message = self.replace_channels_basic(message)
message = self.replace_special_mentions(message)
message = self.replace_special_catchall(message)
return message
@staticmethod
def replace_tags_basic(message: str) -> str:
"""Simply replaces all tags with `@<USER_ID>` in order to prevent us from
tagging users in Slack when we don't want to"""
# Find user IDs in the message
user_ids = re.findall("<@(.*?)>", message)
for user_id in user_ids:
message = message.replace(f"<@{user_id}>", f"@{user_id}")
return message
@staticmethod
def replace_channels_basic(message: str) -> str:
"""Simply replaces all channel mentions with `#<CHANNEL_ID>` in order
to make a message work as part of a link"""
# Find user IDs in the message
channel_matches = re.findall(r"<#(.*?)\|(.*?)>", message)
for channel_id, channel_name in channel_matches:
message = message.replace(
f"<#{channel_id}|{channel_name}>", f"#{channel_name}"
)
return message
@staticmethod
def replace_special_mentions(message: str) -> str:
"""Simply replaces @channel, @here, and @everyone so we don't tag
a bunch of people in Slack when we don't want to"""
# Find user IDs in the message
message = message.replace("<!channel>", "@channel")
message = message.replace("<!here>", "@here")
message = message.replace("<!everyone>", "@everyone")
return message
@staticmethod
def replace_links(message: str) -> str:
"""Replaces slack links e.g. `<URL>` -> `URL` and `<URL|DISPLAY>` -> `DISPLAY`"""
# Find user IDs in the message
possible_link_matches = re.findall(r"<(.*?)>", message)
for possible_link in possible_link_matches:
if not possible_link:
continue
# Special slack patterns that aren't for links
if possible_link[0] not in ["#", "@", "!"]:
link_display = (
possible_link
if "|" not in possible_link
else possible_link.split("|")[1]
)
message = message.replace(f"<{possible_link}>", link_display)
return message
@staticmethod
def replace_special_catchall(message: str) -> str:
"""Replaces pattern of <!something|another-thing> with another-thing
This is added for <!subteam^TEAM-ID|@team-name> but may match other cases as well
"""
pattern = r"<!([^|]+)\|([^>]+)>"
return re.sub(pattern, r"\2", message)
@staticmethod
def add_zero_width_whitespace_after_tag(message: str) -> str:
"""Add a 0 width whitespace after every @"""
return message.replace("@", "@\u200B")