From a6e08b42e2fc72263a02b1e5b9ceb7f7184632fb Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 15 Aug 2023 15:57:58 -0700 Subject: [PATCH] Improve slack connector logging --- backend/danswer/connectors/slack/connector.py | 56 +++++++++++++------ backend/danswer/connectors/slack/utils.py | 25 +++++++-- 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/backend/danswer/connectors/slack/connector.py b/backend/danswer/connectors/slack/connector.py index 88bfe602d..5f1e4919d 100644 --- a/backend/danswer/connectors/slack/connector.py +++ b/backend/danswer/connectors/slack/connector.py @@ -19,6 +19,7 @@ from danswer.connectors.models import ConnectorMissingCredentialError from danswer.connectors.models import Document from danswer.connectors.models import Section from danswer.connectors.slack.utils import get_message_link +from danswer.connectors.slack.utils import make_slack_api_call_logged from danswer.connectors.slack.utils import make_slack_api_call_paginated from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import UserIdReplacer @@ -33,17 +34,25 @@ MessageType = dict[str, Any] ThreadType = list[MessageType] -def _make_slack_api_call( +def _make_paginated_slack_api_call( call: Callable[..., SlackResponse], **kwargs: Any ) -> list[dict[str, Any]]: - return make_slack_api_call_paginated(make_slack_api_rate_limited(call))(**kwargs) + return make_slack_api_call_paginated( + make_slack_api_rate_limited(make_slack_api_call_logged(call)) + )(**kwargs) + + +def _make_slack_api_call( + call: Callable[..., SlackResponse], **kwargs: Any +) -> SlackResponse: + return make_slack_api_rate_limited(make_slack_api_call_logged(call))(**kwargs) def get_channel_info(client: WebClient, channel_id: str) -> ChannelType: """Get information about a channel. Needed to convert channel ID to channel name""" - return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][ - "channel" - ] + return _make_paginated_slack_api_call( + client.conversations_info, channel=channel_id + )[0]["channel"] def get_channels( @@ -52,7 +61,7 @@ def get_channels( ) -> list[ChannelType]: """Get all channels in the workspace""" channels: list[dict[str, Any]] = [] - for result in _make_slack_api_call( + for result in _make_paginated_slack_api_call( client.conversations_list, exclude_archived=exclude_archived ): channels.extend(result["channels"]) @@ -68,11 +77,14 @@ def get_channel_messages( """Get all messages in a channel""" # join so that the bot can access messages if not channel["is_member"]: - client.conversations_join( - channel=channel["id"], is_private=channel["is_private"] + _make_slack_api_call( + client.conversations_join, + channel=channel["id"], + is_private=channel["is_private"], ) + logger.info(f"Successfully joined '{channel['name']}'") - for result in _make_slack_api_call( + for result in _make_paginated_slack_api_call( client.conversations_history, channel=channel["id"], oldest=oldest, @@ -84,7 +96,7 @@ def get_channel_messages( def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType: """Get all messages in a thread""" threads: list[MessageType] = [] - for result in _make_slack_api_call( + for result in _make_paginated_slack_api_call( client.conversations_replies, channel=channel_id, ts=thread_id ): threads.extend(result["messages"]) @@ -139,13 +151,23 @@ def _default_msg_filter(message: MessageType) -> bool: def _filter_channels( all_channels: list[dict[str, Any]], channels_to_connect: list[str] | None ) -> list[dict[str, Any]]: - if channels_to_connect: - return [ - channel - for channel in all_channels - if channel["name"] in channels_to_connect - ] - return all_channels + if not channels_to_connect: + return all_channels + + # validate that all channels in `channels_to_connect` are valid + # fail loudly in the case of an invalid channel so that the user + # knows that one of the channels they've specified is typo'd or private + all_channel_names = {channel["name"] for channel in all_channels} + for channel in channels_to_connect: + if channel not in all_channel_names: + raise ValueError( + f"Channel '{channel}' not found in workspace. " + f"Available channels: {all_channel_names}" + ) + + return [ + channel for channel in all_channels if channel["name"] in channels_to_connect + ] def get_all_docs( diff --git a/backend/danswer/connectors/slack/utils.py b/backend/danswer/connectors/slack/utils.py index 0d486cfbd..2e8c750d0 100644 --- a/backend/danswer/connectors/slack/utils.py +++ b/backend/danswer/connectors/slack/utils.py @@ -1,6 +1,7 @@ import re import time from collections.abc import Callable +from functools import wraps from typing import Any from typing import cast @@ -29,11 +30,25 @@ def get_message_link( ) +def make_slack_api_call_logged( + call: Callable[..., SlackResponse], +) -> Callable[..., SlackResponse]: + @wraps(call) + def logged_call(**kwargs: Any) -> SlackResponse: + logger.debug(f"Making call to Slack API '{call.__name__}' with args '{kwargs}'") + result = call(**kwargs) + logger.debug(f"Call to Slack API '{call.__name__}' returned '{result}'") + return result + + return logged_call + + def make_slack_api_call_paginated( call: Callable[..., SlackResponse], ) -> Callable[..., list[dict[str, Any]]]: """Wraps calls to slack API so that they automatically handle pagination""" + @wraps(call) def paginated_call(**kwargs: Any) -> list[dict[str, Any]]: results: list[dict[str, Any]] = [] cursor: str | None = None @@ -53,17 +68,17 @@ def make_slack_api_rate_limited( ) -> Callable[..., SlackResponse]: """Wraps calls to slack API so that they automatically handle rate limiting""" + @wraps(call) def rate_limited_call(**kwargs: Any) -> SlackResponse: for _ in range(max_retries): try: # Make the API call response = call(**kwargs) - # Check for errors in the response - if response.get("ok"): - return response - else: - raise SlackApiError("", response) + # Check for errors in the response, will raise `SlackApiError` + # if anything went wrong + response.validate() + return response except SlackApiError as e: if e.response["error"] == "ratelimited":