Merge pull request #3008 from danswer-ai/horizontal_slack

Add Functional Horizontal scaling for Slack
This commit is contained in:
hagen-danswer 2024-11-06 14:31:13 -08:00 committed by GitHub
commit faeb9f09f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 341 additions and 93 deletions

View File

@ -225,6 +225,9 @@ class DanswerRedisLocks:
PRUNING_LOCK_PREFIX = "da_lock:pruning"
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
SLACK_BOT_LOCK = "da_lock:slack_bot"
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
class DanswerCeleryPriority(int, Enum):
HIGHEST = 0

View File

@ -1,3 +1,5 @@
import os
from sqlalchemy.orm import Session
from danswer.db.models import SlackBotConfig
@ -48,3 +50,16 @@ def validate_channel_names(
)
return cleaned_channel_names
# Scaling configurations for multi-tenant Slack bot handling
TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it
TENANT_HEARTBEAT_INTERVAL = (
60 # How often pods send heartbeats to indicate they are still processing a tenant
)
TENANT_HEARTBEAT_EXPIRATION = 180 # How long before a tenant's heartbeat expires, allowing other pods to take over
TENANT_ACQUISITION_INTERVAL = (
60 # How often pods attempt to acquire unprocessed tenants
)
MAX_TENANTS_PER_POD = int(os.getenv("MAX_TENANTS_PER_POD", 50))

View File

@ -1,18 +1,34 @@
import asyncio
import os
import signal
import sys
import threading
import time
from threading import Event
from types import FrameType
from typing import Any
from typing import cast
from typing import Dict
from typing import Set
from prometheus_client import Gauge
from prometheus_client import start_http_server
from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.config import MAX_TENANTS_PER_POD
from danswer.danswerbot.slack.config import TENANT_ACQUISITION_INTERVAL
from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_EXPIRATION
from danswer.danswerbot.slack.config import TENANT_HEARTBEAT_INTERVAL
from danswer.danswerbot.slack.config import TENANT_LOCK_EXPIRATION
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
@ -46,6 +62,7 @@ from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import get_session_with_tenant
from danswer.db.search_settings import get_current_search_settings
@ -53,6 +70,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
from danswer.redis.redis_pool import get_redis_client
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
@ -60,10 +78,14 @@ from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
# Prometheus metric for HPA
active_tenants_gauge = Gauge(
"active_tenants", "Number of active tenants handled by this pod"
)
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
# to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down
# the cause.
@ -77,10 +99,205 @@ _SLACK_GREETINGS_TO_IGNORE = {
":wave:",
}
# this is always (currently) the user id of Slack's official slackbot
# This is always (currently) the user id of Slack's official slackbot
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
class SlackbotHandler:
def __init__(self) -> None:
logger.info("Initializing SlackbotHandler")
self.tenant_ids: Set[str | None] = set()
self.socket_clients: Dict[str | None, TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[str | None, SlackBotTokens] = {}
self.running = True
self.pod_id = self.get_pod_id()
self._shutdown_event = Event()
logger.info(f"Pod ID: {self.pod_id}")
# Set up signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, self.shutdown)
signal.signal(signal.SIGINT, self.shutdown)
logger.info("Signal handlers registered")
# Start the Prometheus metrics server
logger.info("Starting Prometheus metrics server")
start_http_server(8000)
logger.info("Prometheus metrics server started")
# Start background threads
logger.info("Starting background threads")
self.acquire_thread = threading.Thread(
target=self.acquire_tenants_loop, daemon=True
)
self.heartbeat_thread = threading.Thread(
target=self.heartbeat_loop, daemon=True
)
self.acquire_thread.start()
self.heartbeat_thread.start()
logger.info("Background threads started")
def get_pod_id(self) -> str:
pod_id = os.environ.get("HOSTNAME", "unknown_pod")
logger.info(f"Retrieved pod ID: {pod_id}")
return pod_id
def acquire_tenants_loop(self) -> None:
while not self._shutdown_event.is_set():
try:
self.acquire_tenants()
active_tenants_gauge.set(len(self.tenant_ids))
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
except Exception as e:
logger.exception(f"Error in Slack acquisition: {e}")
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
def heartbeat_loop(self) -> None:
while not self._shutdown_event.is_set():
try:
self.send_heartbeats()
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants")
except Exception as e:
logger.exception(f"Error in heartbeat loop: {e}")
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
def acquire_tenants(self) -> None:
tenant_ids = get_all_tenant_ids()
logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres")
for tenant_id in tenant_ids:
if tenant_id in self.tenant_ids:
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
continue
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
logger.info(
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}) Not acquiring any more tenants"
)
break
redis_client = get_redis_client(tenant_id=tenant_id)
pod_id = self.pod_id
acquired = redis_client.set(
DanswerRedisLocks.SLACK_BOT_LOCK,
pod_id,
nx=True,
ex=TENANT_LOCK_EXPIRATION,
)
if not acquired:
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
continue
logger.debug(f"Acquired lock for tenant {tenant_id}")
with get_session_with_tenant(tenant_id) as db_session:
try:
logger.debug(
f"Setting tenant ID context variable for tenant {tenant_id}"
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
slack_bot_tokens = fetch_tokens()
logger.debug(f"Fetched Slack bot tokens for tenant {tenant_id}")
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
logger.debug(
f"Reset tenant ID context variable for tenant {tenant_id}"
)
if not slack_bot_tokens:
logger.debug(f"No Slack bot token found for tenant {tenant_id}")
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
continue
if (
tenant_id not in self.slack_bot_tokens
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
):
if tenant_id in self.slack_bot_tokens:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
self.start_socket_client(tenant_id, slack_bot_tokens)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
def send_heartbeats(self) -> None:
current_time = int(time.time())
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
redis_client = get_redis_client(tenant_id=tenant_id)
heartbeat_key = (
f"{DanswerRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
)
redis_client.set(
heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION
)
def start_socket_client(
self, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
) -> None:
logger.info(f"Starting socket client for tenant {tenant_id}")
socket_client = _get_socket_client(slack_bot_tokens, tenant_id)
# Append the event handler
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.info(f"Connecting socket client for tenant {tenant_id}")
socket_client.connect()
self.socket_clients[tenant_id] = socket_client
self.tenant_ids.add(tenant_id)
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for tenant_id, client in self.socket_clients.items():
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
if not self.running:
return
logger.info("Shutting down gracefully")
self.running = False
self._shutdown_event.set()
# Stop all socket clients
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients()
# Wait for background threads to finish (with timeout)
logger.info("Waiting for background threads to finish...")
self.acquire_thread.join(timeout=5)
self.heartbeat_thread.join(timeout=5)
logger.info("Shutdown complete")
sys.exit(0)
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
if req.type == "events_api":
@ -172,7 +389,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
message_subtype = event.get("subtype")
if message_subtype not in [None, "file_share"]:
channel_specific_logger.info(
f"Ignoring message with subtype '{message_subtype}' since is is a special message type"
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
)
return False
@ -247,7 +464,7 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
)
query_event_id, _, _ = decompose_action_id(feedback_id)
logger.notice(f"Successfully handled QA feedback for event: {query_event_id}")
logger.info(f"Successfully handled QA feedback for event: {query_event_id}")
def build_request_details(
@ -269,14 +486,14 @@ def build_request_details(
msg = remove_danswer_bot_tag(msg, client=client.web_client)
if DANSWER_BOT_REPHRASE_MESSAGE:
logger.notice(f"Rephrasing Slack message. Original message: {msg}")
logger.info(f"Rephrasing Slack message. Original message: {msg}")
try:
msg = rephrase_slack_message(msg)
logger.notice(f"Rephrased message: {msg}")
logger.info(f"Rephrased message: {msg}")
except Exception as e:
logger.error(f"Error while trying to rephrase the Slack message: {e}")
else:
logger.notice(f"Received Slack message: {msg}")
logger.info(f"Received Slack message: {msg}")
if tagged:
logger.debug("User tagged DanswerBot")
@ -477,94 +694,21 @@ def _get_socket_client(
)
def _initialize_socket_client(socket_client: TenantSocketModeClient) -> None:
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.notice(f"Listening for messages from Slack {socket_client.tenant_id }...")
socket_client.connect()
# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up
# the slack bot in your workspace, and then add the bot to any channels you want to
# try and answer questions for. Running this file will setup Danswer to listen to all
# messages in those channels and attempt to answer them. As of now, it will only respond
# to messages sent directly in the channel - it will not respond to messages sent within a
# thread.
#
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
# without issue.
if __name__ == "__main__":
slack_bot_tokens: dict[str | None, SlackBotTokens] = {}
socket_clients: dict[str | None, TenantSocketModeClient] = {}
# Initialize the tenant handler which will manage tenant connections
logger.info("Starting SlackbotHandler")
tenant_handler = SlackbotHandler()
set_is_ee_based_on_env_variable()
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
while True:
try:
tenant_ids = get_all_tenant_ids() # Function to retrieve all tenant IDs
try:
# Keep the main thread alive
while tenant_handler.running:
time.sleep(1)
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
latest_slack_bot_tokens = fetch_tokens()
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if (
tenant_id not in slack_bot_tokens
or latest_slack_bot_tokens != slack_bot_tokens[tenant_id]
):
if tenant_id in slack_bot_tokens:
logger.notice(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
# Initial setup for this tenant
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
slack_bot_tokens[tenant_id] = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated
# to avoid + (1) if the user is changing tokens, they are likely okay with some
# "migration downtime" and (2) if a single message is lost it is okay
# as this should be a very rare occurrence
if tenant_id in socket_clients:
socket_clients[tenant_id].close()
socket_client = _get_socket_client(
latest_slack_bot_tokens, tenant_id
)
# Initialize socket client for this tenant. Each tenant has its own
# socket client, allowing for multiple concurrent connections (one
# per tenant) with the tenant ID wrapped in the socket model client.
# Each `connect` stores websocket connection in a separate thread.
_initialize_socket_client(socket_client)
socket_clients[tenant_id] = socket_client
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in socket_clients:
socket_clients[tenant_id].disconnect()
del socket_clients[tenant_id]
del slack_bot_tokens[tenant_id]
# Wait before checking for updates
Event().wait(timeout=60)
except Exception:
logger.exception("An error occurred outside of main event loop")
time.sleep(60)
except Exception:
logger.exception("Fatal error in main thread")
tenant_handler.shutdown(None, None)

View File

@ -331,11 +331,13 @@ def get_session_with_tenant(
Generate a database session bound to a connection with the appropriate tenant schema set.
This preserves the tenant ID across the session and reverts to the previous tenant ID
after the session is closed.
If tenant ID is set, we save the previous tenant ID from the context var to set
after the session is closed. The value `None` evaluates to the default schema.
"""
engine = get_sqlalchemy_engine()
# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None:
tenant_id = previous_tenant_id

View File

@ -66,7 +66,7 @@ def get_all_empty_chat_message_entries(
return
yield message_skeletons
initial_id = message_skeletons[-1].message_id
initial_id = message_skeletons[-1].chat_session_id
def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]:

View File

@ -81,3 +81,4 @@ stripe==10.12.0
urllib3==2.2.3
mistune==0.8.4
sentry-sdk==2.14.0
prometheus_client==0.21.0

View File

@ -0,0 +1,80 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: slack-bot-deployment
labels:
app: slack-bot
spec:
replicas: 1
selector:
matchLabels:
app: slack-bot
template:
metadata:
labels:
app: slack-bot
spec:
containers:
- name: slack-bot
image: danswer/danswer-backend:latest
imagePullPolicy: IfNotPresent
command: ["python", "danswer/danswerbot/slack/listener.py"]
ports:
- containerPort: 8000
resources:
requests:
cpu: "100m"
memory: "200Mi"
limits:
cpu: "500m"
memory: "500Mi"
readinessProbe:
httpGet:
path: /metrics
port: 8000
initialDelaySeconds: 10
periodSeconds: 10
livenessProbe:
httpGet:
path: /metrics
port: 8000
initialDelaySeconds: 15
periodSeconds: 20
---
apiVersion: v1
kind: Service
metadata:
name: slack-bot-service
labels:
app: slack-bot
spec:
selector:
app: slack-bot
ports:
# Port exposed for Prometheus metrics
- protocol: TCP
port: 8000
targetPort: 8000
type: ClusterIP
---
apiVersion: autoscaling/v2beta2
kind: HorizontalPodAutoscaler
name: slack-bot-hpa
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: slack-bot-deployment
minReplicas: 1
maxReplicas: 10
metrics:
- type: Pods
pods:
metric:
name: active_tenants
target:
type: AverageValue
averageValue: "40"

View File

@ -5,6 +5,7 @@ import { generateRandomIconShape, createSVG } from "@/lib/assistantIconUtils";
import { CCPairBasicInfo, DocumentSet, User } from "@/lib/types";
import { Separator } from "@/components/ui/separator";
import { Button } from "@/components/ui/button";
import { Textarea } from "@/components/ui/textarea";
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
import {
ArrayHelpers,
@ -1102,7 +1103,9 @@ export function AssistantEditor({
w-full
py-2
px-3
min-h-12
mr-4
line-clamp-
`}
as="textarea"
autoComplete="off"