From 5a24ed2947bd1408da681dc3e70efd5616d6db28 Mon Sep 17 00:00:00 2001
From: pablodanswer <pablo@danswer.ai>
Date: Thu, 31 Oct 2024 13:25:20 -0700
Subject: [PATCH] updated cleanup

---
 backend/danswer/configs/constants.py         |  3 ++
 backend/danswer/danswerbot/slack/config.py   | 11 ++++
 backend/danswer/danswerbot/slack/listener.py | 55 +++++++++++---------
 3 files changed, 45 insertions(+), 24 deletions(-)

diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py
index 36b9a8bf3..671fd13e2 100644
--- a/backend/danswer/configs/constants.py
+++ b/backend/danswer/configs/constants.py
@@ -225,6 +225,9 @@ class DanswerRedisLocks:
     PRUNING_LOCK_PREFIX = "da_lock:pruning"
     INDEXING_METADATA_PREFIX = "da_metadata:indexing"
 
+    SLACK_BOT_LOCK = "da_lock:slack_bot"
+    SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
+
 
 class DanswerCeleryPriority(int, Enum):
     HIGHEST = 0
diff --git a/backend/danswer/danswerbot/slack/config.py b/backend/danswer/danswerbot/slack/config.py
index 29e1bd0a8..c9ae8281a 100644
--- a/backend/danswer/danswerbot/slack/config.py
+++ b/backend/danswer/danswerbot/slack/config.py
@@ -48,3 +48,14 @@ def validate_channel_names(
                 )
 
     return cleaned_channel_names
+
+
+# Scaling configurations for multi-tenant Slack bot handling
+TENANT_LOCK_EXPIRATION = 1800  # How long a pod can hold exclusive access to a tenant before other pods can acquire it
+TENANT_HEARTBEAT_INTERVAL = (
+    60  # How often pods send heartbeats to indicate they are still processing a tenant
+)
+TENANT_HEARTBEAT_EXPIRATION = 180  # How long before a tenant's heartbeat expires, allowing other pods to take over
+TENANT_ACQUISITION_INTERVAL = (
+    60  # How often pods attempt to acquire unprocessed tenants
+)
diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py
index 21abba5ae..3b4e18ee6 100644
--- a/backend/danswer/danswerbot/slack/listener.py
+++ b/backend/danswer/danswerbot/slack/listener.py
@@ -4,6 +4,7 @@ import sys
 import threading
 import time
 from threading import Event
+from types import FrameType
 from typing import Any
 from typing import cast
 from typing import Dict
@@ -15,12 +16,17 @@ from slack_sdk import WebClient
 from slack_sdk.socket_mode.request import SocketModeRequest
 from slack_sdk.socket_mode.response import SocketModeResponse
 
+from danswer.configs.constants import DanswerRedisLocks
 from danswer.configs.constants import MessageType
 from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
 from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
 from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
 from danswer.connectors.slack.utils import expert_info_from_slack_id
 from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
+from danswer.danswerbot.slack.config import TENANT_ACQUISITION_INTERVAL
+from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_EXPIRATION
+from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_INTERVAL
+from danswer.danswerbot.slack.config import TENANT_LOCK_EXPIRATION
 from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
 from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
 from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
@@ -78,12 +84,6 @@ active_tenants_gauge = Gauge(
     "active_tenants", "Number of active tenants handled by this pod"
 )
 
-# Configuration constants
-TENANT_LOCK_EXPIRATION = 300  # seconds
-TENANT_HEARTBEAT_INTERVAL = 60  # seconds
-TENANT_HEARTBEAT_EXPIRATION = 180  # seconds
-TENANT_ACQUISITION_INTERVAL = 60  # seconds
-
 # In rare cases, some users have been experiencing a massive amount of trivial messages coming through
 # to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down
 # the cause.
@@ -101,10 +101,9 @@ _SLACK_GREETINGS_TO_IGNORE = {
 _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
 
 
-class TenantHandler:
-    def __init__(self):
-        logger.info("Initializing TenantHandler")
-        self.redis_client = get_redis_client(tenant_id=None)
+class SlackbotHandler:
+    def __init__(self) -> None:
+        logger.info("Initializing SlackbotHandler")
         self.tenant_ids: Set[str] = set()
         self.socket_clients: Dict[str, TenantSocketModeClient] = {}
         self.slack_bot_tokens: Dict[str, SlackBotTokens] = {}
@@ -135,7 +134,7 @@ class TenantHandler:
         logger.info(f"Retrieved pod ID: {pod_id}")
         return pod_id
 
-    def acquire_tenants_loop(self):
+    def acquire_tenants_loop(self) -> None:
         logger.info("Starting tenant acquisition loop")
         while self.running:
             try:
@@ -146,7 +145,7 @@ class TenantHandler:
                 logger.exception(f"Error in tenant acquisition: {e}")
             Event().wait(timeout=TENANT_ACQUISITION_INTERVAL)
 
-    def heartbeat_loop(self):
+    def heartbeat_loop(self) -> None:
         logger.info("Starting heartbeat loop")
         while self.running:
             try:
@@ -156,15 +155,18 @@ class TenantHandler:
                 logger.exception(f"Error in heartbeat loop: {e}")
             Event().wait(timeout=TENANT_HEARTBEAT_INTERVAL)
 
-    def acquire_tenants(self):
+    def acquire_tenants(self) -> None:
         tenant_ids = get_all_tenant_ids()
         logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres")
 
         for tenant_id in tenant_ids:
-            lock_key = f"tenant_lock:{tenant_id}"
+            redis_client = get_redis_client(tenant_id=tenant_id)
             pod_id = self.pod_id
-            acquired = self.redis_client.set(
-                lock_key, pod_id, nx=True, ex=TENANT_LOCK_EXPIRATION
+            acquired = redis_client.set(
+                DanswerRedisLocks.SLACK_BOT_LOCK,
+                pod_id,
+                nx=True,
+                ex=TENANT_LOCK_EXPIRATION,
             )
             if not acquired:
                 continue  # Another pod holds the lock
@@ -219,16 +221,21 @@ class TenantHandler:
                 except Exception as e:
                     logger.exception(f"Error handling tenant {tenant_id}: {e}")
 
-    def send_heartbeats(self):
+    def send_heartbeats(self) -> None:
         current_time = int(time.time())
         logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
         for tenant_id in self.tenant_ids:
-            heartbeat_key = f"tenant_heartbeat:{tenant_id}:{self.pod_id}"
-            self.redis_client.set(
+            redis_client = get_redis_client(tenant_id=tenant_id)
+            heartbeat_key = (
+                f"{DanswerRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
+            )
+            redis_client.set(
                 heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION
             )
 
-    def start_socket_client(self, tenant_id: str, slack_bot_tokens: SlackBotTokens):
+    def start_socket_client(
+        self, tenant_id: str, slack_bot_tokens: SlackBotTokens
+    ) -> None:
         logger.info(f"Starting socket client for tenant {tenant_id}")
         socket_client = _get_socket_client(slack_bot_tokens, tenant_id)
 
@@ -242,13 +249,13 @@ class TenantHandler:
         self.tenant_ids.add(tenant_id)
         logger.info(f"Started SocketModeClient for tenant {tenant_id}")
 
-    def stop_socket_clients(self):
+    def stop_socket_clients(self) -> None:
         logger.info(f"Stopping {len(self.socket_clients)} socket clients")
         for tenant_id, client in self.socket_clients.items():
             asyncio.run(client.close())
             logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
 
-    def shutdown(self, signum, frame):
+    def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
         logger.info("Shutting down gracefully")
         self.running = False
         self.stop_socket_clients()
@@ -654,8 +661,8 @@ def _get_socket_client(
 
 if __name__ == "__main__":
     # Initialize the tenant handler which will manage tenant connections
-    logger.info("Starting TenantHandler")
-    tenant_handler = TenantHandler()
+    logger.info("Starting SlackbotHandler")
+    tenant_handler = SlackbotHandler()
 
     set_is_ee_based_on_env_variable()