Add support to limit the number of Slack questions per minute (#908)

Co-authored-by: Matthieu Boret <matthieu.boret@fr.clara.net>
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
mattboret 2024-01-16 06:26:35 +01:00 committed by GitHub
parent d17426749d
commit 53add2c801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 93 additions and 0 deletions

View File

@ -50,3 +50,8 @@ ENABLE_DANSWERBOT_REFLEXION = (
)
# Currently not support chain of thought, probably will add back later
DANSWER_BOT_DISABLE_COT = True
# Maximum Questions Per Minute
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM", "100"))
# Maximum time to wait when a question is queued
DANSWER_BOT_MAX_WAIT_TIME = int(os.environ.get("DANSWER_BOT_MAX_WAIT_TIME", "180"))

View File

@ -1,5 +1,10 @@
import functools
import logging
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import Optional
from typing import TypeVar
from retry import retry
from slack_sdk import WebClient
@ -23,6 +28,7 @@ from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import ChannelIdAdapter
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import SlackBotConfig
@ -38,6 +44,29 @@ from danswer.utils.telemetry import RecordType
logger_base = setup_logger()
srl = SlackRateLimiter()
RT = TypeVar("RT") # return type
def rate_limits(
client: WebClient, channel: str, thread_ts: Optional[str]
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> RT:
if not srl.is_available():
func_randid, position = srl.init_waiter()
srl.notify(client, channel, position, thread_ts)
while not srl.is_available():
srl.waiter(func_randid)
srl.acquire_slot()
return func(*args, **kwargs)
return wrapper
return decorator
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
if details.is_bot_msg and details.sender:
@ -177,6 +206,7 @@ def handle_message(
backoff=2,
logger=logger,
)
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse:
action = "slack_message"
if is_bot_msg:

View File

@ -2,9 +2,11 @@ import logging
import random
import re
import string
import time
from collections.abc import MutableMapping
from typing import Any
from typing import cast
from typing import Optional
from retry import retry
from slack_sdk import WebClient
@ -14,6 +16,8 @@ from slack_sdk.models.metadata import Metadata
from danswer.configs.constants import ID_SEPARATOR
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_MAX_QPM
from danswer.configs.danswerbot_configs import DANSWER_BOT_MAX_WAIT_TIME
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner
@ -347,3 +351,57 @@ def read_slack_thread(
)
return thread_messages
class SlackRateLimiter:
def __init__(self) -> None:
self.max_qpm = DANSWER_BOT_MAX_QPM
self.max_wait_time = DANSWER_BOT_MAX_WAIT_TIME
self.active_question = 0
self.last_reset_time = time.time()
self.waiting_questions: list[int] = []
def refill(self) -> None:
# If elapsed time is greater than the period, reset the active question count
if (time.time() - self.last_reset_time) > 60:
self.active_question = 0
self.last_reset_time = time.time()
def notify(
self, client: WebClient, channel: str, position: int, thread_ts: Optional[str]
) -> None:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=None,
text=f"Your question has been queued. You are in position {position}... please wait a moment :loading:",
thread_ts=thread_ts,
)
def is_available(self) -> bool:
self.refill()
return self.active_question < self.max_qpm
def acquire_slot(self) -> None:
self.active_question += 1
def init_waiter(self) -> tuple[int, int]:
func_randid = random.getrandbits(128)
self.waiting_questions.append(func_randid)
position = self.waiting_questions.index(func_randid) + 1
return func_randid, position
def waiter(self, func_randid: int) -> None:
wait_time = 0
while (
self.active_question >= self.max_qpm
or self.waiting_questions[0] != func_randid
):
if wait_time > self.max_wait_time:
raise TimeoutError
time.sleep(2)
wait_time += 2
self.refill()
del self.waiting_questions[0]