Bugfix/slack rate limiting (#4386)

* use slack's built in rate limit handler for the bot

* WIP

* fix the slack rate limit handler

* change default to 8

* cleanup

* try catch int conversion just in case

* linearize this logic better

* code review comments

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer
2025-03-31 14:00:26 -07:00
committed by GitHub
parent ea30f1de1e
commit ccd372cc4a
6 changed files with 260 additions and 59 deletions

View File

@@ -80,7 +80,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
# Useful for debugging timing issues with reacquisitions. TODO: remove once more generalized logging is in place
# Useful for debugging timing issues with reacquisitions.
# TODO: remove once more generalized logging is in place
task_logger.info("check_for_vespa_sync_task started")
time_start = time.monotonic()

View File

@@ -437,7 +437,7 @@ LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
# Slack specific configs
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 2)
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8)
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"

View File

@@ -14,6 +14,8 @@ from typing import cast
from pydantic import BaseModel
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.http_retry import ConnectionErrorRetryHandler
from slack_sdk.http_retry import RetryHandler
from typing_extensions import override
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
@@ -26,6 +28,8 @@ from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
@@ -38,15 +42,16 @@ from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.slack.onyx_retry_handler import OnyxRedisSlackRetryHandler
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import get_message_link
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
from onyx.connectors.slack.utils import SlackTextCleaner
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
_SLACK_LIMIT = 900
@@ -493,9 +498,13 @@ def _process_message(
)
class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
class SlackConnector(
SlimConnector, CredentialsConnector, CheckpointConnector[SlackCheckpoint]
):
FAST_TIMEOUT = 1
MAX_RETRIES = 7 # arbitrarily selected
def __init__(
self,
channels: list[str] | None = None,
@@ -514,16 +523,49 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
# just used for efficiency
self.text_cleaner: SlackTextCleaner | None = None
self.user_cache: dict[str, BasicExpertInfo | None] = {}
self.credentials_provider: CredentialsProviderInterface | None = None
self.credential_prefix: str | None = None
self.delay_lock: str | None = None # the redis key for the shared lock
self.delay_key: str | None = None # the redis key for the shared delay
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError("Use set_credentials_provider with this connector.")
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
credentials = credentials_provider.get_credentials()
tenant_id = credentials_provider.get_tenant_id()
self.redis = get_redis_client(tenant_id=tenant_id)
self.credential_prefix = (
f"connector:slack:credential_{credentials_provider.get_provider_key()}"
)
self.delay_lock = f"{self.credential_prefix}:delay_lock"
self.delay_key = f"{self.credential_prefix}:delay"
# NOTE: slack has a built in RateLimitErrorRetryHandler, but it isn't designed
# for concurrent workers. We've extended it with OnyxRedisSlackRetryHandler.
connection_error_retry_handler = ConnectionErrorRetryHandler()
onyx_rate_limit_error_retry_handler = OnyxRedisSlackRetryHandler(
max_retry_count=self.MAX_RETRIES,
delay_lock=self.delay_lock,
delay_key=self.delay_key,
r=self.redis,
)
custom_retry_handlers: list[RetryHandler] = [
connection_error_retry_handler,
onyx_rate_limit_error_retry_handler,
]
bot_token = credentials["slack_bot_token"]
self.client = WebClient(token=bot_token)
self.client = WebClient(token=bot_token, retry_handlers=custom_retry_handlers)
# use for requests that must return quickly (e.g. realtime flows where user is waiting)
self.fast_client = WebClient(
token=bot_token, timeout=SlackConnector.FAST_TIMEOUT
)
self.text_cleaner = SlackTextCleaner(client=self.client)
return None
self.credentials_provider = credentials_provider
def retrieve_all_slim_documents(
self,

View File

@@ -0,0 +1,159 @@
import math
import random
import time
from typing import cast
from typing import Optional
from redis import Redis
from redis.lock import Lock as RedisLock
from slack_sdk.http_retry.handler import RetryHandler
from slack_sdk.http_retry.request import HttpRequest
from slack_sdk.http_retry.response import HttpResponse
from slack_sdk.http_retry.state import RetryState
from onyx.utils.logger import setup_logger
logger = setup_logger()
class OnyxRedisSlackRetryHandler(RetryHandler):
"""
This class uses Redis to share a rate limit among multiple threads.
Threads that encounter a rate limit will observe the shared delay, increment the
shared delay with the retry value, and use the new shared value as a wait interval.
This has the effect of serializing calls when a rate limit is hit, which is what
needs to happens if the server punishes us with additional limiting when we make
a call too early. We believe this is what Slack is doing based on empirical
observation, meaning we see indefinite hangs if we're too aggressive.
Another way to do this is just to do exponential backoff. Might be easier?
Adapted from slack's RateLimitErrorRetryHandler.
"""
LOCK_TTL = 60 # used to serialize access to the retry TTL
LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock
"""RetryHandler that does retries for rate limited errors."""
def __init__(
self,
max_retry_count: int,
delay_lock: str,
delay_key: str,
r: Redis,
):
"""
delay_lock: the redis key to use with RedisLock (to synchronize access to delay_key)
delay_key: the redis key containing a shared TTL
"""
super().__init__(max_retry_count=max_retry_count)
self._redis: Redis = r
self._delay_lock = delay_lock
self._delay_key = delay_key
def _can_retry(
self,
*,
state: RetryState,
request: HttpRequest,
response: Optional[HttpResponse] = None,
error: Optional[Exception] = None,
) -> bool:
return response is not None and response.status_code == 429
def prepare_for_next_attempt(
self,
*,
state: RetryState,
request: HttpRequest,
response: Optional[HttpResponse] = None,
error: Optional[Exception] = None,
) -> None:
"""It seems this function is responsible for the wait to retry ... aka we
actually sleep in this function."""
retry_after_value: list[str] | None = None
retry_after_header_name: Optional[str] = None
duration_s: float = 1.0 # seconds
if response is None:
# NOTE(rkuo): this logic comes from RateLimitErrorRetryHandler.
# This reads oddly, as if the caller itself could raise the exception.
# We don't have the luxury of changing this.
if error:
raise error
return
state.next_attempt_requested = True # this signals the caller to retry
# calculate wait duration based on retry-after + some jitter
for k in response.headers.keys():
if k.lower() == "retry-after":
retry_after_header_name = k
break
try:
if retry_after_header_name is None:
# This situation usually does not arise. Just in case.
raise ValueError(
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header name is None"
)
retry_after_value = response.headers.get(retry_after_header_name)
if not retry_after_value:
raise ValueError(
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header value is None"
)
retry_after_value_int = int(
retry_after_value[0]
) # will raise ValueError if somehow we can't convert to int
jitter = retry_after_value_int * 0.25 * random.random()
duration_s = math.ceil(retry_after_value_int + jitter)
except ValueError:
duration_s += random.random()
# lock and extend the ttl
lock: RedisLock = self._redis.lock(
self._delay_lock,
timeout=OnyxRedisSlackRetryHandler.LOCK_TTL,
thread_local=False,
)
acquired = lock.acquire(
blocking_timeout=OnyxRedisSlackRetryHandler.LOCK_BLOCKING_TIMEOUT / 2
)
ttl_ms: int | None = None
try:
if acquired:
# if we can get the lock, then read and extend the ttl
ttl_ms = cast(int, self._redis.pttl(self._delay_key))
if ttl_ms < 0: # negative values are error status codes ... see docs
ttl_ms = 0
ttl_ms_new = ttl_ms + int(duration_s * 1000.0)
self._redis.set(self._delay_key, "1", px=ttl_ms_new)
else:
# if we can't get the lock, just go ahead.
# TODO: if we know our actual parallelism, multiplying by that
# would be a pretty good idea
ttl_ms_new = int(duration_s * 1000.0)
finally:
if acquired:
lock.release()
logger.warning(
f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt wait: "
f"retry-after={retry_after_value} "
f"shared_delay_ms={ttl_ms} new_shared_delay_ms={ttl_ms_new}"
)
# TODO: would be good to take an event var and sleep in short increments to
# allow for a clean exit / exception
time.sleep(ttl_ms_new / 1000.0)
state.increment_current_attempt()

View File

@@ -1,5 +1,4 @@
import re
import time
from collections.abc import Callable
from collections.abc import Generator
from functools import lru_cache
@@ -64,71 +63,72 @@ def _make_slack_api_call_paginated(
return paginated_call
def make_slack_api_rate_limited(
call: Callable[..., SlackResponse], max_retries: int = 7
) -> Callable[..., SlackResponse]:
"""Wraps calls to slack API so that they automatically handle rate limiting"""
# NOTE(rkuo): we may not need this any more if the integrated retry handlers work as
# expected. Do we want to keep this around?
@wraps(call)
def rate_limited_call(**kwargs: Any) -> SlackResponse:
last_exception = None
# def make_slack_api_rate_limited(
# call: Callable[..., SlackResponse], max_retries: int = 7
# ) -> Callable[..., SlackResponse]:
# """Wraps calls to slack API so that they automatically handle rate limiting"""
for _ in range(max_retries):
try:
# Make the API call
response = call(**kwargs)
# @wraps(call)
# def rate_limited_call(**kwargs: Any) -> SlackResponse:
# last_exception = None
# Check for errors in the response, will raise `SlackApiError`
# if anything went wrong
response.validate()
return response
# for _ in range(max_retries):
# try:
# # Make the API call
# response = call(**kwargs)
except SlackApiError as e:
last_exception = e
try:
error = e.response["error"]
except KeyError:
error = "unknown error"
# # Check for errors in the response, will raise `SlackApiError`
# # if anything went wrong
# response.validate()
# return response
if error == "ratelimited":
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
retry_after = int(e.response.headers.get("Retry-After", 1))
logger.info(
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
)
time.sleep(retry_after)
elif error in ["already_reacted", "no_reaction", "internal_error"]:
# Log internal_error and return the response instead of failing
logger.warning(
f"Slack call encountered '{error}', skipping and continuing..."
)
return e.response
else:
# Raise the error for non-transient errors
raise
# except SlackApiError as e:
# last_exception = e
# try:
# error = e.response["error"]
# except KeyError:
# error = "unknown error"
# If the code reaches this point, all retries have been exhausted
msg = f"Max retries ({max_retries}) exceeded"
if last_exception:
raise Exception(msg) from last_exception
else:
raise Exception(msg)
# if error == "ratelimited":
# # Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
# retry_after = int(e.response.headers.get("Retry-After", 1))
# logger.info(
# f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
# )
# time.sleep(retry_after)
# elif error in ["already_reacted", "no_reaction", "internal_error"]:
# # Log internal_error and return the response instead of failing
# logger.warning(
# f"Slack call encountered '{error}', skipping and continuing..."
# )
# return e.response
# else:
# # Raise the error for non-transient errors
# raise
return rate_limited_call
# # If the code reaches this point, all retries have been exhausted
# msg = f"Max retries ({max_retries}) exceeded"
# if last_exception:
# raise Exception(msg) from last_exception
# else:
# raise Exception(msg)
# return rate_limited_call
def make_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(make_slack_api_rate_limited(call))(**kwargs)
return basic_retry_wrapper(call)(**kwargs)
def make_paginated_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return _make_slack_api_call_paginated(
basic_retry_wrapper(make_slack_api_rate_limited(call))
)(**kwargs)
return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs)
def expert_info_from_slack_id(
@@ -142,7 +142,7 @@ def expert_info_from_slack_id(
if user_id in user_cache:
return user_cache[user_id]
response = make_slack_api_rate_limited(client.users_info)(user=user_id)
response = client.users_info(user=user_id)
if not response["ok"]:
user_cache[user_id] = None
@@ -175,9 +175,7 @@ class SlackTextCleaner:
def _get_slack_name(self, user_id: str) -> str:
if user_id not in self._id_to_name_map:
try:
response = make_slack_api_rate_limited(self._client.users_info)(
user=user_id
)
response = self._client.users_info(user=user_id)
# prefer display name if set, since that is what is shown in Slack
self._id_to_name_map[user_id] = (
response["user"]["profile"]["display_name"]

View File

@@ -125,6 +125,7 @@ class TenantRedis(redis.Redis):
"hset",
"hdel",
"ttl",
"pttl",
] # Regular methods that need simple prefixing
if item == "scan_iter" or item == "sscan_iter":