mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Add support to limit the number of Slack questions per minute (#908)
Co-authored-by: Matthieu Boret <matthieu.boret@fr.clara.net> Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
parent
d17426749d
commit
53add2c801
@ -50,3 +50,8 @@ ENABLE_DANSWERBOT_REFLEXION = (
|
||||
)
|
||||
# Currently not support chain of thought, probably will add back later
|
||||
DANSWER_BOT_DISABLE_COT = True
|
||||
|
||||
# Maximum Questions Per Minute
|
||||
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM", "100"))
|
||||
# Maximum time to wait when a question is queued
|
||||
DANSWER_BOT_MAX_WAIT_TIME = int(os.environ.get("DANSWER_BOT_MAX_WAIT_TIME", "180"))
|
||||
|
@ -1,5 +1,10 @@
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
@ -23,6 +28,7 @@ from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import ChannelIdAdapter
|
||||
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import SlackRateLimiter
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import SlackBotConfig
|
||||
@ -38,6 +44,29 @@ from danswer.utils.telemetry import RecordType
|
||||
|
||||
logger_base = setup_logger()
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
|
||||
RT = TypeVar("RT") # return type
|
||||
|
||||
|
||||
def rate_limits(
|
||||
client: WebClient, channel: str, thread_ts: Optional[str]
|
||||
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
|
||||
def decorator(func: Callable[..., RT]) -> Callable[..., RT]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> RT:
|
||||
if not srl.is_available():
|
||||
func_randid, position = srl.init_waiter()
|
||||
srl.notify(client, channel, position, thread_ts)
|
||||
while not srl.is_available():
|
||||
srl.waiter(func_randid)
|
||||
srl.acquire_slot()
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
|
||||
if details.is_bot_msg and details.sender:
|
||||
@ -177,6 +206,7 @@ def handle_message(
|
||||
backoff=2,
|
||||
logger=logger,
|
||||
)
|
||||
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
|
||||
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse:
|
||||
action = "slack_message"
|
||||
if is_bot_msg:
|
||||
|
@ -2,9 +2,11 @@ import logging
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import time
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
@ -14,6 +16,8 @@ from slack_sdk.models.metadata import Metadata
|
||||
|
||||
from danswer.configs.constants import ID_SEPARATOR
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_MAX_QPM
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_MAX_WAIT_TIME
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
@ -347,3 +351,57 @@ def read_slack_thread(
|
||||
)
|
||||
|
||||
return thread_messages
|
||||
|
||||
|
||||
class SlackRateLimiter:
|
||||
def __init__(self) -> None:
|
||||
self.max_qpm = DANSWER_BOT_MAX_QPM
|
||||
self.max_wait_time = DANSWER_BOT_MAX_WAIT_TIME
|
||||
self.active_question = 0
|
||||
self.last_reset_time = time.time()
|
||||
self.waiting_questions: list[int] = []
|
||||
|
||||
def refill(self) -> None:
|
||||
# If elapsed time is greater than the period, reset the active question count
|
||||
if (time.time() - self.last_reset_time) > 60:
|
||||
self.active_question = 0
|
||||
self.last_reset_time = time.time()
|
||||
|
||||
def notify(
|
||||
self, client: WebClient, channel: str, position: int, thread_ts: Optional[str]
|
||||
) -> None:
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
text=f"Your question has been queued. You are in position {position}... please wait a moment :loading:",
|
||||
thread_ts=thread_ts,
|
||||
)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
self.refill()
|
||||
return self.active_question < self.max_qpm
|
||||
|
||||
def acquire_slot(self) -> None:
|
||||
self.active_question += 1
|
||||
|
||||
def init_waiter(self) -> tuple[int, int]:
|
||||
func_randid = random.getrandbits(128)
|
||||
self.waiting_questions.append(func_randid)
|
||||
position = self.waiting_questions.index(func_randid) + 1
|
||||
|
||||
return func_randid, position
|
||||
|
||||
def waiter(self, func_randid: int) -> None:
|
||||
wait_time = 0
|
||||
while (
|
||||
self.active_question >= self.max_qpm
|
||||
or self.waiting_questions[0] != func_randid
|
||||
):
|
||||
if wait_time > self.max_wait_time:
|
||||
raise TimeoutError
|
||||
time.sleep(2)
|
||||
wait_time += 2
|
||||
self.refill()
|
||||
|
||||
del self.waiting_questions[0]
|
||||
|
Loading…
x
Reference in New Issue
Block a user