Improve slack connector logging

This commit is contained in:
Weves
2023-08-15 15:57:58 -07:00
committed by Chris Weaver
parent c845a91eb0
commit a6e08b42e2
2 changed files with 59 additions and 22 deletions

View File

@ -19,6 +19,7 @@ from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
from danswer.connectors.slack.utils import get_message_link from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.slack.utils import make_slack_api_call_logged
from danswer.connectors.slack.utils import make_slack_api_call_paginated from danswer.connectors.slack.utils import make_slack_api_call_paginated
from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import UserIdReplacer from danswer.connectors.slack.utils import UserIdReplacer
@ -33,17 +34,25 @@ MessageType = dict[str, Any]
ThreadType = list[MessageType] ThreadType = list[MessageType]
def _make_slack_api_call( def _make_paginated_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any call: Callable[..., SlackResponse], **kwargs: Any
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
return make_slack_api_call_paginated(make_slack_api_rate_limited(call))(**kwargs) return make_slack_api_call_paginated(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)(**kwargs)
def _make_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return make_slack_api_rate_limited(make_slack_api_call_logged(call))(**kwargs)
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType: def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
"""Get information about a channel. Needed to convert channel ID to channel name""" """Get information about a channel. Needed to convert channel ID to channel name"""
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][ return _make_paginated_slack_api_call(
"channel" client.conversations_info, channel=channel_id
] )[0]["channel"]
def get_channels( def get_channels(
@ -52,7 +61,7 @@ def get_channels(
) -> list[ChannelType]: ) -> list[ChannelType]:
"""Get all channels in the workspace""" """Get all channels in the workspace"""
channels: list[dict[str, Any]] = [] channels: list[dict[str, Any]] = []
for result in _make_slack_api_call( for result in _make_paginated_slack_api_call(
client.conversations_list, exclude_archived=exclude_archived client.conversations_list, exclude_archived=exclude_archived
): ):
channels.extend(result["channels"]) channels.extend(result["channels"])
@ -68,11 +77,14 @@ def get_channel_messages(
"""Get all messages in a channel""" """Get all messages in a channel"""
# join so that the bot can access messages # join so that the bot can access messages
if not channel["is_member"]: if not channel["is_member"]:
client.conversations_join( _make_slack_api_call(
channel=channel["id"], is_private=channel["is_private"] client.conversations_join,
channel=channel["id"],
is_private=channel["is_private"],
) )
logger.info(f"Successfully joined '{channel['name']}'")
for result in _make_slack_api_call( for result in _make_paginated_slack_api_call(
client.conversations_history, client.conversations_history,
channel=channel["id"], channel=channel["id"],
oldest=oldest, oldest=oldest,
@ -84,7 +96,7 @@ def get_channel_messages(
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType: def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
"""Get all messages in a thread""" """Get all messages in a thread"""
threads: list[MessageType] = [] threads: list[MessageType] = []
for result in _make_slack_api_call( for result in _make_paginated_slack_api_call(
client.conversations_replies, channel=channel_id, ts=thread_id client.conversations_replies, channel=channel_id, ts=thread_id
): ):
threads.extend(result["messages"]) threads.extend(result["messages"])
@ -139,14 +151,24 @@ def _default_msg_filter(message: MessageType) -> bool:
def _filter_channels( def _filter_channels(
all_channels: list[dict[str, Any]], channels_to_connect: list[str] | None all_channels: list[dict[str, Any]], channels_to_connect: list[str] | None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
if channels_to_connect: if not channels_to_connect:
return [
channel
for channel in all_channels
if channel["name"] in channels_to_connect
]
return all_channels return all_channels
# validate that all channels in `channels_to_connect` are valid
# fail loudly in the case of an invalid channel so that the user
# knows that one of the channels they've specified is typo'd or private
all_channel_names = {channel["name"] for channel in all_channels}
for channel in channels_to_connect:
if channel not in all_channel_names:
raise ValueError(
f"Channel '{channel}' not found in workspace. "
f"Available channels: {all_channel_names}"
)
return [
channel for channel in all_channels if channel["name"] in channels_to_connect
]
def get_all_docs( def get_all_docs(
client: WebClient, client: WebClient,

View File

@ -1,6 +1,7 @@
import re import re
import time import time
from collections.abc import Callable from collections.abc import Callable
from functools import wraps
from typing import Any from typing import Any
from typing import cast from typing import cast
@ -29,11 +30,25 @@ def get_message_link(
) )
def make_slack_api_call_logged(
call: Callable[..., SlackResponse],
) -> Callable[..., SlackResponse]:
@wraps(call)
def logged_call(**kwargs: Any) -> SlackResponse:
logger.debug(f"Making call to Slack API '{call.__name__}' with args '{kwargs}'")
result = call(**kwargs)
logger.debug(f"Call to Slack API '{call.__name__}' returned '{result}'")
return result
return logged_call
def make_slack_api_call_paginated( def make_slack_api_call_paginated(
call: Callable[..., SlackResponse], call: Callable[..., SlackResponse],
) -> Callable[..., list[dict[str, Any]]]: ) -> Callable[..., list[dict[str, Any]]]:
"""Wraps calls to slack API so that they automatically handle pagination""" """Wraps calls to slack API so that they automatically handle pagination"""
@wraps(call)
def paginated_call(**kwargs: Any) -> list[dict[str, Any]]: def paginated_call(**kwargs: Any) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = [] results: list[dict[str, Any]] = []
cursor: str | None = None cursor: str | None = None
@ -53,17 +68,17 @@ def make_slack_api_rate_limited(
) -> Callable[..., SlackResponse]: ) -> Callable[..., SlackResponse]:
"""Wraps calls to slack API so that they automatically handle rate limiting""" """Wraps calls to slack API so that they automatically handle rate limiting"""
@wraps(call)
def rate_limited_call(**kwargs: Any) -> SlackResponse: def rate_limited_call(**kwargs: Any) -> SlackResponse:
for _ in range(max_retries): for _ in range(max_retries):
try: try:
# Make the API call # Make the API call
response = call(**kwargs) response = call(**kwargs)
# Check for errors in the response # Check for errors in the response, will raise `SlackApiError`
if response.get("ok"): # if anything went wrong
response.validate()
return response return response
else:
raise SlackApiError("", response)
except SlackApiError as e: except SlackApiError as e:
if e.response["error"] == "ratelimited": if e.response["error"] == "ratelimited":