mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Bugfix/slack rate limiting (#4386)
* use slack's built in rate limit handler for the bot * WIP * fix the slack rate limit handler * change default to 8 * cleanup * try catch int conversion just in case * linearize this logic better * code review comments --------- Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
@@ -80,7 +80,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
# Useful for debugging timing issues with reacquisitions. TODO: remove once more generalized logging is in place
|
||||
# Useful for debugging timing issues with reacquisitions.
|
||||
# TODO: remove once more generalized logging is in place
|
||||
task_logger.info("check_for_vespa_sync_task started")
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
@@ -437,7 +437,7 @@ LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
|
||||
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
|
||||
# Slack specific configs
|
||||
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 2)
|
||||
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8)
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
|
@@ -14,6 +14,8 @@ from typing import cast
|
||||
from pydantic import BaseModel
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.http_retry import ConnectionErrorRetryHandler
|
||||
from slack_sdk.http_retry import RetryHandler
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
@@ -26,6 +28,8 @@ from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
@@ -38,15 +42,16 @@ from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.slack.onyx_retry_handler import OnyxRedisSlackRetryHandler
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import get_message_link
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SLACK_LIMIT = 900
|
||||
@@ -493,9 +498,13 @@ def _process_message(
|
||||
)
|
||||
|
||||
|
||||
class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
class SlackConnector(
|
||||
SlimConnector, CredentialsConnector, CheckpointConnector[SlackCheckpoint]
|
||||
):
|
||||
FAST_TIMEOUT = 1
|
||||
|
||||
MAX_RETRIES = 7 # arbitrarily selected
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: list[str] | None = None,
|
||||
@@ -514,16 +523,49 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
|
||||
# just used for efficiency
|
||||
self.text_cleaner: SlackTextCleaner | None = None
|
||||
self.user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
self.credentials_provider: CredentialsProviderInterface | None = None
|
||||
self.credential_prefix: str | None = None
|
||||
self.delay_lock: str | None = None # the redis key for the shared lock
|
||||
self.delay_key: str | None = None # the redis key for the shared delay
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
raise NotImplementedError("Use set_credentials_provider with this connector.")
|
||||
|
||||
def set_credentials_provider(
|
||||
self, credentials_provider: CredentialsProviderInterface
|
||||
) -> None:
|
||||
credentials = credentials_provider.get_credentials()
|
||||
tenant_id = credentials_provider.get_tenant_id()
|
||||
self.redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
self.credential_prefix = (
|
||||
f"connector:slack:credential_{credentials_provider.get_provider_key()}"
|
||||
)
|
||||
self.delay_lock = f"{self.credential_prefix}:delay_lock"
|
||||
self.delay_key = f"{self.credential_prefix}:delay"
|
||||
|
||||
# NOTE: slack has a built in RateLimitErrorRetryHandler, but it isn't designed
|
||||
# for concurrent workers. We've extended it with OnyxRedisSlackRetryHandler.
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler()
|
||||
onyx_rate_limit_error_retry_handler = OnyxRedisSlackRetryHandler(
|
||||
max_retry_count=self.MAX_RETRIES,
|
||||
delay_lock=self.delay_lock,
|
||||
delay_key=self.delay_key,
|
||||
r=self.redis,
|
||||
)
|
||||
custom_retry_handlers: list[RetryHandler] = [
|
||||
connection_error_retry_handler,
|
||||
onyx_rate_limit_error_retry_handler,
|
||||
]
|
||||
|
||||
bot_token = credentials["slack_bot_token"]
|
||||
self.client = WebClient(token=bot_token)
|
||||
self.client = WebClient(token=bot_token, retry_handlers=custom_retry_handlers)
|
||||
# use for requests that must return quickly (e.g. realtime flows where user is waiting)
|
||||
self.fast_client = WebClient(
|
||||
token=bot_token, timeout=SlackConnector.FAST_TIMEOUT
|
||||
)
|
||||
self.text_cleaner = SlackTextCleaner(client=self.client)
|
||||
return None
|
||||
self.credentials_provider = credentials_provider
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
|
159
backend/onyx/connectors/slack/onyx_retry_handler.py
Normal file
159
backend/onyx/connectors/slack/onyx_retry_handler.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from slack_sdk.http_retry.handler import RetryHandler
|
||||
from slack_sdk.http_retry.request import HttpRequest
|
||||
from slack_sdk.http_retry.response import HttpResponse
|
||||
from slack_sdk.http_retry.state import RetryState
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxRedisSlackRetryHandler(RetryHandler):
|
||||
"""
|
||||
This class uses Redis to share a rate limit among multiple threads.
|
||||
|
||||
Threads that encounter a rate limit will observe the shared delay, increment the
|
||||
shared delay with the retry value, and use the new shared value as a wait interval.
|
||||
|
||||
This has the effect of serializing calls when a rate limit is hit, which is what
|
||||
needs to happens if the server punishes us with additional limiting when we make
|
||||
a call too early. We believe this is what Slack is doing based on empirical
|
||||
observation, meaning we see indefinite hangs if we're too aggressive.
|
||||
|
||||
Another way to do this is just to do exponential backoff. Might be easier?
|
||||
|
||||
Adapted from slack's RateLimitErrorRetryHandler.
|
||||
"""
|
||||
|
||||
LOCK_TTL = 60 # used to serialize access to the retry TTL
|
||||
LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock
|
||||
|
||||
"""RetryHandler that does retries for rate limited errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retry_count: int,
|
||||
delay_lock: str,
|
||||
delay_key: str,
|
||||
r: Redis,
|
||||
):
|
||||
"""
|
||||
delay_lock: the redis key to use with RedisLock (to synchronize access to delay_key)
|
||||
delay_key: the redis key containing a shared TTL
|
||||
"""
|
||||
super().__init__(max_retry_count=max_retry_count)
|
||||
self._redis: Redis = r
|
||||
self._delay_lock = delay_lock
|
||||
self._delay_key = delay_key
|
||||
|
||||
def _can_retry(
|
||||
self,
|
||||
*,
|
||||
state: RetryState,
|
||||
request: HttpRequest,
|
||||
response: Optional[HttpResponse] = None,
|
||||
error: Optional[Exception] = None,
|
||||
) -> bool:
|
||||
return response is not None and response.status_code == 429
|
||||
|
||||
def prepare_for_next_attempt(
|
||||
self,
|
||||
*,
|
||||
state: RetryState,
|
||||
request: HttpRequest,
|
||||
response: Optional[HttpResponse] = None,
|
||||
error: Optional[Exception] = None,
|
||||
) -> None:
|
||||
"""It seems this function is responsible for the wait to retry ... aka we
|
||||
actually sleep in this function."""
|
||||
retry_after_value: list[str] | None = None
|
||||
retry_after_header_name: Optional[str] = None
|
||||
duration_s: float = 1.0 # seconds
|
||||
|
||||
if response is None:
|
||||
# NOTE(rkuo): this logic comes from RateLimitErrorRetryHandler.
|
||||
# This reads oddly, as if the caller itself could raise the exception.
|
||||
# We don't have the luxury of changing this.
|
||||
if error:
|
||||
raise error
|
||||
|
||||
return
|
||||
|
||||
state.next_attempt_requested = True # this signals the caller to retry
|
||||
|
||||
# calculate wait duration based on retry-after + some jitter
|
||||
for k in response.headers.keys():
|
||||
if k.lower() == "retry-after":
|
||||
retry_after_header_name = k
|
||||
break
|
||||
|
||||
try:
|
||||
if retry_after_header_name is None:
|
||||
# This situation usually does not arise. Just in case.
|
||||
raise ValueError(
|
||||
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header name is None"
|
||||
)
|
||||
|
||||
retry_after_value = response.headers.get(retry_after_header_name)
|
||||
if not retry_after_value:
|
||||
raise ValueError(
|
||||
"OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header value is None"
|
||||
)
|
||||
|
||||
retry_after_value_int = int(
|
||||
retry_after_value[0]
|
||||
) # will raise ValueError if somehow we can't convert to int
|
||||
jitter = retry_after_value_int * 0.25 * random.random()
|
||||
duration_s = math.ceil(retry_after_value_int + jitter)
|
||||
except ValueError:
|
||||
duration_s += random.random()
|
||||
|
||||
# lock and extend the ttl
|
||||
lock: RedisLock = self._redis.lock(
|
||||
self._delay_lock,
|
||||
timeout=OnyxRedisSlackRetryHandler.LOCK_TTL,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=OnyxRedisSlackRetryHandler.LOCK_BLOCKING_TIMEOUT / 2
|
||||
)
|
||||
|
||||
ttl_ms: int | None = None
|
||||
|
||||
try:
|
||||
if acquired:
|
||||
# if we can get the lock, then read and extend the ttl
|
||||
ttl_ms = cast(int, self._redis.pttl(self._delay_key))
|
||||
if ttl_ms < 0: # negative values are error status codes ... see docs
|
||||
ttl_ms = 0
|
||||
ttl_ms_new = ttl_ms + int(duration_s * 1000.0)
|
||||
self._redis.set(self._delay_key, "1", px=ttl_ms_new)
|
||||
else:
|
||||
# if we can't get the lock, just go ahead.
|
||||
# TODO: if we know our actual parallelism, multiplying by that
|
||||
# would be a pretty good idea
|
||||
ttl_ms_new = int(duration_s * 1000.0)
|
||||
finally:
|
||||
if acquired:
|
||||
lock.release()
|
||||
|
||||
logger.warning(
|
||||
f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt wait: "
|
||||
f"retry-after={retry_after_value} "
|
||||
f"shared_delay_ms={ttl_ms} new_shared_delay_ms={ttl_ms_new}"
|
||||
)
|
||||
|
||||
# TODO: would be good to take an event var and sleep in short increments to
|
||||
# allow for a clean exit / exception
|
||||
time.sleep(ttl_ms_new / 1000.0)
|
||||
|
||||
state.increment_current_attempt()
|
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from functools import lru_cache
|
||||
@@ -64,71 +63,72 @@ def _make_slack_api_call_paginated(
|
||||
return paginated_call
|
||||
|
||||
|
||||
def make_slack_api_rate_limited(
|
||||
call: Callable[..., SlackResponse], max_retries: int = 7
|
||||
) -> Callable[..., SlackResponse]:
|
||||
"""Wraps calls to slack API so that they automatically handle rate limiting"""
|
||||
# NOTE(rkuo): we may not need this any more if the integrated retry handlers work as
|
||||
# expected. Do we want to keep this around?
|
||||
|
||||
@wraps(call)
|
||||
def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
||||
last_exception = None
|
||||
# def make_slack_api_rate_limited(
|
||||
# call: Callable[..., SlackResponse], max_retries: int = 7
|
||||
# ) -> Callable[..., SlackResponse]:
|
||||
# """Wraps calls to slack API so that they automatically handle rate limiting"""
|
||||
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
# Make the API call
|
||||
response = call(**kwargs)
|
||||
# @wraps(call)
|
||||
# def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
||||
# last_exception = None
|
||||
|
||||
# Check for errors in the response, will raise `SlackApiError`
|
||||
# if anything went wrong
|
||||
response.validate()
|
||||
return response
|
||||
# for _ in range(max_retries):
|
||||
# try:
|
||||
# # Make the API call
|
||||
# response = call(**kwargs)
|
||||
|
||||
except SlackApiError as e:
|
||||
last_exception = e
|
||||
try:
|
||||
error = e.response["error"]
|
||||
except KeyError:
|
||||
error = "unknown error"
|
||||
# # Check for errors in the response, will raise `SlackApiError`
|
||||
# # if anything went wrong
|
||||
# response.validate()
|
||||
# return response
|
||||
|
||||
if error == "ratelimited":
|
||||
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
|
||||
retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||
logger.info(
|
||||
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
elif error in ["already_reacted", "no_reaction", "internal_error"]:
|
||||
# Log internal_error and return the response instead of failing
|
||||
logger.warning(
|
||||
f"Slack call encountered '{error}', skipping and continuing..."
|
||||
)
|
||||
return e.response
|
||||
else:
|
||||
# Raise the error for non-transient errors
|
||||
raise
|
||||
# except SlackApiError as e:
|
||||
# last_exception = e
|
||||
# try:
|
||||
# error = e.response["error"]
|
||||
# except KeyError:
|
||||
# error = "unknown error"
|
||||
|
||||
# If the code reaches this point, all retries have been exhausted
|
||||
msg = f"Max retries ({max_retries}) exceeded"
|
||||
if last_exception:
|
||||
raise Exception(msg) from last_exception
|
||||
else:
|
||||
raise Exception(msg)
|
||||
# if error == "ratelimited":
|
||||
# # Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
|
||||
# retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||
# logger.info(
|
||||
# f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
|
||||
# )
|
||||
# time.sleep(retry_after)
|
||||
# elif error in ["already_reacted", "no_reaction", "internal_error"]:
|
||||
# # Log internal_error and return the response instead of failing
|
||||
# logger.warning(
|
||||
# f"Slack call encountered '{error}', skipping and continuing..."
|
||||
# )
|
||||
# return e.response
|
||||
# else:
|
||||
# # Raise the error for non-transient errors
|
||||
# raise
|
||||
|
||||
return rate_limited_call
|
||||
# # If the code reaches this point, all retries have been exhausted
|
||||
# msg = f"Max retries ({max_retries}) exceeded"
|
||||
# if last_exception:
|
||||
# raise Exception(msg) from last_exception
|
||||
# else:
|
||||
# raise Exception(msg)
|
||||
|
||||
# return rate_limited_call
|
||||
|
||||
|
||||
def make_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(make_slack_api_rate_limited(call))(**kwargs)
|
||||
return basic_retry_wrapper(call)(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(make_slack_api_rate_limited(call))
|
||||
)(**kwargs)
|
||||
return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs)
|
||||
|
||||
|
||||
def expert_info_from_slack_id(
|
||||
@@ -142,7 +142,7 @@ def expert_info_from_slack_id(
|
||||
if user_id in user_cache:
|
||||
return user_cache[user_id]
|
||||
|
||||
response = make_slack_api_rate_limited(client.users_info)(user=user_id)
|
||||
response = client.users_info(user=user_id)
|
||||
|
||||
if not response["ok"]:
|
||||
user_cache[user_id] = None
|
||||
@@ -175,9 +175,7 @@ class SlackTextCleaner:
|
||||
def _get_slack_name(self, user_id: str) -> str:
|
||||
if user_id not in self._id_to_name_map:
|
||||
try:
|
||||
response = make_slack_api_rate_limited(self._client.users_info)(
|
||||
user=user_id
|
||||
)
|
||||
response = self._client.users_info(user=user_id)
|
||||
# prefer display name if set, since that is what is shown in Slack
|
||||
self._id_to_name_map[user_id] = (
|
||||
response["user"]["profile"]["display_name"]
|
||||
|
@@ -125,6 +125,7 @@ class TenantRedis(redis.Redis):
|
||||
"hset",
|
||||
"hdel",
|
||||
"ttl",
|
||||
"pttl",
|
||||
] # Regular methods that need simple prefixing
|
||||
|
||||
if item == "scan_iter" or item == "sscan_iter":
|
||||
|
Reference in New Issue
Block a user