Improve slack connector logging

This commit is contained in:
Weves 2023-08-15 15:57:58 -07:00 committed by Chris Weaver
parent c845a91eb0
commit a6e08b42e2
2 changed files with 59 additions and 22 deletions

View File

@ -19,6 +19,7 @@ from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.slack.utils import make_slack_api_call_logged
from danswer.connectors.slack.utils import make_slack_api_call_paginated
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import UserIdReplacer
@ -33,17 +34,25 @@ MessageType = dict[str, Any]
ThreadType = list[MessageType]
def _make_slack_api_call(
def _make_paginated_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> list[dict[str, Any]]:
return make_slack_api_call_paginated(make_slack_api_rate_limited(call))(**kwargs)
return make_slack_api_call_paginated(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)(**kwargs)
def _make_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return make_slack_api_rate_limited(make_slack_api_call_logged(call))(**kwargs)
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
"""Get information about a channel. Needed to convert channel ID to channel name"""
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
"channel"
]
return _make_paginated_slack_api_call(
client.conversations_info, channel=channel_id
)[0]["channel"]
def get_channels(
@ -52,7 +61,7 @@ def get_channels(
) -> list[ChannelType]:
"""Get all channels in the workspace"""
channels: list[dict[str, Any]] = []
for result in _make_slack_api_call(
for result in _make_paginated_slack_api_call(
client.conversations_list, exclude_archived=exclude_archived
):
channels.extend(result["channels"])
@ -68,11 +77,14 @@ def get_channel_messages(
"""Get all messages in a channel"""
# join so that the bot can access messages
if not channel["is_member"]:
client.conversations_join(
channel=channel["id"], is_private=channel["is_private"]
_make_slack_api_call(
client.conversations_join,
channel=channel["id"],
is_private=channel["is_private"],
)
logger.info(f"Successfully joined '{channel['name']}'")
for result in _make_slack_api_call(
for result in _make_paginated_slack_api_call(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
@ -84,7 +96,7 @@ def get_channel_messages(
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
"""Get all messages in a thread"""
threads: list[MessageType] = []
for result in _make_slack_api_call(
for result in _make_paginated_slack_api_call(
client.conversations_replies, channel=channel_id, ts=thread_id
):
threads.extend(result["messages"])
@ -139,13 +151,23 @@ def _default_msg_filter(message: MessageType) -> bool:
def _filter_channels(
all_channels: list[dict[str, Any]], channels_to_connect: list[str] | None
) -> list[dict[str, Any]]:
if channels_to_connect:
return [
channel
for channel in all_channels
if channel["name"] in channels_to_connect
]
return all_channels
if not channels_to_connect:
return all_channels
# validate that all channels in `channels_to_connect` are valid
# fail loudly in the case of an invalid channel so that the user
# knows that one of the channels they've specified is typo'd or private
all_channel_names = {channel["name"] for channel in all_channels}
for channel in channels_to_connect:
if channel not in all_channel_names:
raise ValueError(
f"Channel '{channel}' not found in workspace. "
f"Available channels: {all_channel_names}"
)
return [
channel for channel in all_channels if channel["name"] in channels_to_connect
]
def get_all_docs(

View File

@ -1,6 +1,7 @@
import re
import time
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import cast
@ -29,11 +30,25 @@ def get_message_link(
)
def make_slack_api_call_logged(
call: Callable[..., SlackResponse],
) -> Callable[..., SlackResponse]:
@wraps(call)
def logged_call(**kwargs: Any) -> SlackResponse:
logger.debug(f"Making call to Slack API '{call.__name__}' with args '{kwargs}'")
result = call(**kwargs)
logger.debug(f"Call to Slack API '{call.__name__}' returned '{result}'")
return result
return logged_call
def make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., list[dict[str, Any]]]:
"""Wraps calls to slack API so that they automatically handle pagination"""
@wraps(call)
def paginated_call(**kwargs: Any) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = []
cursor: str | None = None
@ -53,17 +68,17 @@ def make_slack_api_rate_limited(
) -> Callable[..., SlackResponse]:
"""Wraps calls to slack API so that they automatically handle rate limiting"""
@wraps(call)
def rate_limited_call(**kwargs: Any) -> SlackResponse:
for _ in range(max_retries):
try:
# Make the API call
response = call(**kwargs)
# Check for errors in the response
if response.get("ok"):
return response
else:
raise SlackApiError("", response)
# Check for errors in the response, will raise `SlackApiError`
# if anything went wrong
response.validate()
return response
except SlackApiError as e:
if e.response["error"] == "ratelimited":