From 3eef4e3992cbf38270b5bb7e72b73faec36cdb4f Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Wed, 30 Oct 2024 21:58:24 -0700 Subject: [PATCH] functioning --- backend/danswer/danswerbot/slack/acquirer.py | 199 ++++++++++++++ backend/danswer/danswerbot/slack/listener.py | 275 +++++++++++++------ 2 files changed, 384 insertions(+), 90 deletions(-) create mode 100644 backend/danswer/danswerbot/slack/acquirer.py diff --git a/backend/danswer/danswerbot/slack/acquirer.py b/backend/danswer/danswerbot/slack/acquirer.py new file mode 100644 index 000000000000..9aa64a5d3e58 --- /dev/null +++ b/backend/danswer/danswerbot/slack/acquirer.py @@ -0,0 +1,199 @@ +import asyncio +import signal +import sys +import threading +import time +from threading import Event +from typing import Dict +from typing import Set + +from prometheus_client import Gauge +from prometheus_client import start_http_server +from slack_sdk.socket_mode.aiohttp import SocketModeClient +from slack_sdk.web.async_client import AsyncWebClient + +from danswer.danswerbot.slack.listener import _get_socket_client +from danswer.danswerbot.slack.tokens import fetch_tokens +from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR +from danswer.db.engine import get_all_tenant_ids +from danswer.db.engine import get_session_with_tenant +from danswer.db.search_settings import get_current_search_settings +from danswer.key_value_store.interface import KvKeyNotFoundError +from danswer.natural_language_processing.search_nlp_models import EmbeddingModel +from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder +from danswer.redis.redis_pool import get_redis_client +from danswer.utils.logger import setup_logger +from shared_configs.configs import MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT + +logger = setup_logger() + +# Prometheus metric for HPA +active_tenants_gauge = Gauge( + "active_tenants", "Number of active tenants handled by this pod" +) + +# Configuration constants +TENANT_LOCK_EXPIRATION = 300 # seconds +TENANT_HEARTBEAT_INTERVAL = 60 # seconds +TENANT_HEARTBEAT_EXPIRATION = 180 # seconds +TENANT_ACQUISITION_INTERVAL = 60 # seconds + + +class TenantHandler: + def __init__(self): + logger.info("Initializing TenantHandler") + self.redis_client = get_redis_client(tenant_id=None) + self.tenant_ids: Set[str] = set() + self.socket_clients: Dict[str, SocketModeClient] = {} + self.slack_bot_tokens: Dict[str, str] = {} + self.running = True + self.pod_id = self.get_pod_id() + logger.info(f"Pod ID: {self.pod_id}") + + # Set up signal handlers for graceful shutdown + signal.signal(signal.SIGTERM, self.shutdown) + signal.signal(signal.SIGINT, self.shutdown) + logger.info("Signal handlers registered") + + # Start the Prometheus metrics server + logger.info("Starting Prometheus metrics server") + start_http_server(8000) + logger.info("Prometheus metrics server started") + + # Start background threads + logger.info("Starting background threads") + threading.Thread(target=self.acquire_tenants_loop, daemon=True).start() + threading.Thread(target=self.heartbeat_loop, daemon=True).start() + logger.info("Background threads started") + + def get_pod_id(self) -> str: + import os + + pod_id = os.environ.get("HOSTNAME", "unknown_pod") + logger.info(f"Retrieved pod ID: {pod_id}") + return pod_id + + def acquire_tenants_loop(self): + logger.info("Starting tenant acquisition loop") + while self.running: + try: + self.acquire_tenants() + active_tenants_gauge.set(len(self.tenant_ids)) + logger.debug(f"Current active tenants: {len(self.tenant_ids)}") + except Exception as e: + logger.exception(f"Error in tenant acquisition: {e}") + Event().wait(timeout=TENANT_ACQUISITION_INTERVAL) + + def heartbeat_loop(self): + logger.info("Starting heartbeat loop") + while self.running: + try: + self.send_heartbeats() + logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants") + except Exception as e: + logger.exception(f"Error in heartbeat loop: {e}") + Event().wait(timeout=TENANT_HEARTBEAT_INTERVAL) + + def acquire_tenants(self): + tenant_ids = get_all_tenant_ids() + logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres") + + for tenant_id in tenant_ids: + with get_session_with_tenant(tenant_id) as db_session: + try: + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public") + latest_slack_bot_token = fetch_tokens() + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + + if not latest_slack_bot_token: + logger.debug(f"No Slack bot token found for tenant {tenant_id}") + if tenant_id in self.socket_clients: + asyncio.run(self.socket_clients[tenant_id].close()) + del self.socket_clients[tenant_id] + del self.slack_bot_tokens[tenant_id] + continue + + slack_bot_token = latest_slack_bot_token.bot_token + + if ( + tenant_id not in self.slack_bot_tokens + or slack_bot_token != self.slack_bot_tokens[tenant_id] + ): + if tenant_id in self.slack_bot_tokens: + logger.notice( + f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting" + ) + else: + # Initial setup for this tenant + search_settings = get_current_search_settings(db_session) + embedding_model = EmbeddingModel.from_db_model( + search_settings=search_settings, + server_host=MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + ) + warm_up_bi_encoder(embedding_model=embedding_model) + + self.slack_bot_tokens[tenant_id] = slack_bot_token + + if tenant_id in self.socket_clients: + asyncio.run(self.socket_clients[tenant_id].close()) + + asyncio.run(self.start_socket_client(tenant_id)) + + except KvKeyNotFoundError: + logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}") + if tenant_id in self.socket_clients: + asyncio.run(self.socket_clients[tenant_id].close()) + del self.socket_clients[tenant_id] + del self.slack_bot_tokens[tenant_id] + except Exception as e: + logger.exception(f"Error handling tenant {tenant_id}: {e}") + + def send_heartbeats(self): + current_time = int(time.time()) + logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants") + for tenant_id in self.tenant_ids: + heartbeat_key = f"tenant_heartbeat:{tenant_id}:{self.pod_id}" + self.redis_client.set( + heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION + ) + + async def start_socket_client(self, tenant_id: str): + logger.info(f"Starting socket client for tenant {tenant_id}") + app_token = self.slack_bot_tokens[tenant_id] + web_client = AsyncWebClient(token=app_token) + socket_client = SocketModeClient(app_token=app_token, web_client=web_client) + + socket_client = _get_socket_client(app_token, tenant_id) + + @socket_client.socket_mode_request_listeners.append + async def handle_events(event): + logger.debug(f"Received event for tenant {tenant_id}") + + logger.info(f"Connecting socket client for tenant {tenant_id}") + await socket_client.connect() + self.socket_clients[tenant_id] = socket_client + self.tenant_ids.add(tenant_id) + logger.info(f"Started SocketModeClient for tenant {tenant_id}") + + def stop_socket_clients(self): + logger.info(f"Stopping {len(self.socket_clients)} socket clients") + for tenant_id, client in self.socket_clients.items(): + asyncio.run(client.close()) + logger.info(f"Stopped SocketModeClient for tenant {tenant_id}") + + def shutdown(self, signum, frame): + logger.info("Shutting down gracefully") + self.running = False + self.stop_socket_clients() + logger.info("Shutdown complete") + sys.exit(0) + + +if __name__ == "__main__": + logger.info("Starting TenantHandler") + handler = TenantHandler() + # Keep the main thread alive + while handler.running: + time.sleep(1) diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index 2078d621325c..377954921264 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -1,8 +1,16 @@ +import asyncio +import signal +import sys +import threading import time from threading import Event from typing import Any from typing import cast +from typing import Dict +from typing import Set +from prometheus_client import Gauge +from prometheus_client import start_http_server from slack_sdk import WebClient from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.response import SocketModeResponse @@ -46,6 +54,7 @@ from danswer.danswerbot.slack.utils import remove_danswer_bot_tag from danswer.danswerbot.slack.utils import rephrase_slack_message from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import TenantSocketModeClient +from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR from danswer.db.engine import get_all_tenant_ids from danswer.db.engine import get_session_with_tenant from danswer.db.search_settings import get_current_search_settings @@ -53,6 +62,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder from danswer.one_shot_answer.models import ThreadMessage +from danswer.redis.redis_pool import get_redis_client from danswer.search.retrieval.search_runner import download_nltk_data from danswer.server.manage.models import SlackBotTokens from danswer.utils.logger import setup_logger @@ -64,6 +74,17 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() +# Prometheus metric for HPA +active_tenants_gauge = Gauge( + "active_tenants", "Number of active tenants handled by this pod" +) + +# Configuration constants +TENANT_LOCK_EXPIRATION = 300 # seconds +TENANT_HEARTBEAT_INTERVAL = 60 # seconds +TENANT_HEARTBEAT_EXPIRATION = 180 # seconds +TENANT_ACQUISITION_INTERVAL = 60 # seconds + # In rare cases, some users have been experiencing a massive amount of trivial messages coming through # to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down # the cause. @@ -77,10 +98,157 @@ _SLACK_GREETINGS_TO_IGNORE = { ":wave:", } -# this is always (currently) the user id of Slack's official slackbot +# This is always (currently) the user id of Slack's official slackbot _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT" +class TenantHandler: + def __init__(self): + logger.info("Initializing TenantHandler") + self.redis_client = get_redis_client(tenant_id=None) + self.tenant_ids: Set[str] = set() + self.socket_clients: Dict[str, TenantSocketModeClient] = {} + self.slack_bot_tokens: Dict[str, SlackBotTokens] = {} + self.running = True + self.pod_id = self.get_pod_id() + logger.info(f"Pod ID: {self.pod_id}") + + # Set up signal handlers for graceful shutdown + signal.signal(signal.SIGTERM, self.shutdown) + signal.signal(signal.SIGINT, self.shutdown) + logger.info("Signal handlers registered") + + # Start the Prometheus metrics server + logger.info("Starting Prometheus metrics server") + start_http_server(8000) + logger.info("Prometheus metrics server started") + + # Start background threads + logger.info("Starting background threads") + threading.Thread(target=self.acquire_tenants_loop, daemon=True).start() + threading.Thread(target=self.heartbeat_loop, daemon=True).start() + logger.info("Background threads started") + + def get_pod_id(self) -> str: + import os + + pod_id = os.environ.get("HOSTNAME", "unknown_pod") + logger.info(f"Retrieved pod ID: {pod_id}") + return pod_id + + def acquire_tenants_loop(self): + logger.info("Starting tenant acquisition loop") + while self.running: + try: + self.acquire_tenants() + active_tenants_gauge.set(len(self.tenant_ids)) + logger.debug(f"Current active tenants: {len(self.tenant_ids)}") + except Exception as e: + logger.exception(f"Error in tenant acquisition: {e}") + Event().wait(timeout=TENANT_ACQUISITION_INTERVAL) + + def heartbeat_loop(self): + logger.info("Starting heartbeat loop") + while self.running: + try: + self.send_heartbeats() + logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants") + except Exception as e: + logger.exception(f"Error in heartbeat loop: {e}") + Event().wait(timeout=TENANT_HEARTBEAT_INTERVAL) + + def acquire_tenants(self): + tenant_ids = get_all_tenant_ids() + logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres") + + for tenant_id in tenant_ids: + with get_session_with_tenant(tenant_id) as db_session: + try: + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public") + slack_bot_tokens = fetch_tokens() + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + + if not slack_bot_tokens: + logger.debug(f"No Slack bot token found for tenant {tenant_id}") + if tenant_id in self.socket_clients: + asyncio.run(self.socket_clients[tenant_id].close()) + del self.socket_clients[tenant_id] + del self.slack_bot_tokens[tenant_id] + continue + + if ( + tenant_id not in self.slack_bot_tokens + or slack_bot_tokens != self.slack_bot_tokens[tenant_id] + ): + if tenant_id in self.slack_bot_tokens: + logger.info( + f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting" + ) + else: + # Initial setup for this tenant + search_settings = get_current_search_settings(db_session) + embedding_model = EmbeddingModel.from_db_model( + search_settings=search_settings, + server_host=MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + ) + warm_up_bi_encoder(embedding_model=embedding_model) + + self.slack_bot_tokens[tenant_id] = slack_bot_tokens + + if tenant_id in self.socket_clients: + # Close the existing socket client + asyncio.run(self.socket_clients[tenant_id].close()) + + # Start a new socket client for the tenant + self.start_socket_client(tenant_id, slack_bot_tokens) + + except KvKeyNotFoundError: + logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}") + if tenant_id in self.socket_clients: + asyncio.run(self.socket_clients[tenant_id].close()) + del self.socket_clients[tenant_id] + del self.slack_bot_tokens[tenant_id] + except Exception as e: + logger.exception(f"Error handling tenant {tenant_id}: {e}") + + def send_heartbeats(self): + current_time = int(time.time()) + logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants") + for tenant_id in self.tenant_ids: + heartbeat_key = f"tenant_heartbeat:{tenant_id}:{self.pod_id}" + self.redis_client.set( + heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION + ) + + def start_socket_client(self, tenant_id: str, slack_bot_tokens: SlackBotTokens): + logger.info(f"Starting socket client for tenant {tenant_id}") + socket_client = _get_socket_client(slack_bot_tokens, tenant_id) + + # Append the event handler + socket_client.socket_mode_request_listeners.append(process_slack_event) + + # Establish a WebSocket connection to the Socket Mode servers + logger.info(f"Connecting socket client for tenant {tenant_id}") + socket_client.connect() + self.socket_clients[tenant_id] = socket_client + self.tenant_ids.add(tenant_id) + logger.info(f"Started SocketModeClient for tenant {tenant_id}") + + def stop_socket_clients(self): + logger.info(f"Stopping {len(self.socket_clients)} socket clients") + for tenant_id, client in self.socket_clients.items(): + asyncio.run(client.close()) + logger.info(f"Stopped SocketModeClient for tenant {tenant_id}") + + def shutdown(self, signum, frame): + logger.info("Shutting down gracefully") + self.running = False + self.stop_socket_clients() + logger.info("Shutdown complete") + sys.exit(0) + + def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool: """True to keep going, False to ignore this Slack request""" if req.type == "events_api": @@ -172,7 +340,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) - message_subtype = event.get("subtype") if message_subtype not in [None, "file_share"]: channel_specific_logger.info( - f"Ignoring message with subtype '{message_subtype}' since is is a special message type" + f"Ignoring message with subtype '{message_subtype}' since it is a special message type" ) return False @@ -247,7 +415,7 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) -> ) query_event_id, _, _ = decompose_action_id(feedback_id) - logger.notice(f"Successfully handled QA feedback for event: {query_event_id}") + logger.info(f"Successfully handled QA feedback for event: {query_event_id}") def build_request_details( @@ -269,14 +437,14 @@ def build_request_details( msg = remove_danswer_bot_tag(msg, client=client.web_client) if DANSWER_BOT_REPHRASE_MESSAGE: - logger.notice(f"Rephrasing Slack message. Original message: {msg}") + logger.info(f"Rephrasing Slack message. Original message: {msg}") try: msg = rephrase_slack_message(msg) - logger.notice(f"Rephrased message: {msg}") + logger.info(f"Rephrased message: {msg}") except Exception as e: logger.error(f"Error while trying to rephrase the Slack message: {e}") else: - logger.notice(f"Received Slack message: {msg}") + logger.info(f"Received Slack message: {msg}") if tagged: logger.debug("User tagged DanswerBot") @@ -477,94 +645,21 @@ def _get_socket_client( ) -def _initialize_socket_client(socket_client: TenantSocketModeClient) -> None: - socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore - - # Establish a WebSocket connection to the Socket Mode servers - logger.notice(f"Listening for messages from Slack {socket_client.tenant_id }...") - socket_client.connect() - - -# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up -# the slack bot in your workspace, and then add the bot to any channels you want to -# try and answer questions for. Running this file will setup Danswer to listen to all -# messages in those channels and attempt to answer them. As of now, it will only respond -# to messages sent directly in the channel - it will not respond to messages sent within a -# thread. -# -# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC -# without issue. if __name__ == "__main__": - slack_bot_tokens: dict[str | None, SlackBotTokens] = {} - socket_clients: dict[str | None, TenantSocketModeClient] = {} + # Initialize the tenant handler which will manage tenant connections + logger.info("Starting TenantHandler") + tenant_handler = TenantHandler() set_is_ee_based_on_env_variable() - logger.notice("Verifying query preprocessing (NLTK) data is downloaded") + logger.info("Verifying query preprocessing (NLTK) data is downloaded") download_nltk_data() - while True: - try: - tenant_ids = get_all_tenant_ids() # Function to retrieve all tenant IDs + try: + # Keep the main thread alive + while tenant_handler.running: + time.sleep(1) - for tenant_id in tenant_ids: - with get_session_with_tenant(tenant_id) as db_session: - try: - token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public") - latest_slack_bot_tokens = fetch_tokens() - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) - - if ( - tenant_id not in slack_bot_tokens - or latest_slack_bot_tokens != slack_bot_tokens[tenant_id] - ): - if tenant_id in slack_bot_tokens: - logger.notice( - f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting" - ) - else: - # Initial setup for this tenant - search_settings = get_current_search_settings( - db_session - ) - embedding_model = EmbeddingModel.from_db_model( - search_settings=search_settings, - server_host=MODEL_SERVER_HOST, - server_port=MODEL_SERVER_PORT, - ) - warm_up_bi_encoder(embedding_model=embedding_model) - - slack_bot_tokens[tenant_id] = latest_slack_bot_tokens - - # potentially may cause a message to be dropped, but it is complicated - # to avoid + (1) if the user is changing tokens, they are likely okay with some - # "migration downtime" and (2) if a single message is lost it is okay - # as this should be a very rare occurrence - if tenant_id in socket_clients: - socket_clients[tenant_id].close() - - socket_client = _get_socket_client( - latest_slack_bot_tokens, tenant_id - ) - - # Initialize socket client for this tenant. Each tenant has its own - # socket client, allowing for multiple concurrent connections (one - # per tenant) with the tenant ID wrapped in the socket model client. - # Each `connect` stores websocket connection in a separate thread. - _initialize_socket_client(socket_client) - - socket_clients[tenant_id] = socket_client - - except KvKeyNotFoundError: - logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}") - if tenant_id in socket_clients: - socket_clients[tenant_id].disconnect() - del socket_clients[tenant_id] - del slack_bot_tokens[tenant_id] - - # Wait before checking for updates - Event().wait(timeout=60) - - except Exception: - logger.exception("An error occurred outside of main event loop") - time.sleep(60) + except Exception: + logger.exception("Fatal error in main thread") + tenant_handler.shutdown(None, None)