functioning

This commit is contained in:
pablodanswer
2024-10-30 21:58:24 -07:00
parent 467ce4e3f3
commit 3eef4e3992
2 changed files with 384 additions and 90 deletions

View File

@@ -0,0 +1,199 @@
import asyncio
import signal
import sys
import threading
import time
from threading import Event
from typing import Dict
from typing import Set
from prometheus_client import Gauge
from prometheus_client import start_http_server
from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.web.async_client import AsyncWebClient
from danswer.danswerbot.slack.listener import _get_socket_client
from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import get_session_with_tenant
from danswer.db.search_settings import get_current_search_settings
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
# Prometheus metric for HPA
active_tenants_gauge = Gauge(
"active_tenants", "Number of active tenants handled by this pod"
)
# Configuration constants
TENANT_LOCK_EXPIRATION = 300 # seconds
TENANT_HEARTBEAT_INTERVAL = 60 # seconds
TENANT_HEARTBEAT_EXPIRATION = 180 # seconds
TENANT_ACQUISITION_INTERVAL = 60 # seconds
class TenantHandler:
def __init__(self):
logger.info("Initializing TenantHandler")
self.redis_client = get_redis_client(tenant_id=None)
self.tenant_ids: Set[str] = set()
self.socket_clients: Dict[str, SocketModeClient] = {}
self.slack_bot_tokens: Dict[str, str] = {}
self.running = True
self.pod_id = self.get_pod_id()
logger.info(f"Pod ID: {self.pod_id}")
# Set up signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, self.shutdown)
signal.signal(signal.SIGINT, self.shutdown)
logger.info("Signal handlers registered")
# Start the Prometheus metrics server
logger.info("Starting Prometheus metrics server")
start_http_server(8000)
logger.info("Prometheus metrics server started")
# Start background threads
logger.info("Starting background threads")
threading.Thread(target=self.acquire_tenants_loop, daemon=True).start()
threading.Thread(target=self.heartbeat_loop, daemon=True).start()
logger.info("Background threads started")
def get_pod_id(self) -> str:
import os
pod_id = os.environ.get("HOSTNAME", "unknown_pod")
logger.info(f"Retrieved pod ID: {pod_id}")
return pod_id
def acquire_tenants_loop(self):
logger.info("Starting tenant acquisition loop")
while self.running:
try:
self.acquire_tenants()
active_tenants_gauge.set(len(self.tenant_ids))
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
except Exception as e:
logger.exception(f"Error in tenant acquisition: {e}")
Event().wait(timeout=TENANT_ACQUISITION_INTERVAL)
def heartbeat_loop(self):
logger.info("Starting heartbeat loop")
while self.running:
try:
self.send_heartbeats()
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants")
except Exception as e:
logger.exception(f"Error in heartbeat loop: {e}")
Event().wait(timeout=TENANT_HEARTBEAT_INTERVAL)
def acquire_tenants(self):
tenant_ids = get_all_tenant_ids()
logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres")
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
latest_slack_bot_token = fetch_tokens()
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if not latest_slack_bot_token:
logger.debug(f"No Slack bot token found for tenant {tenant_id}")
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
continue
slack_bot_token = latest_slack_bot_token.bot_token
if (
tenant_id not in self.slack_bot_tokens
or slack_bot_token != self.slack_bot_tokens[tenant_id]
):
if tenant_id in self.slack_bot_tokens:
logger.notice(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
# Initial setup for this tenant
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_id] = slack_bot_token
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
asyncio.run(self.start_socket_client(tenant_id))
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
def send_heartbeats(self):
current_time = int(time.time())
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
heartbeat_key = f"tenant_heartbeat:{tenant_id}:{self.pod_id}"
self.redis_client.set(
heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION
)
async def start_socket_client(self, tenant_id: str):
logger.info(f"Starting socket client for tenant {tenant_id}")
app_token = self.slack_bot_tokens[tenant_id]
web_client = AsyncWebClient(token=app_token)
socket_client = SocketModeClient(app_token=app_token, web_client=web_client)
socket_client = _get_socket_client(app_token, tenant_id)
@socket_client.socket_mode_request_listeners.append
async def handle_events(event):
logger.debug(f"Received event for tenant {tenant_id}")
logger.info(f"Connecting socket client for tenant {tenant_id}")
await socket_client.connect()
self.socket_clients[tenant_id] = socket_client
self.tenant_ids.add(tenant_id)
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
def stop_socket_clients(self):
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for tenant_id, client in self.socket_clients.items():
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
def shutdown(self, signum, frame):
logger.info("Shutting down gracefully")
self.running = False
self.stop_socket_clients()
logger.info("Shutdown complete")
sys.exit(0)
if __name__ == "__main__":
logger.info("Starting TenantHandler")
handler = TenantHandler()
# Keep the main thread alive
while handler.running:
time.sleep(1)

View File

@@ -1,8 +1,16 @@
import asyncio
import signal
import sys
import threading
import time
from threading import Event
from typing import Any
from typing import cast
from typing import Dict
from typing import Set
from prometheus_client import Gauge
from prometheus_client import start_http_server
from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
@@ -46,6 +54,7 @@ from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import get_session_with_tenant
from danswer.db.search_settings import get_current_search_settings
@@ -53,6 +62,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
from danswer.redis.redis_pool import get_redis_client
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
@@ -64,6 +74,17 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
# Prometheus metric for HPA
active_tenants_gauge = Gauge(
"active_tenants", "Number of active tenants handled by this pod"
)
# Configuration constants
TENANT_LOCK_EXPIRATION = 300 # seconds
TENANT_HEARTBEAT_INTERVAL = 60 # seconds
TENANT_HEARTBEAT_EXPIRATION = 180 # seconds
TENANT_ACQUISITION_INTERVAL = 60 # seconds
# In rare cases, some users have been experiencing a massive amount of trivial messages coming through
# to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down
# the cause.
@@ -77,10 +98,157 @@ _SLACK_GREETINGS_TO_IGNORE = {
":wave:",
}
# this is always (currently) the user id of Slack's official slackbot
# This is always (currently) the user id of Slack's official slackbot
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
class TenantHandler:
def __init__(self):
logger.info("Initializing TenantHandler")
self.redis_client = get_redis_client(tenant_id=None)
self.tenant_ids: Set[str] = set()
self.socket_clients: Dict[str, TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[str, SlackBotTokens] = {}
self.running = True
self.pod_id = self.get_pod_id()
logger.info(f"Pod ID: {self.pod_id}")
# Set up signal handlers for graceful shutdown
signal.signal(signal.SIGTERM, self.shutdown)
signal.signal(signal.SIGINT, self.shutdown)
logger.info("Signal handlers registered")
# Start the Prometheus metrics server
logger.info("Starting Prometheus metrics server")
start_http_server(8000)
logger.info("Prometheus metrics server started")
# Start background threads
logger.info("Starting background threads")
threading.Thread(target=self.acquire_tenants_loop, daemon=True).start()
threading.Thread(target=self.heartbeat_loop, daemon=True).start()
logger.info("Background threads started")
def get_pod_id(self) -> str:
import os
pod_id = os.environ.get("HOSTNAME", "unknown_pod")
logger.info(f"Retrieved pod ID: {pod_id}")
return pod_id
def acquire_tenants_loop(self):
logger.info("Starting tenant acquisition loop")
while self.running:
try:
self.acquire_tenants()
active_tenants_gauge.set(len(self.tenant_ids))
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
except Exception as e:
logger.exception(f"Error in tenant acquisition: {e}")
Event().wait(timeout=TENANT_ACQUISITION_INTERVAL)
def heartbeat_loop(self):
logger.info("Starting heartbeat loop")
while self.running:
try:
self.send_heartbeats()
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants")
except Exception as e:
logger.exception(f"Error in heartbeat loop: {e}")
Event().wait(timeout=TENANT_HEARTBEAT_INTERVAL)
def acquire_tenants(self):
tenant_ids = get_all_tenant_ids()
logger.debug(f"Found {len(tenant_ids)} total tenants in Postgres")
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
slack_bot_tokens = fetch_tokens()
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if not slack_bot_tokens:
logger.debug(f"No Slack bot token found for tenant {tenant_id}")
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
continue
if (
tenant_id not in self.slack_bot_tokens
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
):
if tenant_id in self.slack_bot_tokens:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
# Initial setup for this tenant
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
if tenant_id in self.socket_clients:
# Close the existing socket client
asyncio.run(self.socket_clients[tenant_id].close())
# Start a new socket client for the tenant
self.start_socket_client(tenant_id, slack_bot_tokens)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
def send_heartbeats(self):
current_time = int(time.time())
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
heartbeat_key = f"tenant_heartbeat:{tenant_id}:{self.pod_id}"
self.redis_client.set(
heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION
)
def start_socket_client(self, tenant_id: str, slack_bot_tokens: SlackBotTokens):
logger.info(f"Starting socket client for tenant {tenant_id}")
socket_client = _get_socket_client(slack_bot_tokens, tenant_id)
# Append the event handler
socket_client.socket_mode_request_listeners.append(process_slack_event)
# Establish a WebSocket connection to the Socket Mode servers
logger.info(f"Connecting socket client for tenant {tenant_id}")
socket_client.connect()
self.socket_clients[tenant_id] = socket_client
self.tenant_ids.add(tenant_id)
logger.info(f"Started SocketModeClient for tenant {tenant_id}")
def stop_socket_clients(self):
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for tenant_id, client in self.socket_clients.items():
asyncio.run(client.close())
logger.info(f"Stopped SocketModeClient for tenant {tenant_id}")
def shutdown(self, signum, frame):
logger.info("Shutting down gracefully")
self.running = False
self.stop_socket_clients()
logger.info("Shutdown complete")
sys.exit(0)
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
if req.type == "events_api":
@@ -172,7 +340,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
message_subtype = event.get("subtype")
if message_subtype not in [None, "file_share"]:
channel_specific_logger.info(
f"Ignoring message with subtype '{message_subtype}' since is is a special message type"
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
)
return False
@@ -247,7 +415,7 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
)
query_event_id, _, _ = decompose_action_id(feedback_id)
logger.notice(f"Successfully handled QA feedback for event: {query_event_id}")
logger.info(f"Successfully handled QA feedback for event: {query_event_id}")
def build_request_details(
@@ -269,14 +437,14 @@ def build_request_details(
msg = remove_danswer_bot_tag(msg, client=client.web_client)
if DANSWER_BOT_REPHRASE_MESSAGE:
logger.notice(f"Rephrasing Slack message. Original message: {msg}")
logger.info(f"Rephrasing Slack message. Original message: {msg}")
try:
msg = rephrase_slack_message(msg)
logger.notice(f"Rephrased message: {msg}")
logger.info(f"Rephrased message: {msg}")
except Exception as e:
logger.error(f"Error while trying to rephrase the Slack message: {e}")
else:
logger.notice(f"Received Slack message: {msg}")
logger.info(f"Received Slack message: {msg}")
if tagged:
logger.debug("User tagged DanswerBot")
@@ -477,94 +645,21 @@ def _get_socket_client(
)
def _initialize_socket_client(socket_client: TenantSocketModeClient) -> None:
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.notice(f"Listening for messages from Slack {socket_client.tenant_id }...")
socket_client.connect()
# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up
# the slack bot in your workspace, and then add the bot to any channels you want to
# try and answer questions for. Running this file will setup Danswer to listen to all
# messages in those channels and attempt to answer them. As of now, it will only respond
# to messages sent directly in the channel - it will not respond to messages sent within a
# thread.
#
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
# without issue.
if __name__ == "__main__":
slack_bot_tokens: dict[str | None, SlackBotTokens] = {}
socket_clients: dict[str | None, TenantSocketModeClient] = {}
# Initialize the tenant handler which will manage tenant connections
logger.info("Starting TenantHandler")
tenant_handler = TenantHandler()
set_is_ee_based_on_env_variable()
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
while True:
try:
tenant_ids = get_all_tenant_ids() # Function to retrieve all tenant IDs
try:
# Keep the main thread alive
while tenant_handler.running:
time.sleep(1)
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
latest_slack_bot_tokens = fetch_tokens()
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if (
tenant_id not in slack_bot_tokens
or latest_slack_bot_tokens != slack_bot_tokens[tenant_id]
):
if tenant_id in slack_bot_tokens:
logger.notice(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
# Initial setup for this tenant
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
slack_bot_tokens[tenant_id] = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated
# to avoid + (1) if the user is changing tokens, they are likely okay with some
# "migration downtime" and (2) if a single message is lost it is okay
# as this should be a very rare occurrence
if tenant_id in socket_clients:
socket_clients[tenant_id].close()
socket_client = _get_socket_client(
latest_slack_bot_tokens, tenant_id
)
# Initialize socket client for this tenant. Each tenant has its own
# socket client, allowing for multiple concurrent connections (one
# per tenant) with the tenant ID wrapped in the socket model client.
# Each `connect` stores websocket connection in a separate thread.
_initialize_socket_client(socket_client)
socket_clients[tenant_id] = socket_client
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in socket_clients:
socket_clients[tenant_id].disconnect()
del socket_clients[tenant_id]
del slack_bot_tokens[tenant_id]
# Wait before checking for updates
Event().wait(timeout=60)
except Exception:
logger.exception("An error occurred outside of main event loop")
time.sleep(60)
except Exception:
logger.exception("Fatal error in main thread")
tenant_handler.shutdown(None, None)