mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-03 19:20:53 +02:00
new slack rate limiting approach (#4779)
* fix slack rate limit retry handler for groups
* trying to mitigate memory usage during csv download
* Revert "trying to mitigate memory usage during csv download"
This reverts commit 48262eacf6
.
* integrated approach to rate limiting
* code review
* try no redis setting
* add pytest-dotenv
* add more debugging
* added comments
* add more stats
---------
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
@ -8,7 +8,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@ -64,7 +64,7 @@ def _fetch_channel_permissions(
|
||||
for channel_id in private_channel_ids:
|
||||
# Collect all member ids for the channel pagination calls
|
||||
member_ids = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
slack_client.conversations_members,
|
||||
channel=channel_id,
|
||||
):
|
||||
@ -92,7 +92,7 @@ def _fetch_channel_permissions(
|
||||
external_user_emails=member_emails,
|
||||
# No group<->document mapping for slack
|
||||
external_user_group_ids=set(),
|
||||
# No way to determine if slack is invite only without enterprise liscense
|
||||
# No way to determine if slack is invite only without enterprise license
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
@ -10,8 +10,8 @@ from slack_sdk import WebClient
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
@ -23,7 +23,7 @@ def _get_slack_group_ids(
|
||||
slack_client: WebClient,
|
||||
) -> list[str]:
|
||||
group_ids = []
|
||||
for result in make_paginated_slack_api_call_w_retries(slack_client.usergroups_list):
|
||||
for result in make_paginated_slack_api_call(slack_client.usergroups_list):
|
||||
for group in result.get("usergroups", []):
|
||||
group_ids.append(group.get("id"))
|
||||
return group_ids
|
||||
@ -35,7 +35,7 @@ def _get_slack_group_members_email(
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> list[str]:
|
||||
group_member_emails = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
slack_client.usergroups_users_list, usergroup=group_name
|
||||
):
|
||||
for member_id in result.get("users", []):
|
||||
|
@ -1,13 +1,13 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
|
||||
|
||||
def fetch_user_id_to_email_map(
|
||||
slack_client: WebClient,
|
||||
) -> dict[str, str]:
|
||||
user_id_to_email_map = {}
|
||||
for user_info in make_paginated_slack_api_call_w_retries(
|
||||
for user_info in make_paginated_slack_api_call(
|
||||
slack_client.users_list,
|
||||
):
|
||||
for user in user_info.get("members", []):
|
||||
|
@ -76,10 +76,11 @@ def hash_api_key(api_key: str) -> str:
|
||||
# and overlaps are impossible
|
||||
if api_key.startswith(_API_KEY_PREFIX):
|
||||
return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
|
||||
elif api_key.startswith(_DEPRECATED_API_KEY_PREFIX):
|
||||
|
||||
if api_key.startswith(_DEPRECATED_API_KEY_PREFIX):
|
||||
return _deprecated_hash_api_key(api_key)
|
||||
else:
|
||||
raise ValueError(f"Invalid API key prefix: {api_key[:3]}")
|
||||
|
||||
raise ValueError(f"Invalid API key prefix: {api_key[:3]}")
|
||||
|
||||
|
||||
def build_displayable_api_key(api_key: str) -> str:
|
||||
|
@ -9,8 +9,11 @@ from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http.client import IncompleteRead
|
||||
from http.client import RemoteDisconnected
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.error import URLError
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
@ -18,6 +21,9 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.http_retry import ConnectionErrorRetryHandler
|
||||
from slack_sdk.http_retry import RetryHandler
|
||||
from slack_sdk.http_retry.builtin_interval_calculators import (
|
||||
FixedValueRetryIntervalCalculator,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
@ -45,10 +51,10 @@ from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.slack.onyx_retry_handler import OnyxRedisSlackRetryHandler
|
||||
from onyx.connectors.slack.onyx_slack_web_client import OnyxSlackWebClient
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import get_message_link
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@ -78,7 +84,7 @@ def _collect_paginated_channels(
|
||||
channel_types: list[str],
|
||||
) -> list[ChannelType]:
|
||||
channels: list[dict[str, Any]] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_list,
|
||||
exclude_archived=exclude_archived,
|
||||
# also get private channels the bot is added to
|
||||
@ -135,14 +141,13 @@ def get_channel_messages(
|
||||
"""Get all messages in a channel"""
|
||||
# join so that the bot can access messages
|
||||
if not channel["is_member"]:
|
||||
make_slack_api_call_w_retries(
|
||||
client.conversations_join,
|
||||
client.conversations_join(
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
logger.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
@ -159,7 +164,7 @@ def get_channel_messages(
|
||||
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||
"""Get all messages in a thread"""
|
||||
threads: list[MessageType] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_replies, channel=channel_id, ts=thread_id
|
||||
):
|
||||
threads.extend(result["messages"])
|
||||
@ -317,8 +322,7 @@ def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
|
||||
Raises:
|
||||
SlackApiError: If the channel cannot be fetched
|
||||
"""
|
||||
response = make_slack_api_call_w_retries(
|
||||
client.conversations_info,
|
||||
response = client.conversations_info(
|
||||
channel=channel_id,
|
||||
)
|
||||
return cast(ChannelType, response["channel"])
|
||||
@ -335,8 +339,7 @@ def _get_messages(
|
||||
# have to be in the channel in order to read messages
|
||||
if not channel["is_member"]:
|
||||
try:
|
||||
make_slack_api_call_w_retries(
|
||||
client.conversations_join,
|
||||
client.conversations_join(
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
@ -349,8 +352,7 @@ def _get_messages(
|
||||
raise
|
||||
logger.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
response = make_slack_api_call_w_retries(
|
||||
client.conversations_history,
|
||||
response = client.conversations_history(
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
@ -527,6 +529,7 @@ class SlackConnector(
|
||||
channel_regex_enabled: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
num_threads: int = SLACK_NUM_THREADS,
|
||||
use_redis: bool = True,
|
||||
) -> None:
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
@ -539,6 +542,7 @@ class SlackConnector(
|
||||
self.user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
self.credentials_provider: CredentialsProviderInterface | None = None
|
||||
self.credential_prefix: str | None = None
|
||||
self.use_redis: bool = use_redis
|
||||
# self.delay_lock: str | None = None # the redis key for the shared lock
|
||||
# self.delay_key: str | None = None # the redis key for the shared delay
|
||||
|
||||
@ -563,10 +567,19 @@ class SlackConnector(
|
||||
|
||||
# NOTE: slack has a built in RateLimitErrorRetryHandler, but it isn't designed
|
||||
# for concurrent workers. We've extended it with OnyxRedisSlackRetryHandler.
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler()
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler(
|
||||
max_retry_count=max_retry_count,
|
||||
interval_calculator=FixedValueRetryIntervalCalculator(),
|
||||
error_types=[
|
||||
URLError,
|
||||
ConnectionResetError,
|
||||
RemoteDisconnected,
|
||||
IncompleteRead,
|
||||
],
|
||||
)
|
||||
|
||||
onyx_rate_limit_error_retry_handler = OnyxRedisSlackRetryHandler(
|
||||
max_retry_count=max_retry_count,
|
||||
delay_lock=delay_lock,
|
||||
delay_key=delay_key,
|
||||
r=r,
|
||||
)
|
||||
@ -575,7 +588,13 @@ class SlackConnector(
|
||||
onyx_rate_limit_error_retry_handler,
|
||||
]
|
||||
|
||||
client = WebClient(token=token, retry_handlers=custom_retry_handlers)
|
||||
client = OnyxSlackWebClient(
|
||||
delay_lock=delay_lock,
|
||||
delay_key=delay_key,
|
||||
r=r,
|
||||
token=token,
|
||||
retry_handlers=custom_retry_handlers,
|
||||
)
|
||||
return client
|
||||
|
||||
@property
|
||||
@ -599,16 +618,32 @@ class SlackConnector(
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id cannot be None!")
|
||||
|
||||
self.redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
self.credential_prefix = SlackConnector.make_credential_prefix(
|
||||
credentials_provider.get_provider_key()
|
||||
)
|
||||
|
||||
bot_token = credentials["slack_bot_token"]
|
||||
self.client = SlackConnector.make_slack_web_client(
|
||||
self.credential_prefix, bot_token, self.MAX_RETRIES, self.redis
|
||||
)
|
||||
|
||||
if self.use_redis:
|
||||
self.redis = get_redis_client(tenant_id=tenant_id)
|
||||
self.credential_prefix = SlackConnector.make_credential_prefix(
|
||||
credentials_provider.get_provider_key()
|
||||
)
|
||||
|
||||
self.client = SlackConnector.make_slack_web_client(
|
||||
self.credential_prefix, bot_token, self.MAX_RETRIES, self.redis
|
||||
)
|
||||
else:
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler(
|
||||
max_retry_count=self.MAX_RETRIES,
|
||||
interval_calculator=FixedValueRetryIntervalCalculator(),
|
||||
error_types=[
|
||||
URLError,
|
||||
ConnectionResetError,
|
||||
RemoteDisconnected,
|
||||
IncompleteRead,
|
||||
],
|
||||
)
|
||||
|
||||
self.client = WebClient(
|
||||
token=bot_token, retry_handlers=[connection_error_retry_handler]
|
||||
)
|
||||
|
||||
# use for requests that must return quickly (e.g. realtime flows where user is waiting)
|
||||
self.fast_client = WebClient(
|
||||
@ -651,6 +686,8 @@ class SlackConnector(
|
||||
Step 2.4: If there are no more messages in the channel, switch the current
|
||||
channel to the next channel.
|
||||
"""
|
||||
num_channels_remaining = 0
|
||||
|
||||
if self.client is None or self.text_cleaner is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
@ -664,7 +701,9 @@ class SlackConnector(
|
||||
raw_channels, self.channels, self.channel_regex_enabled
|
||||
)
|
||||
logger.info(
|
||||
f"Channels: all={len(raw_channels)} post_filtering={len(filtered_channels)}"
|
||||
f"Channels - initial checkpoint: "
|
||||
f"all={len(raw_channels)} "
|
||||
f"post_filtering={len(filtered_channels)}"
|
||||
)
|
||||
|
||||
checkpoint.channel_ids = [c["id"] for c in filtered_channels]
|
||||
@ -677,6 +716,17 @@ class SlackConnector(
|
||||
return checkpoint
|
||||
|
||||
final_channel_ids = checkpoint.channel_ids
|
||||
for channel_id in final_channel_ids:
|
||||
if channel_id not in checkpoint.channel_completion_map:
|
||||
num_channels_remaining += 1
|
||||
|
||||
logger.info(
|
||||
f"Channels - current status: "
|
||||
f"processed={len(final_channel_ids) - num_channels_remaining} "
|
||||
f"remaining={num_channels_remaining=} "
|
||||
f"total={len(final_channel_ids)}"
|
||||
)
|
||||
|
||||
channel = checkpoint.current_channel
|
||||
if channel is None:
|
||||
raise ValueError("current_channel key not set in checkpoint")
|
||||
@ -688,18 +738,32 @@ class SlackConnector(
|
||||
oldest = str(start) if start else None
|
||||
latest = checkpoint.channel_completion_map.get(channel_id, str(end))
|
||||
seen_thread_ts = set(checkpoint.seen_thread_ts)
|
||||
|
||||
logger.debug(
|
||||
f"Getting messages for channel {channel} within range {oldest} - {latest}"
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
f"Getting messages for channel {channel} within range {oldest} - {latest}"
|
||||
)
|
||||
message_batch, has_more_in_channel = _get_messages(
|
||||
channel, self.client, oldest, latest
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Retrieved messages: "
|
||||
f"{len(message_batch)=} "
|
||||
f"{channel=} "
|
||||
f"{oldest=} "
|
||||
f"{latest=}"
|
||||
)
|
||||
|
||||
new_latest = message_batch[-1]["ts"] if message_batch else latest
|
||||
|
||||
num_threads_start = len(seen_thread_ts)
|
||||
# Process messages in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
|
||||
# NOTE(rkuo): this seems to be assuming the slack sdk is thread safe.
|
||||
# That's a very bold assumption! Likely not correct.
|
||||
|
||||
futures: list[Future[ProcessedSlackMessage]] = []
|
||||
for message in message_batch:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
@ -736,7 +800,12 @@ class SlackConnector(
|
||||
yield failure
|
||||
|
||||
num_threads_processed = len(seen_thread_ts) - num_threads_start
|
||||
logger.info(f"Processed {num_threads_processed} threads.")
|
||||
logger.info(
|
||||
f"Message processing stats: "
|
||||
f"batch_len={len(message_batch)} "
|
||||
f"batch_yielded={num_threads_processed} "
|
||||
f"total_threads_seen={len(seen_thread_ts)}"
|
||||
)
|
||||
|
||||
checkpoint.seen_thread_ts = list(seen_thread_ts)
|
||||
checkpoint.channel_completion_map[channel["id"]] = new_latest
|
||||
@ -751,6 +820,7 @@ class SlackConnector(
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if new_channel_id:
|
||||
new_channel = _get_channel_by_id(self.client, new_channel_id)
|
||||
checkpoint.current_channel = new_channel
|
||||
@ -758,8 +828,6 @@ class SlackConnector(
|
||||
checkpoint.current_channel = None
|
||||
|
||||
checkpoint.has_more = checkpoint.current_channel is not None
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing channel {channel['name']}")
|
||||
yield ConnectorFailure(
|
||||
@ -773,7 +841,8 @@ class SlackConnector(
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
return checkpoint
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
|
@ -1,11 +1,8 @@
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from slack_sdk.http_retry.handler import RetryHandler
|
||||
from slack_sdk.http_retry.request import HttpRequest
|
||||
from slack_sdk.http_retry.response import HttpResponse
|
||||
@ -20,28 +17,23 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
|
||||
"""
|
||||
This class uses Redis to share a rate limit among multiple threads.
|
||||
|
||||
Threads that encounter a rate limit will observe the shared delay, increment the
|
||||
shared delay with the retry value, and use the new shared value as a wait interval.
|
||||
As currently implemented, this code is already surrounded by a lock in Redis
|
||||
via an override of _perform_urllib_http_request in OnyxSlackWebClient.
|
||||
|
||||
This has the effect of serializing calls when a rate limit is hit, which is what
|
||||
needs to happens if the server punishes us with additional limiting when we make
|
||||
a call too early. We believe this is what Slack is doing based on empirical
|
||||
observation, meaning we see indefinite hangs if we're too aggressive.
|
||||
This just sets the desired retry delay with TTL in redis. In conjunction with
|
||||
a custom subclass of the client, the value is read and obeyed prior to an API call
|
||||
and also serialized.
|
||||
|
||||
Another way to do this is just to do exponential backoff. Might be easier?
|
||||
|
||||
Adapted from slack's RateLimitErrorRetryHandler.
|
||||
"""
|
||||
|
||||
LOCK_TTL = 60 # used to serialize access to the retry TTL
|
||||
LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock
|
||||
|
||||
"""RetryHandler that does retries for rate limited errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retry_count: int,
|
||||
delay_lock: str,
|
||||
delay_key: str,
|
||||
r: Redis,
|
||||
):
|
||||
@ -51,7 +43,6 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
|
||||
"""
|
||||
super().__init__(max_retry_count=max_retry_count)
|
||||
self._redis: Redis = r
|
||||
self._delay_lock = delay_lock
|
||||
self._delay_key = delay_key
|
||||
|
||||
def _can_retry(
|
||||
@ -72,8 +63,18 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
|
||||
response: Optional[HttpResponse] = None,
|
||||
error: Optional[Exception] = None,
|
||||
) -> None:
|
||||
"""It seems this function is responsible for the wait to retry ... aka we
|
||||
actually sleep in this function."""
|
||||
"""As initially designed by the SDK authors, this function is responsible for
|
||||
the wait to retry ... aka we actually sleep in this function.
|
||||
|
||||
This doesn't work well with multiple clients because every thread is unaware
|
||||
of the current retry value until it actually calls the endpoint.
|
||||
|
||||
We're combining this with an actual subclass of the slack web client so
|
||||
that the delay is used BEFORE calling an API endpoint. The subclassed client
|
||||
has already taken the lock in redis when this method is called.
|
||||
"""
|
||||
ttl_ms: int | None = None
|
||||
|
||||
retry_after_value: list[str] | None = None
|
||||
retry_after_header_name: Optional[str] = None
|
||||
duration_s: float = 1.0 # seconds
|
||||
@ -112,48 +113,22 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
|
||||
retry_after_value[0]
|
||||
) # will raise ValueError if somehow we can't convert to int
|
||||
jitter = retry_after_value_int * 0.25 * random.random()
|
||||
duration_s = math.ceil(retry_after_value_int + jitter)
|
||||
duration_s = retry_after_value_int + jitter
|
||||
except ValueError:
|
||||
duration_s += random.random()
|
||||
|
||||
# lock and extend the ttl
|
||||
lock: RedisLock = self._redis.lock(
|
||||
self._delay_lock,
|
||||
timeout=OnyxRedisSlackRetryHandler.LOCK_TTL,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=OnyxRedisSlackRetryHandler.LOCK_BLOCKING_TIMEOUT / 2
|
||||
)
|
||||
|
||||
ttl_ms: int | None = None
|
||||
|
||||
try:
|
||||
if acquired:
|
||||
# if we can get the lock, then read and extend the ttl
|
||||
ttl_ms = cast(int, self._redis.pttl(self._delay_key))
|
||||
if ttl_ms < 0: # negative values are error status codes ... see docs
|
||||
ttl_ms = 0
|
||||
ttl_ms_new = ttl_ms + int(duration_s * 1000.0)
|
||||
self._redis.set(self._delay_key, "1", px=ttl_ms_new)
|
||||
else:
|
||||
# if we can't get the lock, just go ahead.
|
||||
# TODO: if we know our actual parallelism, multiplying by that
|
||||
# would be a pretty good idea
|
||||
ttl_ms_new = int(duration_s * 1000.0)
|
||||
finally:
|
||||
if acquired:
|
||||
lock.release()
|
||||
# Read and extend the ttl
|
||||
ttl_ms = cast(int, self._redis.pttl(self._delay_key))
|
||||
if ttl_ms < 0: # negative values are error status codes ... see docs
|
||||
ttl_ms = 0
|
||||
ttl_ms_new = ttl_ms + int(duration_s * 1000.0)
|
||||
self._redis.set(self._delay_key, "1", px=ttl_ms_new)
|
||||
|
||||
logger.warning(
|
||||
f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt wait: "
|
||||
f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt setting delay: "
|
||||
f"current_attempt={state.current_attempt} "
|
||||
f"retry-after={retry_after_value} "
|
||||
f"shared_delay_ms={ttl_ms} new_shared_delay_ms={ttl_ms_new}"
|
||||
f"{ttl_ms_new=}"
|
||||
)
|
||||
|
||||
# TODO: would be good to take an event var and sleep in short increments to
|
||||
# allow for a clean exit / exception
|
||||
time.sleep(ttl_ms_new / 1000.0)
|
||||
|
||||
state.increment_current_attempt()
|
||||
|
116
backend/onyx/connectors/slack/onyx_slack_web_client.py
Normal file
116
backend/onyx/connectors/slack/onyx_slack_web_client.py
Normal file
@ -0,0 +1,116 @@
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from urllib.request import Request
|
||||
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_BLOCKING_TIMEOUT
|
||||
from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT
|
||||
from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_TTL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxSlackWebClient(WebClient):
|
||||
"""Use in combination with the Onyx Retry Handler.
|
||||
|
||||
This client wrapper enforces a proper retry delay through redis BEFORE the api call
|
||||
so that multiple clients can synchronize and rate limit properly.
|
||||
|
||||
The retry handler writes the correct delay value to redis so that it is can be used
|
||||
by this wrapper.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, delay_lock: str, delay_key: str, r: Redis, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._delay_key = delay_key
|
||||
self._delay_lock = delay_lock
|
||||
self._redis: Redis = r
|
||||
self.num_requests: int = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _perform_urllib_http_request(
|
||||
self, *, url: str, args: Dict[str, Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""By locking around the base class method, we ensure that both the delay from
|
||||
Redis and parsing/writing of retry values to Redis are handled properly in
|
||||
one place"""
|
||||
# lock and extend the ttl
|
||||
lock: RedisLock = self._redis.lock(
|
||||
self._delay_lock,
|
||||
timeout=ONYX_SLACK_LOCK_TTL,
|
||||
)
|
||||
|
||||
# try to acquire the lock
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
acquired = lock.acquire(blocking_timeout=ONYX_SLACK_LOCK_BLOCKING_TIMEOUT)
|
||||
if acquired:
|
||||
break
|
||||
|
||||
# if we couldn't acquire the lock but it exists, there's at least some activity
|
||||
# so keep trying...
|
||||
if self._redis.exists(self._delay_lock):
|
||||
continue
|
||||
|
||||
if time.monotonic() - start > ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT:
|
||||
raise RuntimeError(
|
||||
f"OnyxSlackWebClient._perform_urllib_http_request - "
|
||||
f"timed out waiting for lock: {ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT=}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = super()._perform_urllib_http_request(url=url, args=args)
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
else:
|
||||
logger.warning(
|
||||
"OnyxSlackWebClient._perform_urllib_http_request lock not owned on release"
|
||||
)
|
||||
|
||||
time.monotonic() - start
|
||||
# logger.info(
|
||||
# f"OnyxSlackWebClient._perform_urllib_http_request: Releasing lock: {elapsed=}"
|
||||
# )
|
||||
|
||||
return result
|
||||
|
||||
def _perform_urllib_http_request_internal(
|
||||
self,
|
||||
url: str,
|
||||
req: Request,
|
||||
) -> Dict[str, Any]:
|
||||
"""Overrides the internal method which is mostly the direct call to
|
||||
urllib/urlopen ... so this is a good place to perform our delay."""
|
||||
|
||||
# read and execute the delay
|
||||
delay_ms = cast(int, self._redis.pttl(self._delay_key))
|
||||
if delay_ms < 0: # negative values are error status codes ... see docs
|
||||
delay_ms = 0
|
||||
|
||||
if delay_ms > 0:
|
||||
logger.warning(
|
||||
f"OnyxSlackWebClient._perform_urllib_http_request_internal delay: "
|
||||
f"{delay_ms=} "
|
||||
f"{self.num_requests=}"
|
||||
)
|
||||
|
||||
time.sleep(delay_ms / 1000.0)
|
||||
|
||||
result = super()._perform_urllib_http_request_internal(url, req)
|
||||
|
||||
with self._lock:
|
||||
self.num_requests += 1
|
||||
|
||||
# the delay key should have naturally expired by this point
|
||||
return result
|
@ -21,6 +21,11 @@ basic_retry_wrapper = retry_builder(tries=7)
|
||||
# number of messages we request per page when fetching paginated slack messages
|
||||
_SLACK_LIMIT = 900
|
||||
|
||||
# used to serialize access to the retry TTL
|
||||
ONYX_SLACK_LOCK_TTL = 1800 # how long the lock is allowed to idle before it expires
|
||||
ONYX_SLACK_LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock per wait attempt
|
||||
ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT = 3600 # how long to wait for the lock in total
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_base_url(token: str) -> str:
|
||||
@ -44,6 +49,18 @@ def get_message_link(
|
||||
return link
|
||||
|
||||
|
||||
def make_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return call(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(call)(**kwargs)
|
||||
|
||||
|
||||
def _make_slack_api_call_paginated(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., Generator[dict[str, Any], None, None]]:
|
||||
@ -119,17 +136,18 @@ def _make_slack_api_call_paginated(
|
||||
|
||||
# return rate_limited_call
|
||||
|
||||
|
||||
def make_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(call)(**kwargs)
|
||||
# temporarily disabling due to using a different retry approach
|
||||
# might be permanent if everything works out
|
||||
# def make_slack_api_call_w_retries(
|
||||
# call: Callable[..., SlackResponse], **kwargs: Any
|
||||
# ) -> SlackResponse:
|
||||
# return basic_retry_wrapper(call)(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs)
|
||||
# def make_paginated_slack_api_call_w_retries(
|
||||
# call: Callable[..., SlackResponse], **kwargs: Any
|
||||
# ) -> Generator[dict[str, Any], None, None]:
|
||||
# return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs)
|
||||
|
||||
|
||||
def expert_info_from_slack_id(
|
||||
|
@ -8,3 +8,8 @@ filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::cryptography.utils.CryptographyDeprecationWarning
|
||||
ignore::PendingDeprecationWarning:ddtrace.internal.module
|
||||
# .test.env is gitignored.
|
||||
# After installing pytest-dotenv,
|
||||
# you can use it to test credentials locally.
|
||||
env_files =
|
||||
.test.env
|
||||
|
@ -12,6 +12,7 @@ pandas==2.2.3
|
||||
posthog==3.7.4
|
||||
pre-commit==3.2.2
|
||||
pytest-asyncio==0.22.0
|
||||
pytest-dotenv==0.5.2
|
||||
pytest-xdist==3.6.1
|
||||
pytest==8.3.5
|
||||
reorder-python-imports-black==3.14.0
|
||||
|
@ -31,6 +31,7 @@ def slack_connector(
|
||||
connector = SlackConnector(
|
||||
channels=[channel] if channel else None,
|
||||
channel_regex_enabled=False,
|
||||
use_redis=False,
|
||||
)
|
||||
connector.client = mock_slack_client
|
||||
connector.set_credentials_provider(credentials_provider=slack_credentials_provider)
|
||||
|
@ -17,8 +17,7 @@ from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.connectors.slack.connector import default_msg_filter
|
||||
from onyx.connectors.slack.connector import get_channel_messages
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
|
||||
|
||||
def _get_slack_channel_id(channel: dict[str, Any]) -> str:
|
||||
@ -40,7 +39,7 @@ def _get_non_general_channels(
|
||||
channel_types.append("public_channel")
|
||||
|
||||
conversations: list[dict[str, Any]] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
slack_client.conversations_list,
|
||||
exclude_archived=False,
|
||||
types=channel_types,
|
||||
@ -64,7 +63,7 @@ def _clear_slack_conversation_members(
|
||||
) -> None:
|
||||
channel_id = _get_slack_channel_id(channel)
|
||||
member_ids: list[str] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
slack_client.conversations_members,
|
||||
channel=channel_id,
|
||||
):
|
||||
@ -140,15 +139,13 @@ def _build_slack_channel_from_name(
|
||||
if channel:
|
||||
# If channel is provided, we rename it
|
||||
channel_id = _get_slack_channel_id(channel)
|
||||
channel_response = make_slack_api_call_w_retries(
|
||||
slack_client.conversations_rename,
|
||||
channel_response = slack_client.conversations_rename(
|
||||
channel=channel_id,
|
||||
name=channel_name,
|
||||
)
|
||||
else:
|
||||
# Otherwise, we create a new channel
|
||||
channel_response = make_slack_api_call_w_retries(
|
||||
slack_client.conversations_create,
|
||||
channel_response = slack_client.conversations_create(
|
||||
name=channel_name,
|
||||
is_private=is_private,
|
||||
)
|
||||
@ -219,10 +216,13 @@ class SlackManager:
|
||||
|
||||
@staticmethod
|
||||
def build_slack_user_email_id_map(slack_client: WebClient) -> dict[str, str]:
|
||||
users_results = make_slack_api_call_w_retries(
|
||||
users: list[dict[str, Any]] = []
|
||||
|
||||
for users_results in make_paginated_slack_api_call(
|
||||
slack_client.users_list,
|
||||
)
|
||||
users: list[dict[str, Any]] = users_results.get("members", [])
|
||||
):
|
||||
users.extend(users_results.get("members", []))
|
||||
|
||||
user_email_id_map = {}
|
||||
for user in users:
|
||||
if not (email := user.get("profile", {}).get("email")):
|
||||
@ -253,8 +253,7 @@ class SlackManager:
|
||||
slack_client: WebClient, channel: dict[str, Any], message: str
|
||||
) -> None:
|
||||
channel_id = _get_slack_channel_id(channel)
|
||||
make_slack_api_call_w_retries(
|
||||
slack_client.chat_postMessage,
|
||||
slack_client.chat_postMessage(
|
||||
channel=channel_id,
|
||||
text=message,
|
||||
)
|
||||
@ -274,7 +273,7 @@ class SlackManager:
|
||||
) -> None:
|
||||
channel_types = ["private_channel", "public_channel"]
|
||||
channels: list[dict[str, Any]] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
slack_client.conversations_list,
|
||||
exclude_archived=False,
|
||||
types=channel_types,
|
||||
|
Reference in New Issue
Block a user