mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-07 13:10:24 +02:00
Improve slack connector logging
This commit is contained in:
@ -19,6 +19,7 @@ from danswer.connectors.models import ConnectorMissingCredentialError
|
|||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
from danswer.connectors.slack.utils import get_message_link
|
from danswer.connectors.slack.utils import get_message_link
|
||||||
|
from danswer.connectors.slack.utils import make_slack_api_call_logged
|
||||||
from danswer.connectors.slack.utils import make_slack_api_call_paginated
|
from danswer.connectors.slack.utils import make_slack_api_call_paginated
|
||||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||||
from danswer.connectors.slack.utils import UserIdReplacer
|
from danswer.connectors.slack.utils import UserIdReplacer
|
||||||
@ -33,17 +34,25 @@ MessageType = dict[str, Any]
|
|||||||
ThreadType = list[MessageType]
|
ThreadType = list[MessageType]
|
||||||
|
|
||||||
|
|
||||||
def _make_slack_api_call(
|
def _make_paginated_slack_api_call(
|
||||||
call: Callable[..., SlackResponse], **kwargs: Any
|
call: Callable[..., SlackResponse], **kwargs: Any
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
return make_slack_api_call_paginated(make_slack_api_rate_limited(call))(**kwargs)
|
return make_slack_api_call_paginated(
|
||||||
|
make_slack_api_rate_limited(make_slack_api_call_logged(call))
|
||||||
|
)(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_slack_api_call(
|
||||||
|
call: Callable[..., SlackResponse], **kwargs: Any
|
||||||
|
) -> SlackResponse:
|
||||||
|
return make_slack_api_rate_limited(make_slack_api_call_logged(call))(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
|
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
|
||||||
"""Get information about a channel. Needed to convert channel ID to channel name"""
|
"""Get information about a channel. Needed to convert channel ID to channel name"""
|
||||||
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
|
return _make_paginated_slack_api_call(
|
||||||
"channel"
|
client.conversations_info, channel=channel_id
|
||||||
]
|
)[0]["channel"]
|
||||||
|
|
||||||
|
|
||||||
def get_channels(
|
def get_channels(
|
||||||
@ -52,7 +61,7 @@ def get_channels(
|
|||||||
) -> list[ChannelType]:
|
) -> list[ChannelType]:
|
||||||
"""Get all channels in the workspace"""
|
"""Get all channels in the workspace"""
|
||||||
channels: list[dict[str, Any]] = []
|
channels: list[dict[str, Any]] = []
|
||||||
for result in _make_slack_api_call(
|
for result in _make_paginated_slack_api_call(
|
||||||
client.conversations_list, exclude_archived=exclude_archived
|
client.conversations_list, exclude_archived=exclude_archived
|
||||||
):
|
):
|
||||||
channels.extend(result["channels"])
|
channels.extend(result["channels"])
|
||||||
@ -68,11 +77,14 @@ def get_channel_messages(
|
|||||||
"""Get all messages in a channel"""
|
"""Get all messages in a channel"""
|
||||||
# join so that the bot can access messages
|
# join so that the bot can access messages
|
||||||
if not channel["is_member"]:
|
if not channel["is_member"]:
|
||||||
client.conversations_join(
|
_make_slack_api_call(
|
||||||
channel=channel["id"], is_private=channel["is_private"]
|
client.conversations_join,
|
||||||
|
channel=channel["id"],
|
||||||
|
is_private=channel["is_private"],
|
||||||
)
|
)
|
||||||
|
logger.info(f"Successfully joined '{channel['name']}'")
|
||||||
|
|
||||||
for result in _make_slack_api_call(
|
for result in _make_paginated_slack_api_call(
|
||||||
client.conversations_history,
|
client.conversations_history,
|
||||||
channel=channel["id"],
|
channel=channel["id"],
|
||||||
oldest=oldest,
|
oldest=oldest,
|
||||||
@ -84,7 +96,7 @@ def get_channel_messages(
|
|||||||
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||||
"""Get all messages in a thread"""
|
"""Get all messages in a thread"""
|
||||||
threads: list[MessageType] = []
|
threads: list[MessageType] = []
|
||||||
for result in _make_slack_api_call(
|
for result in _make_paginated_slack_api_call(
|
||||||
client.conversations_replies, channel=channel_id, ts=thread_id
|
client.conversations_replies, channel=channel_id, ts=thread_id
|
||||||
):
|
):
|
||||||
threads.extend(result["messages"])
|
threads.extend(result["messages"])
|
||||||
@ -139,14 +151,24 @@ def _default_msg_filter(message: MessageType) -> bool:
|
|||||||
def _filter_channels(
|
def _filter_channels(
|
||||||
all_channels: list[dict[str, Any]], channels_to_connect: list[str] | None
|
all_channels: list[dict[str, Any]], channels_to_connect: list[str] | None
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
if channels_to_connect:
|
if not channels_to_connect:
|
||||||
return [
|
|
||||||
channel
|
|
||||||
for channel in all_channels
|
|
||||||
if channel["name"] in channels_to_connect
|
|
||||||
]
|
|
||||||
return all_channels
|
return all_channels
|
||||||
|
|
||||||
|
# validate that all channels in `channels_to_connect` are valid
|
||||||
|
# fail loudly in the case of an invalid channel so that the user
|
||||||
|
# knows that one of the channels they've specified is typo'd or private
|
||||||
|
all_channel_names = {channel["name"] for channel in all_channels}
|
||||||
|
for channel in channels_to_connect:
|
||||||
|
if channel not in all_channel_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"Channel '{channel}' not found in workspace. "
|
||||||
|
f"Available channels: {all_channel_names}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
channel for channel in all_channels if channel["name"] in channels_to_connect
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_all_docs(
|
def get_all_docs(
|
||||||
client: WebClient,
|
client: WebClient,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@ -29,11 +30,25 @@ def get_message_link(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_slack_api_call_logged(
|
||||||
|
call: Callable[..., SlackResponse],
|
||||||
|
) -> Callable[..., SlackResponse]:
|
||||||
|
@wraps(call)
|
||||||
|
def logged_call(**kwargs: Any) -> SlackResponse:
|
||||||
|
logger.debug(f"Making call to Slack API '{call.__name__}' with args '{kwargs}'")
|
||||||
|
result = call(**kwargs)
|
||||||
|
logger.debug(f"Call to Slack API '{call.__name__}' returned '{result}'")
|
||||||
|
return result
|
||||||
|
|
||||||
|
return logged_call
|
||||||
|
|
||||||
|
|
||||||
def make_slack_api_call_paginated(
|
def make_slack_api_call_paginated(
|
||||||
call: Callable[..., SlackResponse],
|
call: Callable[..., SlackResponse],
|
||||||
) -> Callable[..., list[dict[str, Any]]]:
|
) -> Callable[..., list[dict[str, Any]]]:
|
||||||
"""Wraps calls to slack API so that they automatically handle pagination"""
|
"""Wraps calls to slack API so that they automatically handle pagination"""
|
||||||
|
|
||||||
|
@wraps(call)
|
||||||
def paginated_call(**kwargs: Any) -> list[dict[str, Any]]:
|
def paginated_call(**kwargs: Any) -> list[dict[str, Any]]:
|
||||||
results: list[dict[str, Any]] = []
|
results: list[dict[str, Any]] = []
|
||||||
cursor: str | None = None
|
cursor: str | None = None
|
||||||
@ -53,17 +68,17 @@ def make_slack_api_rate_limited(
|
|||||||
) -> Callable[..., SlackResponse]:
|
) -> Callable[..., SlackResponse]:
|
||||||
"""Wraps calls to slack API so that they automatically handle rate limiting"""
|
"""Wraps calls to slack API so that they automatically handle rate limiting"""
|
||||||
|
|
||||||
|
@wraps(call)
|
||||||
def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
||||||
for _ in range(max_retries):
|
for _ in range(max_retries):
|
||||||
try:
|
try:
|
||||||
# Make the API call
|
# Make the API call
|
||||||
response = call(**kwargs)
|
response = call(**kwargs)
|
||||||
|
|
||||||
# Check for errors in the response
|
# Check for errors in the response, will raise `SlackApiError`
|
||||||
if response.get("ok"):
|
# if anything went wrong
|
||||||
|
response.validate()
|
||||||
return response
|
return response
|
||||||
else:
|
|
||||||
raise SlackApiError("", response)
|
|
||||||
|
|
||||||
except SlackApiError as e:
|
except SlackApiError as e:
|
||||||
if e.response["error"] == "ratelimited":
|
if e.response["error"] == "ratelimited":
|
||||||
|
Reference in New Issue
Block a user