From ef4d5dcec32e924ad8914a84feddaaf75e31c97c Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Thu, 29 May 2025 12:49:32 -0700 Subject: [PATCH] new slack rate limiting approach (#4779) * fix slack rate limit retry handler for groups * trying to mitigate memory usage during csv download * Revert "trying to mitigate memory usage during csv download" This reverts commit 48262eacf69ef1970a0e3e78f8e599708c1f5114. * integrated approach to rate limiting * code review * try no redis setting * add pytest-dotenv * add more debugging * added comments * add more stats --------- Co-authored-by: Richard Kuo (Onyx) --- .../external_permissions/slack/doc_sync.py | 6 +- .../external_permissions/slack/group_sync.py | 6 +- .../onyx/external_permissions/slack/utils.py | 4 +- backend/onyx/auth/api_key.py | 7 +- backend/onyx/connectors/slack/connector.py | 135 +++++++++++++----- .../connectors/slack/onyx_retry_handler.py | 79 ++++------ .../connectors/slack/onyx_slack_web_client.py | 116 +++++++++++++++ backend/onyx/connectors/slack/utils.py | 36 +++-- backend/pytest.ini | 5 + backend/requirements/dev.txt | 1 + .../connectors/slack/test_slack_connector.py | 1 + .../slack/slack_api_utils.py | 27 ++-- 12 files changed, 304 insertions(+), 119 deletions(-) create mode 100644 backend/onyx/connectors/slack/onyx_slack_web_client.py diff --git a/backend/ee/onyx/external_permissions/slack/doc_sync.py b/backend/ee/onyx/external_permissions/slack/doc_sync.py index c7b02b24a9..64f259a803 100644 --- a/backend/ee/onyx/external_permissions/slack/doc_sync.py +++ b/backend/ee/onyx/external_permissions/slack/doc_sync.py @@ -8,7 +8,7 @@ from onyx.access.models import DocExternalAccess from onyx.access.models import ExternalAccess from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider from onyx.connectors.slack.connector import get_channels -from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries +from onyx.connectors.slack.connector import make_paginated_slack_api_call from onyx.connectors.slack.connector import SlackConnector from onyx.db.models import ConnectorCredentialPair from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface @@ -64,7 +64,7 @@ def _fetch_channel_permissions( for channel_id in private_channel_ids: # Collect all member ids for the channel pagination calls member_ids = [] - for result in make_paginated_slack_api_call_w_retries( + for result in make_paginated_slack_api_call( slack_client.conversations_members, channel=channel_id, ): @@ -92,7 +92,7 @@ def _fetch_channel_permissions( external_user_emails=member_emails, # No group<->document mapping for slack external_user_group_ids=set(), - # No way to determine if slack is invite only without enterprise liscense + # No way to determine if slack is invite only without enterprise license is_public=False, ) diff --git a/backend/ee/onyx/external_permissions/slack/group_sync.py b/backend/ee/onyx/external_permissions/slack/group_sync.py index 9ca060128a..6ec23749af 100644 --- a/backend/ee/onyx/external_permissions/slack/group_sync.py +++ b/backend/ee/onyx/external_permissions/slack/group_sync.py @@ -10,8 +10,8 @@ from slack_sdk import WebClient from ee.onyx.db.external_perm import ExternalUserGroup from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider -from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries from onyx.connectors.slack.connector import SlackConnector +from onyx.connectors.slack.utils import make_paginated_slack_api_call from onyx.db.models import ConnectorCredentialPair from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger @@ -23,7 +23,7 @@ def _get_slack_group_ids( slack_client: WebClient, ) -> list[str]: group_ids = [] - for result in make_paginated_slack_api_call_w_retries(slack_client.usergroups_list): + for result in make_paginated_slack_api_call(slack_client.usergroups_list): for group in result.get("usergroups", []): group_ids.append(group.get("id")) return group_ids @@ -35,7 +35,7 @@ def _get_slack_group_members_email( user_id_to_email_map: dict[str, str], ) -> list[str]: group_member_emails = [] - for result in make_paginated_slack_api_call_w_retries( + for result in make_paginated_slack_api_call( slack_client.usergroups_users_list, usergroup=group_name ): for member_id in result.get("users", []): diff --git a/backend/ee/onyx/external_permissions/slack/utils.py b/backend/ee/onyx/external_permissions/slack/utils.py index 85ef284679..38aaa3785e 100644 --- a/backend/ee/onyx/external_permissions/slack/utils.py +++ b/backend/ee/onyx/external_permissions/slack/utils.py @@ -1,13 +1,13 @@ from slack_sdk import WebClient -from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries +from onyx.connectors.slack.utils import make_paginated_slack_api_call def fetch_user_id_to_email_map( slack_client: WebClient, ) -> dict[str, str]: user_id_to_email_map = {} - for user_info in make_paginated_slack_api_call_w_retries( + for user_info in make_paginated_slack_api_call( slack_client.users_list, ): for user in user_info.get("members", []): diff --git a/backend/onyx/auth/api_key.py b/backend/onyx/auth/api_key.py index e6c8c0c584..e5a62f7728 100644 --- a/backend/onyx/auth/api_key.py +++ b/backend/onyx/auth/api_key.py @@ -76,10 +76,11 @@ def hash_api_key(api_key: str) -> str: # and overlaps are impossible if api_key.startswith(_API_KEY_PREFIX): return hashlib.sha256(api_key.encode("utf-8")).hexdigest() - elif api_key.startswith(_DEPRECATED_API_KEY_PREFIX): + + if api_key.startswith(_DEPRECATED_API_KEY_PREFIX): return _deprecated_hash_api_key(api_key) - else: - raise ValueError(f"Invalid API key prefix: {api_key[:3]}") + + raise ValueError(f"Invalid API key prefix: {api_key[:3]}") def build_displayable_api_key(api_key: str) -> str: diff --git a/backend/onyx/connectors/slack/connector.py b/backend/onyx/connectors/slack/connector.py index ce84ad12c9..2d88f7134b 100644 --- a/backend/onyx/connectors/slack/connector.py +++ b/backend/onyx/connectors/slack/connector.py @@ -9,8 +9,11 @@ from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor from datetime import datetime from datetime import timezone +from http.client import IncompleteRead +from http.client import RemoteDisconnected from typing import Any from typing import cast +from urllib.error import URLError from pydantic import BaseModel from redis import Redis @@ -18,6 +21,9 @@ from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from slack_sdk.http_retry import ConnectionErrorRetryHandler from slack_sdk.http_retry import RetryHandler +from slack_sdk.http_retry.builtin_interval_calculators import ( + FixedValueRetryIntervalCalculator, +) from typing_extensions import override from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS @@ -45,10 +51,10 @@ from onyx.connectors.models import EntityFailure from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.connectors.slack.onyx_retry_handler import OnyxRedisSlackRetryHandler +from onyx.connectors.slack.onyx_slack_web_client import OnyxSlackWebClient from onyx.connectors.slack.utils import expert_info_from_slack_id from onyx.connectors.slack.utils import get_message_link -from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries -from onyx.connectors.slack.utils import make_slack_api_call_w_retries +from onyx.connectors.slack.utils import make_paginated_slack_api_call from onyx.connectors.slack.utils import SlackTextCleaner from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.redis.redis_pool import get_redis_client @@ -78,7 +84,7 @@ def _collect_paginated_channels( channel_types: list[str], ) -> list[ChannelType]: channels: list[dict[str, Any]] = [] - for result in make_paginated_slack_api_call_w_retries( + for result in make_paginated_slack_api_call( client.conversations_list, exclude_archived=exclude_archived, # also get private channels the bot is added to @@ -135,14 +141,13 @@ def get_channel_messages( """Get all messages in a channel""" # join so that the bot can access messages if not channel["is_member"]: - make_slack_api_call_w_retries( - client.conversations_join, + client.conversations_join( channel=channel["id"], is_private=channel["is_private"], ) logger.info(f"Successfully joined '{channel['name']}'") - for result in make_paginated_slack_api_call_w_retries( + for result in make_paginated_slack_api_call( client.conversations_history, channel=channel["id"], oldest=oldest, @@ -159,7 +164,7 @@ def get_channel_messages( def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType: """Get all messages in a thread""" threads: list[MessageType] = [] - for result in make_paginated_slack_api_call_w_retries( + for result in make_paginated_slack_api_call( client.conversations_replies, channel=channel_id, ts=thread_id ): threads.extend(result["messages"]) @@ -317,8 +322,7 @@ def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType: Raises: SlackApiError: If the channel cannot be fetched """ - response = make_slack_api_call_w_retries( - client.conversations_info, + response = client.conversations_info( channel=channel_id, ) return cast(ChannelType, response["channel"]) @@ -335,8 +339,7 @@ def _get_messages( # have to be in the channel in order to read messages if not channel["is_member"]: try: - make_slack_api_call_w_retries( - client.conversations_join, + client.conversations_join( channel=channel["id"], is_private=channel["is_private"], ) @@ -349,8 +352,7 @@ def _get_messages( raise logger.info(f"Successfully joined '{channel['name']}'") - response = make_slack_api_call_w_retries( - client.conversations_history, + response = client.conversations_history( channel=channel["id"], oldest=oldest, latest=latest, @@ -527,6 +529,7 @@ class SlackConnector( channel_regex_enabled: bool = False, batch_size: int = INDEX_BATCH_SIZE, num_threads: int = SLACK_NUM_THREADS, + use_redis: bool = True, ) -> None: self.channels = channels self.channel_regex_enabled = channel_regex_enabled @@ -539,6 +542,7 @@ class SlackConnector( self.user_cache: dict[str, BasicExpertInfo | None] = {} self.credentials_provider: CredentialsProviderInterface | None = None self.credential_prefix: str | None = None + self.use_redis: bool = use_redis # self.delay_lock: str | None = None # the redis key for the shared lock # self.delay_key: str | None = None # the redis key for the shared delay @@ -563,10 +567,19 @@ class SlackConnector( # NOTE: slack has a built in RateLimitErrorRetryHandler, but it isn't designed # for concurrent workers. We've extended it with OnyxRedisSlackRetryHandler. - connection_error_retry_handler = ConnectionErrorRetryHandler() + connection_error_retry_handler = ConnectionErrorRetryHandler( + max_retry_count=max_retry_count, + interval_calculator=FixedValueRetryIntervalCalculator(), + error_types=[ + URLError, + ConnectionResetError, + RemoteDisconnected, + IncompleteRead, + ], + ) + onyx_rate_limit_error_retry_handler = OnyxRedisSlackRetryHandler( max_retry_count=max_retry_count, - delay_lock=delay_lock, delay_key=delay_key, r=r, ) @@ -575,7 +588,13 @@ class SlackConnector( onyx_rate_limit_error_retry_handler, ] - client = WebClient(token=token, retry_handlers=custom_retry_handlers) + client = OnyxSlackWebClient( + delay_lock=delay_lock, + delay_key=delay_key, + r=r, + token=token, + retry_handlers=custom_retry_handlers, + ) return client @property @@ -599,16 +618,32 @@ class SlackConnector( if not tenant_id: raise ValueError("tenant_id cannot be None!") - self.redis = get_redis_client(tenant_id=tenant_id) - - self.credential_prefix = SlackConnector.make_credential_prefix( - credentials_provider.get_provider_key() - ) - bot_token = credentials["slack_bot_token"] - self.client = SlackConnector.make_slack_web_client( - self.credential_prefix, bot_token, self.MAX_RETRIES, self.redis - ) + + if self.use_redis: + self.redis = get_redis_client(tenant_id=tenant_id) + self.credential_prefix = SlackConnector.make_credential_prefix( + credentials_provider.get_provider_key() + ) + + self.client = SlackConnector.make_slack_web_client( + self.credential_prefix, bot_token, self.MAX_RETRIES, self.redis + ) + else: + connection_error_retry_handler = ConnectionErrorRetryHandler( + max_retry_count=self.MAX_RETRIES, + interval_calculator=FixedValueRetryIntervalCalculator(), + error_types=[ + URLError, + ConnectionResetError, + RemoteDisconnected, + IncompleteRead, + ], + ) + + self.client = WebClient( + token=bot_token, retry_handlers=[connection_error_retry_handler] + ) # use for requests that must return quickly (e.g. realtime flows where user is waiting) self.fast_client = WebClient( @@ -651,6 +686,8 @@ class SlackConnector( Step 2.4: If there are no more messages in the channel, switch the current channel to the next channel. """ + num_channels_remaining = 0 + if self.client is None or self.text_cleaner is None: raise ConnectorMissingCredentialError("Slack") @@ -664,7 +701,9 @@ class SlackConnector( raw_channels, self.channels, self.channel_regex_enabled ) logger.info( - f"Channels: all={len(raw_channels)} post_filtering={len(filtered_channels)}" + f"Channels - initial checkpoint: " + f"all={len(raw_channels)} " + f"post_filtering={len(filtered_channels)}" ) checkpoint.channel_ids = [c["id"] for c in filtered_channels] @@ -677,6 +716,17 @@ class SlackConnector( return checkpoint final_channel_ids = checkpoint.channel_ids + for channel_id in final_channel_ids: + if channel_id not in checkpoint.channel_completion_map: + num_channels_remaining += 1 + + logger.info( + f"Channels - current status: " + f"processed={len(final_channel_ids) - num_channels_remaining} " + f"remaining={num_channels_remaining=} " + f"total={len(final_channel_ids)}" + ) + channel = checkpoint.current_channel if channel is None: raise ValueError("current_channel key not set in checkpoint") @@ -688,18 +738,32 @@ class SlackConnector( oldest = str(start) if start else None latest = checkpoint.channel_completion_map.get(channel_id, str(end)) seen_thread_ts = set(checkpoint.seen_thread_ts) + + logger.debug( + f"Getting messages for channel {channel} within range {oldest} - {latest}" + ) + try: - logger.debug( - f"Getting messages for channel {channel} within range {oldest} - {latest}" - ) message_batch, has_more_in_channel = _get_messages( channel, self.client, oldest, latest ) + + logger.info( + f"Retrieved messages: " + f"{len(message_batch)=} " + f"{channel=} " + f"{oldest=} " + f"{latest=}" + ) + new_latest = message_batch[-1]["ts"] if message_batch else latest num_threads_start = len(seen_thread_ts) # Process messages in parallel using ThreadPoolExecutor with ThreadPoolExecutor(max_workers=self.num_threads) as executor: + # NOTE(rkuo): this seems to be assuming the slack sdk is thread safe. + # That's a very bold assumption! Likely not correct. + futures: list[Future[ProcessedSlackMessage]] = [] for message in message_batch: # Capture the current context so that the thread gets the current tenant ID @@ -736,7 +800,12 @@ class SlackConnector( yield failure num_threads_processed = len(seen_thread_ts) - num_threads_start - logger.info(f"Processed {num_threads_processed} threads.") + logger.info( + f"Message processing stats: " + f"batch_len={len(message_batch)} " + f"batch_yielded={num_threads_processed} " + f"total_threads_seen={len(seen_thread_ts)}" + ) checkpoint.seen_thread_ts = list(seen_thread_ts) checkpoint.channel_completion_map[channel["id"]] = new_latest @@ -751,6 +820,7 @@ class SlackConnector( ), None, ) + if new_channel_id: new_channel = _get_channel_by_id(self.client, new_channel_id) checkpoint.current_channel = new_channel @@ -758,8 +828,6 @@ class SlackConnector( checkpoint.current_channel = None checkpoint.has_more = checkpoint.current_channel is not None - return checkpoint - except Exception as e: logger.exception(f"Error processing channel {channel['name']}") yield ConnectorFailure( @@ -773,7 +841,8 @@ class SlackConnector( failure_message=str(e), exception=e, ) - return checkpoint + + return checkpoint def validate_connector_settings(self) -> None: """ diff --git a/backend/onyx/connectors/slack/onyx_retry_handler.py b/backend/onyx/connectors/slack/onyx_retry_handler.py index 7031db379f..9a84e2ae68 100644 --- a/backend/onyx/connectors/slack/onyx_retry_handler.py +++ b/backend/onyx/connectors/slack/onyx_retry_handler.py @@ -1,11 +1,8 @@ -import math import random -import time from typing import cast from typing import Optional from redis import Redis -from redis.lock import Lock as RedisLock from slack_sdk.http_retry.handler import RetryHandler from slack_sdk.http_retry.request import HttpRequest from slack_sdk.http_retry.response import HttpResponse @@ -20,28 +17,23 @@ class OnyxRedisSlackRetryHandler(RetryHandler): """ This class uses Redis to share a rate limit among multiple threads. - Threads that encounter a rate limit will observe the shared delay, increment the - shared delay with the retry value, and use the new shared value as a wait interval. + As currently implemented, this code is already surrounded by a lock in Redis + via an override of _perform_urllib_http_request in OnyxSlackWebClient. - This has the effect of serializing calls when a rate limit is hit, which is what - needs to happens if the server punishes us with additional limiting when we make - a call too early. We believe this is what Slack is doing based on empirical - observation, meaning we see indefinite hangs if we're too aggressive. + This just sets the desired retry delay with TTL in redis. In conjunction with + a custom subclass of the client, the value is read and obeyed prior to an API call + and also serialized. Another way to do this is just to do exponential backoff. Might be easier? Adapted from slack's RateLimitErrorRetryHandler. """ - LOCK_TTL = 60 # used to serialize access to the retry TTL - LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock - """RetryHandler that does retries for rate limited errors.""" def __init__( self, max_retry_count: int, - delay_lock: str, delay_key: str, r: Redis, ): @@ -51,7 +43,6 @@ class OnyxRedisSlackRetryHandler(RetryHandler): """ super().__init__(max_retry_count=max_retry_count) self._redis: Redis = r - self._delay_lock = delay_lock self._delay_key = delay_key def _can_retry( @@ -72,8 +63,18 @@ class OnyxRedisSlackRetryHandler(RetryHandler): response: Optional[HttpResponse] = None, error: Optional[Exception] = None, ) -> None: - """It seems this function is responsible for the wait to retry ... aka we - actually sleep in this function.""" + """As initially designed by the SDK authors, this function is responsible for + the wait to retry ... aka we actually sleep in this function. + + This doesn't work well with multiple clients because every thread is unaware + of the current retry value until it actually calls the endpoint. + + We're combining this with an actual subclass of the slack web client so + that the delay is used BEFORE calling an API endpoint. The subclassed client + has already taken the lock in redis when this method is called. + """ + ttl_ms: int | None = None + retry_after_value: list[str] | None = None retry_after_header_name: Optional[str] = None duration_s: float = 1.0 # seconds @@ -112,48 +113,22 @@ class OnyxRedisSlackRetryHandler(RetryHandler): retry_after_value[0] ) # will raise ValueError if somehow we can't convert to int jitter = retry_after_value_int * 0.25 * random.random() - duration_s = math.ceil(retry_after_value_int + jitter) + duration_s = retry_after_value_int + jitter except ValueError: duration_s += random.random() - # lock and extend the ttl - lock: RedisLock = self._redis.lock( - self._delay_lock, - timeout=OnyxRedisSlackRetryHandler.LOCK_TTL, - thread_local=False, - ) - - acquired = lock.acquire( - blocking_timeout=OnyxRedisSlackRetryHandler.LOCK_BLOCKING_TIMEOUT / 2 - ) - - ttl_ms: int | None = None - - try: - if acquired: - # if we can get the lock, then read and extend the ttl - ttl_ms = cast(int, self._redis.pttl(self._delay_key)) - if ttl_ms < 0: # negative values are error status codes ... see docs - ttl_ms = 0 - ttl_ms_new = ttl_ms + int(duration_s * 1000.0) - self._redis.set(self._delay_key, "1", px=ttl_ms_new) - else: - # if we can't get the lock, just go ahead. - # TODO: if we know our actual parallelism, multiplying by that - # would be a pretty good idea - ttl_ms_new = int(duration_s * 1000.0) - finally: - if acquired: - lock.release() + # Read and extend the ttl + ttl_ms = cast(int, self._redis.pttl(self._delay_key)) + if ttl_ms < 0: # negative values are error status codes ... see docs + ttl_ms = 0 + ttl_ms_new = ttl_ms + int(duration_s * 1000.0) + self._redis.set(self._delay_key, "1", px=ttl_ms_new) logger.warning( - f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt wait: " + f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt setting delay: " + f"current_attempt={state.current_attempt} " f"retry-after={retry_after_value} " - f"shared_delay_ms={ttl_ms} new_shared_delay_ms={ttl_ms_new}" + f"{ttl_ms_new=}" ) - # TODO: would be good to take an event var and sleep in short increments to - # allow for a clean exit / exception - time.sleep(ttl_ms_new / 1000.0) - state.increment_current_attempt() diff --git a/backend/onyx/connectors/slack/onyx_slack_web_client.py b/backend/onyx/connectors/slack/onyx_slack_web_client.py new file mode 100644 index 0000000000..ec3b8bbd4a --- /dev/null +++ b/backend/onyx/connectors/slack/onyx_slack_web_client.py @@ -0,0 +1,116 @@ +import threading +import time +from typing import Any +from typing import cast +from typing import Dict +from urllib.request import Request + +from redis import Redis +from redis.lock import Lock as RedisLock +from slack_sdk import WebClient + +from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_BLOCKING_TIMEOUT +from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT +from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_TTL +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +class OnyxSlackWebClient(WebClient): + """Use in combination with the Onyx Retry Handler. + + This client wrapper enforces a proper retry delay through redis BEFORE the api call + so that multiple clients can synchronize and rate limit properly. + + The retry handler writes the correct delay value to redis so that it is can be used + by this wrapper. + + """ + + def __init__( + self, delay_lock: str, delay_key: str, r: Redis, *args: Any, **kwargs: Any + ) -> None: + super().__init__(*args, **kwargs) + self._delay_key = delay_key + self._delay_lock = delay_lock + self._redis: Redis = r + self.num_requests: int = 0 + self._lock = threading.Lock() + + def _perform_urllib_http_request( + self, *, url: str, args: Dict[str, Dict[str, Any]] + ) -> Dict[str, Any]: + """By locking around the base class method, we ensure that both the delay from + Redis and parsing/writing of retry values to Redis are handled properly in + one place""" + # lock and extend the ttl + lock: RedisLock = self._redis.lock( + self._delay_lock, + timeout=ONYX_SLACK_LOCK_TTL, + ) + + # try to acquire the lock + start = time.monotonic() + while True: + acquired = lock.acquire(blocking_timeout=ONYX_SLACK_LOCK_BLOCKING_TIMEOUT) + if acquired: + break + + # if we couldn't acquire the lock but it exists, there's at least some activity + # so keep trying... + if self._redis.exists(self._delay_lock): + continue + + if time.monotonic() - start > ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT: + raise RuntimeError( + f"OnyxSlackWebClient._perform_urllib_http_request - " + f"timed out waiting for lock: {ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT=}" + ) + + try: + result = super()._perform_urllib_http_request(url=url, args=args) + finally: + if lock.owned(): + lock.release() + else: + logger.warning( + "OnyxSlackWebClient._perform_urllib_http_request lock not owned on release" + ) + + time.monotonic() - start + # logger.info( + # f"OnyxSlackWebClient._perform_urllib_http_request: Releasing lock: {elapsed=}" + # ) + + return result + + def _perform_urllib_http_request_internal( + self, + url: str, + req: Request, + ) -> Dict[str, Any]: + """Overrides the internal method which is mostly the direct call to + urllib/urlopen ... so this is a good place to perform our delay.""" + + # read and execute the delay + delay_ms = cast(int, self._redis.pttl(self._delay_key)) + if delay_ms < 0: # negative values are error status codes ... see docs + delay_ms = 0 + + if delay_ms > 0: + logger.warning( + f"OnyxSlackWebClient._perform_urllib_http_request_internal delay: " + f"{delay_ms=} " + f"{self.num_requests=}" + ) + + time.sleep(delay_ms / 1000.0) + + result = super()._perform_urllib_http_request_internal(url, req) + + with self._lock: + self.num_requests += 1 + + # the delay key should have naturally expired by this point + return result diff --git a/backend/onyx/connectors/slack/utils.py b/backend/onyx/connectors/slack/utils.py index 2a9145da84..f266d0fa86 100644 --- a/backend/onyx/connectors/slack/utils.py +++ b/backend/onyx/connectors/slack/utils.py @@ -21,6 +21,11 @@ basic_retry_wrapper = retry_builder(tries=7) # number of messages we request per page when fetching paginated slack messages _SLACK_LIMIT = 900 +# used to serialize access to the retry TTL +ONYX_SLACK_LOCK_TTL = 1800 # how long the lock is allowed to idle before it expires +ONYX_SLACK_LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock per wait attempt +ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT = 3600 # how long to wait for the lock in total + @lru_cache() def get_base_url(token: str) -> str: @@ -44,6 +49,18 @@ def get_message_link( return link +def make_slack_api_call( + call: Callable[..., SlackResponse], **kwargs: Any +) -> SlackResponse: + return call(**kwargs) + + +def make_paginated_slack_api_call( + call: Callable[..., SlackResponse], **kwargs: Any +) -> Generator[dict[str, Any], None, None]: + return _make_slack_api_call_paginated(call)(**kwargs) + + def _make_slack_api_call_paginated( call: Callable[..., SlackResponse], ) -> Callable[..., Generator[dict[str, Any], None, None]]: @@ -119,17 +136,18 @@ def _make_slack_api_call_paginated( # return rate_limited_call - -def make_slack_api_call_w_retries( - call: Callable[..., SlackResponse], **kwargs: Any -) -> SlackResponse: - return basic_retry_wrapper(call)(**kwargs) +# temporarily disabling due to using a different retry approach +# might be permanent if everything works out +# def make_slack_api_call_w_retries( +# call: Callable[..., SlackResponse], **kwargs: Any +# ) -> SlackResponse: +# return basic_retry_wrapper(call)(**kwargs) -def make_paginated_slack_api_call_w_retries( - call: Callable[..., SlackResponse], **kwargs: Any -) -> Generator[dict[str, Any], None, None]: - return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs) +# def make_paginated_slack_api_call_w_retries( +# call: Callable[..., SlackResponse], **kwargs: Any +# ) -> Generator[dict[str, Any], None, None]: +# return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs) def expert_info_from_slack_id( diff --git a/backend/pytest.ini b/backend/pytest.ini index 2b1d38e3ed..f02df059c3 100644 --- a/backend/pytest.ini +++ b/backend/pytest.ini @@ -8,3 +8,8 @@ filterwarnings = ignore::DeprecationWarning ignore::cryptography.utils.CryptographyDeprecationWarning ignore::PendingDeprecationWarning:ddtrace.internal.module +# .test.env is gitignored. +# After installing pytest-dotenv, +# you can use it to test credentials locally. +env_files = + .test.env diff --git a/backend/requirements/dev.txt b/backend/requirements/dev.txt index 3fe44dd904..c738b029a2 100644 --- a/backend/requirements/dev.txt +++ b/backend/requirements/dev.txt @@ -12,6 +12,7 @@ pandas==2.2.3 posthog==3.7.4 pre-commit==3.2.2 pytest-asyncio==0.22.0 +pytest-dotenv==0.5.2 pytest-xdist==3.6.1 pytest==8.3.5 reorder-python-imports-black==3.14.0 diff --git a/backend/tests/daily/connectors/slack/test_slack_connector.py b/backend/tests/daily/connectors/slack/test_slack_connector.py index cd4c858ab7..d8ce2099dc 100644 --- a/backend/tests/daily/connectors/slack/test_slack_connector.py +++ b/backend/tests/daily/connectors/slack/test_slack_connector.py @@ -31,6 +31,7 @@ def slack_connector( connector = SlackConnector( channels=[channel] if channel else None, channel_regex_enabled=False, + use_redis=False, ) connector.client = mock_slack_client connector.set_credentials_provider(credentials_provider=slack_credentials_provider) diff --git a/backend/tests/integration/connector_job_tests/slack/slack_api_utils.py b/backend/tests/integration/connector_job_tests/slack/slack_api_utils.py index 6da99ba28a..7e4f8cfd0b 100644 --- a/backend/tests/integration/connector_job_tests/slack/slack_api_utils.py +++ b/backend/tests/integration/connector_job_tests/slack/slack_api_utils.py @@ -17,8 +17,7 @@ from slack_sdk.errors import SlackApiError from onyx.connectors.slack.connector import default_msg_filter from onyx.connectors.slack.connector import get_channel_messages -from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries -from onyx.connectors.slack.utils import make_slack_api_call_w_retries +from onyx.connectors.slack.utils import make_paginated_slack_api_call def _get_slack_channel_id(channel: dict[str, Any]) -> str: @@ -40,7 +39,7 @@ def _get_non_general_channels( channel_types.append("public_channel") conversations: list[dict[str, Any]] = [] - for result in make_paginated_slack_api_call_w_retries( + for result in make_paginated_slack_api_call( slack_client.conversations_list, exclude_archived=False, types=channel_types, @@ -64,7 +63,7 @@ def _clear_slack_conversation_members( ) -> None: channel_id = _get_slack_channel_id(channel) member_ids: list[str] = [] - for result in make_paginated_slack_api_call_w_retries( + for result in make_paginated_slack_api_call( slack_client.conversations_members, channel=channel_id, ): @@ -140,15 +139,13 @@ def _build_slack_channel_from_name( if channel: # If channel is provided, we rename it channel_id = _get_slack_channel_id(channel) - channel_response = make_slack_api_call_w_retries( - slack_client.conversations_rename, + channel_response = slack_client.conversations_rename( channel=channel_id, name=channel_name, ) else: # Otherwise, we create a new channel - channel_response = make_slack_api_call_w_retries( - slack_client.conversations_create, + channel_response = slack_client.conversations_create( name=channel_name, is_private=is_private, ) @@ -219,10 +216,13 @@ class SlackManager: @staticmethod def build_slack_user_email_id_map(slack_client: WebClient) -> dict[str, str]: - users_results = make_slack_api_call_w_retries( + users: list[dict[str, Any]] = [] + + for users_results in make_paginated_slack_api_call( slack_client.users_list, - ) - users: list[dict[str, Any]] = users_results.get("members", []) + ): + users.extend(users_results.get("members", [])) + user_email_id_map = {} for user in users: if not (email := user.get("profile", {}).get("email")): @@ -253,8 +253,7 @@ class SlackManager: slack_client: WebClient, channel: dict[str, Any], message: str ) -> None: channel_id = _get_slack_channel_id(channel) - make_slack_api_call_w_retries( - slack_client.chat_postMessage, + slack_client.chat_postMessage( channel=channel_id, text=message, ) @@ -274,7 +273,7 @@ class SlackManager: ) -> None: channel_types = ["private_channel", "public_channel"] channels: list[dict[str, Any]] = [] - for result in make_paginated_slack_api_call_w_retries( + for result in make_paginated_slack_api_call( slack_client.conversations_list, exclude_archived=False, types=channel_types,