mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-12 22:09:36 +02:00
Improve slack connector logging
This commit is contained in:
parent
c845a91eb0
commit
a6e08b42e2
@ -19,6 +19,7 @@ from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.slack.utils import get_message_link
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_logged
|
||||
from danswer.connectors.slack.utils import make_slack_api_call_paginated
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import UserIdReplacer
|
||||
@ -33,17 +34,25 @@ MessageType = dict[str, Any]
|
||||
ThreadType = list[MessageType]
|
||||
|
||||
|
||||
def _make_slack_api_call(
|
||||
def _make_paginated_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> list[dict[str, Any]]:
|
||||
return make_slack_api_call_paginated(make_slack_api_rate_limited(call))(**kwargs)
|
||||
return make_slack_api_call_paginated(
|
||||
make_slack_api_rate_limited(make_slack_api_call_logged(call))
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def _make_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return make_slack_api_rate_limited(make_slack_api_call_logged(call))(**kwargs)
|
||||
|
||||
|
||||
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
|
||||
"""Get information about a channel. Needed to convert channel ID to channel name"""
|
||||
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
|
||||
"channel"
|
||||
]
|
||||
return _make_paginated_slack_api_call(
|
||||
client.conversations_info, channel=channel_id
|
||||
)[0]["channel"]
|
||||
|
||||
|
||||
def get_channels(
|
||||
@ -52,7 +61,7 @@ def get_channels(
|
||||
) -> list[ChannelType]:
|
||||
"""Get all channels in the workspace"""
|
||||
channels: list[dict[str, Any]] = []
|
||||
for result in _make_slack_api_call(
|
||||
for result in _make_paginated_slack_api_call(
|
||||
client.conversations_list, exclude_archived=exclude_archived
|
||||
):
|
||||
channels.extend(result["channels"])
|
||||
@ -68,11 +77,14 @@ def get_channel_messages(
|
||||
"""Get all messages in a channel"""
|
||||
# join so that the bot can access messages
|
||||
if not channel["is_member"]:
|
||||
client.conversations_join(
|
||||
channel=channel["id"], is_private=channel["is_private"]
|
||||
_make_slack_api_call(
|
||||
client.conversations_join,
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
logger.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
for result in _make_slack_api_call(
|
||||
for result in _make_paginated_slack_api_call(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
@ -84,7 +96,7 @@ def get_channel_messages(
|
||||
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||
"""Get all messages in a thread"""
|
||||
threads: list[MessageType] = []
|
||||
for result in _make_slack_api_call(
|
||||
for result in _make_paginated_slack_api_call(
|
||||
client.conversations_replies, channel=channel_id, ts=thread_id
|
||||
):
|
||||
threads.extend(result["messages"])
|
||||
@ -139,13 +151,23 @@ def _default_msg_filter(message: MessageType) -> bool:
|
||||
def _filter_channels(
|
||||
all_channels: list[dict[str, Any]], channels_to_connect: list[str] | None
|
||||
) -> list[dict[str, Any]]:
|
||||
if channels_to_connect:
|
||||
return [
|
||||
channel
|
||||
for channel in all_channels
|
||||
if channel["name"] in channels_to_connect
|
||||
]
|
||||
return all_channels
|
||||
if not channels_to_connect:
|
||||
return all_channels
|
||||
|
||||
# validate that all channels in `channels_to_connect` are valid
|
||||
# fail loudly in the case of an invalid channel so that the user
|
||||
# knows that one of the channels they've specified is typo'd or private
|
||||
all_channel_names = {channel["name"] for channel in all_channels}
|
||||
for channel in channels_to_connect:
|
||||
if channel not in all_channel_names:
|
||||
raise ValueError(
|
||||
f"Channel '{channel}' not found in workspace. "
|
||||
f"Available channels: {all_channel_names}"
|
||||
)
|
||||
|
||||
return [
|
||||
channel for channel in all_channels if channel["name"] in channels_to_connect
|
||||
]
|
||||
|
||||
|
||||
def get_all_docs(
|
||||
|
@ -1,6 +1,7 @@
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@ -29,11 +30,25 @@ def get_message_link(
|
||||
)
|
||||
|
||||
|
||||
def make_slack_api_call_logged(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., SlackResponse]:
|
||||
@wraps(call)
|
||||
def logged_call(**kwargs: Any) -> SlackResponse:
|
||||
logger.debug(f"Making call to Slack API '{call.__name__}' with args '{kwargs}'")
|
||||
result = call(**kwargs)
|
||||
logger.debug(f"Call to Slack API '{call.__name__}' returned '{result}'")
|
||||
return result
|
||||
|
||||
return logged_call
|
||||
|
||||
|
||||
def make_slack_api_call_paginated(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., list[dict[str, Any]]]:
|
||||
"""Wraps calls to slack API so that they automatically handle pagination"""
|
||||
|
||||
@wraps(call)
|
||||
def paginated_call(**kwargs: Any) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
cursor: str | None = None
|
||||
@ -53,17 +68,17 @@ def make_slack_api_rate_limited(
|
||||
) -> Callable[..., SlackResponse]:
|
||||
"""Wraps calls to slack API so that they automatically handle rate limiting"""
|
||||
|
||||
@wraps(call)
|
||||
def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
# Make the API call
|
||||
response = call(**kwargs)
|
||||
|
||||
# Check for errors in the response
|
||||
if response.get("ok"):
|
||||
return response
|
||||
else:
|
||||
raise SlackApiError("", response)
|
||||
# Check for errors in the response, will raise `SlackApiError`
|
||||
# if anything went wrong
|
||||
response.validate()
|
||||
return response
|
||||
|
||||
except SlackApiError as e:
|
||||
if e.response["error"] == "ratelimited":
|
||||
|
Loading…
x
Reference in New Issue
Block a user