From ccd372cc4a8dc2ea628fe6c76ae2658f0c351eff Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Mon, 31 Mar 2025 14:00:26 -0700 Subject: [PATCH] Bugfix/slack rate limiting (#4386) * use slack's built in rate limit handler for the bot * WIP * fix the slack rate limit handler * change default to 8 * cleanup * try catch int conversion just in case * linearize this logic better * code review comments --------- Co-authored-by: Richard Kuo (Onyx) --- .../background/celery/tasks/vespa/tasks.py | 3 +- backend/onyx/configs/app_configs.py | 2 +- backend/onyx/connectors/slack/connector.py | 50 +++++- .../connectors/slack/onyx_retry_handler.py | 159 ++++++++++++++++++ backend/onyx/connectors/slack/utils.py | 104 ++++++------ backend/onyx/redis/redis_pool.py | 1 + 6 files changed, 260 insertions(+), 59 deletions(-) create mode 100644 backend/onyx/connectors/slack/onyx_retry_handler.py diff --git a/backend/onyx/background/celery/tasks/vespa/tasks.py b/backend/onyx/background/celery/tasks/vespa/tasks.py index f319d271eb2a..72c0f6481683 100644 --- a/backend/onyx/background/celery/tasks/vespa/tasks.py +++ b/backend/onyx/background/celery/tasks/vespa/tasks.py @@ -80,7 +80,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None: """Runs periodically to check if any document needs syncing. Generates sets of tasks for Celery if syncing is needed.""" - # Useful for debugging timing issues with reacquisitions. TODO: remove once more generalized logging is in place + # Useful for debugging timing issues with reacquisitions. + # TODO: remove once more generalized logging is in place task_logger.info("check_for_vespa_sync_task started") time_start = time.monotonic() diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 2a73fe60fec0..bb975eef5bbb 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -437,7 +437,7 @@ LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID") LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET") # Slack specific configs -SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 2) +SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8) DASK_JOB_CLIENT_ENABLED = ( os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true" diff --git a/backend/onyx/connectors/slack/connector.py b/backend/onyx/connectors/slack/connector.py index b38f216b377a..b72f5b7ac6f8 100644 --- a/backend/onyx/connectors/slack/connector.py +++ b/backend/onyx/connectors/slack/connector.py @@ -14,6 +14,8 @@ from typing import cast from pydantic import BaseModel from slack_sdk import WebClient from slack_sdk.errors import SlackApiError +from slack_sdk.http_retry import ConnectionErrorRetryHandler +from slack_sdk.http_retry import RetryHandler from typing_extensions import override from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS @@ -26,6 +28,8 @@ from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import CheckpointConnector from onyx.connectors.interfaces import CheckpointOutput +from onyx.connectors.interfaces import CredentialsConnector +from onyx.connectors.interfaces import CredentialsProviderInterface from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnector @@ -38,15 +42,16 @@ from onyx.connectors.models import DocumentFailure from onyx.connectors.models import EntityFailure from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection +from onyx.connectors.slack.onyx_retry_handler import OnyxRedisSlackRetryHandler from onyx.connectors.slack.utils import expert_info_from_slack_id from onyx.connectors.slack.utils import get_message_link from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries from onyx.connectors.slack.utils import make_slack_api_call_w_retries from onyx.connectors.slack.utils import SlackTextCleaner from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface +from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger - logger = setup_logger() _SLACK_LIMIT = 900 @@ -493,9 +498,13 @@ def _process_message( ) -class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]): +class SlackConnector( + SlimConnector, CredentialsConnector, CheckpointConnector[SlackCheckpoint] +): FAST_TIMEOUT = 1 + MAX_RETRIES = 7 # arbitrarily selected + def __init__( self, channels: list[str] | None = None, @@ -514,16 +523,49 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]): # just used for efficiency self.text_cleaner: SlackTextCleaner | None = None self.user_cache: dict[str, BasicExpertInfo | None] = {} + self.credentials_provider: CredentialsProviderInterface | None = None + self.credential_prefix: str | None = None + self.delay_lock: str | None = None # the redis key for the shared lock + self.delay_key: str | None = None # the redis key for the shared delay def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + raise NotImplementedError("Use set_credentials_provider with this connector.") + + def set_credentials_provider( + self, credentials_provider: CredentialsProviderInterface + ) -> None: + credentials = credentials_provider.get_credentials() + tenant_id = credentials_provider.get_tenant_id() + self.redis = get_redis_client(tenant_id=tenant_id) + + self.credential_prefix = ( + f"connector:slack:credential_{credentials_provider.get_provider_key()}" + ) + self.delay_lock = f"{self.credential_prefix}:delay_lock" + self.delay_key = f"{self.credential_prefix}:delay" + + # NOTE: slack has a built in RateLimitErrorRetryHandler, but it isn't designed + # for concurrent workers. We've extended it with OnyxRedisSlackRetryHandler. + connection_error_retry_handler = ConnectionErrorRetryHandler() + onyx_rate_limit_error_retry_handler = OnyxRedisSlackRetryHandler( + max_retry_count=self.MAX_RETRIES, + delay_lock=self.delay_lock, + delay_key=self.delay_key, + r=self.redis, + ) + custom_retry_handlers: list[RetryHandler] = [ + connection_error_retry_handler, + onyx_rate_limit_error_retry_handler, + ] + bot_token = credentials["slack_bot_token"] - self.client = WebClient(token=bot_token) + self.client = WebClient(token=bot_token, retry_handlers=custom_retry_handlers) # use for requests that must return quickly (e.g. realtime flows where user is waiting) self.fast_client = WebClient( token=bot_token, timeout=SlackConnector.FAST_TIMEOUT ) self.text_cleaner = SlackTextCleaner(client=self.client) - return None + self.credentials_provider = credentials_provider def retrieve_all_slim_documents( self, diff --git a/backend/onyx/connectors/slack/onyx_retry_handler.py b/backend/onyx/connectors/slack/onyx_retry_handler.py new file mode 100644 index 000000000000..7031db379f65 --- /dev/null +++ b/backend/onyx/connectors/slack/onyx_retry_handler.py @@ -0,0 +1,159 @@ +import math +import random +import time +from typing import cast +from typing import Optional + +from redis import Redis +from redis.lock import Lock as RedisLock +from slack_sdk.http_retry.handler import RetryHandler +from slack_sdk.http_retry.request import HttpRequest +from slack_sdk.http_retry.response import HttpResponse +from slack_sdk.http_retry.state import RetryState + +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +class OnyxRedisSlackRetryHandler(RetryHandler): + """ + This class uses Redis to share a rate limit among multiple threads. + + Threads that encounter a rate limit will observe the shared delay, increment the + shared delay with the retry value, and use the new shared value as a wait interval. + + This has the effect of serializing calls when a rate limit is hit, which is what + needs to happens if the server punishes us with additional limiting when we make + a call too early. We believe this is what Slack is doing based on empirical + observation, meaning we see indefinite hangs if we're too aggressive. + + Another way to do this is just to do exponential backoff. Might be easier? + + Adapted from slack's RateLimitErrorRetryHandler. + """ + + LOCK_TTL = 60 # used to serialize access to the retry TTL + LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock + + """RetryHandler that does retries for rate limited errors.""" + + def __init__( + self, + max_retry_count: int, + delay_lock: str, + delay_key: str, + r: Redis, + ): + """ + delay_lock: the redis key to use with RedisLock (to synchronize access to delay_key) + delay_key: the redis key containing a shared TTL + """ + super().__init__(max_retry_count=max_retry_count) + self._redis: Redis = r + self._delay_lock = delay_lock + self._delay_key = delay_key + + def _can_retry( + self, + *, + state: RetryState, + request: HttpRequest, + response: Optional[HttpResponse] = None, + error: Optional[Exception] = None, + ) -> bool: + return response is not None and response.status_code == 429 + + def prepare_for_next_attempt( + self, + *, + state: RetryState, + request: HttpRequest, + response: Optional[HttpResponse] = None, + error: Optional[Exception] = None, + ) -> None: + """It seems this function is responsible for the wait to retry ... aka we + actually sleep in this function.""" + retry_after_value: list[str] | None = None + retry_after_header_name: Optional[str] = None + duration_s: float = 1.0 # seconds + + if response is None: + # NOTE(rkuo): this logic comes from RateLimitErrorRetryHandler. + # This reads oddly, as if the caller itself could raise the exception. + # We don't have the luxury of changing this. + if error: + raise error + + return + + state.next_attempt_requested = True # this signals the caller to retry + + # calculate wait duration based on retry-after + some jitter + for k in response.headers.keys(): + if k.lower() == "retry-after": + retry_after_header_name = k + break + + try: + if retry_after_header_name is None: + # This situation usually does not arise. Just in case. + raise ValueError( + "OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header name is None" + ) + + retry_after_value = response.headers.get(retry_after_header_name) + if not retry_after_value: + raise ValueError( + "OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header value is None" + ) + + retry_after_value_int = int( + retry_after_value[0] + ) # will raise ValueError if somehow we can't convert to int + jitter = retry_after_value_int * 0.25 * random.random() + duration_s = math.ceil(retry_after_value_int + jitter) + except ValueError: + duration_s += random.random() + + # lock and extend the ttl + lock: RedisLock = self._redis.lock( + self._delay_lock, + timeout=OnyxRedisSlackRetryHandler.LOCK_TTL, + thread_local=False, + ) + + acquired = lock.acquire( + blocking_timeout=OnyxRedisSlackRetryHandler.LOCK_BLOCKING_TIMEOUT / 2 + ) + + ttl_ms: int | None = None + + try: + if acquired: + # if we can get the lock, then read and extend the ttl + ttl_ms = cast(int, self._redis.pttl(self._delay_key)) + if ttl_ms < 0: # negative values are error status codes ... see docs + ttl_ms = 0 + ttl_ms_new = ttl_ms + int(duration_s * 1000.0) + self._redis.set(self._delay_key, "1", px=ttl_ms_new) + else: + # if we can't get the lock, just go ahead. + # TODO: if we know our actual parallelism, multiplying by that + # would be a pretty good idea + ttl_ms_new = int(duration_s * 1000.0) + finally: + if acquired: + lock.release() + + logger.warning( + f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt wait: " + f"retry-after={retry_after_value} " + f"shared_delay_ms={ttl_ms} new_shared_delay_ms={ttl_ms_new}" + ) + + # TODO: would be good to take an event var and sleep in short increments to + # allow for a clean exit / exception + time.sleep(ttl_ms_new / 1000.0) + + state.increment_current_attempt() diff --git a/backend/onyx/connectors/slack/utils.py b/backend/onyx/connectors/slack/utils.py index 757036a9fa21..d60498279c74 100644 --- a/backend/onyx/connectors/slack/utils.py +++ b/backend/onyx/connectors/slack/utils.py @@ -1,5 +1,4 @@ import re -import time from collections.abc import Callable from collections.abc import Generator from functools import lru_cache @@ -64,71 +63,72 @@ def _make_slack_api_call_paginated( return paginated_call -def make_slack_api_rate_limited( - call: Callable[..., SlackResponse], max_retries: int = 7 -) -> Callable[..., SlackResponse]: - """Wraps calls to slack API so that they automatically handle rate limiting""" +# NOTE(rkuo): we may not need this any more if the integrated retry handlers work as +# expected. Do we want to keep this around? - @wraps(call) - def rate_limited_call(**kwargs: Any) -> SlackResponse: - last_exception = None +# def make_slack_api_rate_limited( +# call: Callable[..., SlackResponse], max_retries: int = 7 +# ) -> Callable[..., SlackResponse]: +# """Wraps calls to slack API so that they automatically handle rate limiting""" - for _ in range(max_retries): - try: - # Make the API call - response = call(**kwargs) +# @wraps(call) +# def rate_limited_call(**kwargs: Any) -> SlackResponse: +# last_exception = None - # Check for errors in the response, will raise `SlackApiError` - # if anything went wrong - response.validate() - return response +# for _ in range(max_retries): +# try: +# # Make the API call +# response = call(**kwargs) - except SlackApiError as e: - last_exception = e - try: - error = e.response["error"] - except KeyError: - error = "unknown error" +# # Check for errors in the response, will raise `SlackApiError` +# # if anything went wrong +# response.validate() +# return response - if error == "ratelimited": - # Handle rate limiting: get the 'Retry-After' header value and sleep for that duration - retry_after = int(e.response.headers.get("Retry-After", 1)) - logger.info( - f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}" - ) - time.sleep(retry_after) - elif error in ["already_reacted", "no_reaction", "internal_error"]: - # Log internal_error and return the response instead of failing - logger.warning( - f"Slack call encountered '{error}', skipping and continuing..." - ) - return e.response - else: - # Raise the error for non-transient errors - raise +# except SlackApiError as e: +# last_exception = e +# try: +# error = e.response["error"] +# except KeyError: +# error = "unknown error" - # If the code reaches this point, all retries have been exhausted - msg = f"Max retries ({max_retries}) exceeded" - if last_exception: - raise Exception(msg) from last_exception - else: - raise Exception(msg) +# if error == "ratelimited": +# # Handle rate limiting: get the 'Retry-After' header value and sleep for that duration +# retry_after = int(e.response.headers.get("Retry-After", 1)) +# logger.info( +# f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}" +# ) +# time.sleep(retry_after) +# elif error in ["already_reacted", "no_reaction", "internal_error"]: +# # Log internal_error and return the response instead of failing +# logger.warning( +# f"Slack call encountered '{error}', skipping and continuing..." +# ) +# return e.response +# else: +# # Raise the error for non-transient errors +# raise - return rate_limited_call +# # If the code reaches this point, all retries have been exhausted +# msg = f"Max retries ({max_retries}) exceeded" +# if last_exception: +# raise Exception(msg) from last_exception +# else: +# raise Exception(msg) + +# return rate_limited_call def make_slack_api_call_w_retries( call: Callable[..., SlackResponse], **kwargs: Any ) -> SlackResponse: - return basic_retry_wrapper(make_slack_api_rate_limited(call))(**kwargs) + return basic_retry_wrapper(call)(**kwargs) def make_paginated_slack_api_call_w_retries( call: Callable[..., SlackResponse], **kwargs: Any ) -> Generator[dict[str, Any], None, None]: - return _make_slack_api_call_paginated( - basic_retry_wrapper(make_slack_api_rate_limited(call)) - )(**kwargs) + return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs) def expert_info_from_slack_id( @@ -142,7 +142,7 @@ def expert_info_from_slack_id( if user_id in user_cache: return user_cache[user_id] - response = make_slack_api_rate_limited(client.users_info)(user=user_id) + response = client.users_info(user=user_id) if not response["ok"]: user_cache[user_id] = None @@ -175,9 +175,7 @@ class SlackTextCleaner: def _get_slack_name(self, user_id: str) -> str: if user_id not in self._id_to_name_map: try: - response = make_slack_api_rate_limited(self._client.users_info)( - user=user_id - ) + response = self._client.users_info(user=user_id) # prefer display name if set, since that is what is shown in Slack self._id_to_name_map[user_id] = ( response["user"]["profile"]["display_name"] diff --git a/backend/onyx/redis/redis_pool.py b/backend/onyx/redis/redis_pool.py index e56f13d1ef04..584def16a5d9 100644 --- a/backend/onyx/redis/redis_pool.py +++ b/backend/onyx/redis/redis_pool.py @@ -125,6 +125,7 @@ class TenantRedis(redis.Redis): "hset", "hdel", "ttl", + "pttl", ] # Regular methods that need simple prefixing if item == "scan_iter" or item == "sscan_iter":