Slackbot optimization (#3696)

* initial pass

* update

* nit

* nit

* bot -> app

* nit

* quick update

* various improvements

* k

* k

* nit
This commit is contained in:
pablonyx 2025-01-20 11:46:52 -08:00 committed by GitHub
parent fe3eae3680
commit cc4953b560
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,6 +14,7 @@ from typing import Set
from prometheus_client import Gauge from prometheus_client import Gauge
from prometheus_client import start_http_server from prometheus_client import start_http_server
from redis.lock import Lock
from slack_sdk import WebClient from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse from slack_sdk.socket_mode.response import SocketModeResponse
@ -122,6 +123,9 @@ class SlackbotHandler:
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {} self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {} self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
# Store Redis lock objects here so we can release them properly
self.redis_locks: Dict[str | None, Lock] = {}
self.running = True self.running = True
self.pod_id = self.get_pod_id() self.pod_id = self.get_pod_id()
self._shutdown_event = Event() self._shutdown_event = Event()
@ -159,10 +163,15 @@ class SlackbotHandler:
while not self._shutdown_event.is_set(): while not self._shutdown_event.is_set():
try: try:
self.acquire_tenants() self.acquire_tenants()
# After we finish acquiring and managing Slack bots,
# set the gauge to the number of active tenants (those with Slack bots).
active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set( active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set(
len(self.tenant_ids) len(self.tenant_ids)
) )
logger.debug(f"Current active tenants: {len(self.tenant_ids)}") logger.debug(
f"Current active tenants with Slack bots: {len(self.tenant_ids)}"
)
except Exception as e: except Exception as e:
logger.exception(f"Error in Slack acquisition: {e}") logger.exception(f"Error in Slack acquisition: {e}")
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL) self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
@ -171,7 +180,9 @@ class SlackbotHandler:
while not self._shutdown_event.is_set(): while not self._shutdown_event.is_set():
try: try:
self.send_heartbeats() self.send_heartbeats()
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants") logger.debug(
f"Sent heartbeats for {len(self.tenant_ids)} active tenants"
)
except Exception as e: except Exception as e:
logger.exception(f"Error in heartbeat loop: {e}") logger.exception(f"Error in heartbeat loop: {e}")
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL) self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
@ -179,17 +190,21 @@ class SlackbotHandler:
def _manage_clients_per_tenant( def _manage_clients_per_tenant(
self, db_session: Session, tenant_id: str | None, bot: SlackBot self, db_session: Session, tenant_id: str | None, bot: SlackBot
) -> None: ) -> None:
"""
- If the tokens are missing or empty, close the socket client and remove them.
- If the tokens have changed, close the existing socket client and reconnect.
- If the tokens are new, warm up the model and start a new socket client.
"""
slack_bot_tokens = SlackBotTokens( slack_bot_tokens = SlackBotTokens(
bot_token=bot.bot_token, bot_token=bot.bot_token,
app_token=bot.app_token, app_token=bot.app_token,
) )
tenant_bot_pair = (tenant_id, bot.id) tenant_bot_pair = (tenant_id, bot.id)
# If the tokens are not set, we need to close the socket client and delete the tokens # If the tokens are missing or empty, close the socket client and remove them.
# for the tenant and app
if not slack_bot_tokens: if not slack_bot_tokens:
logger.debug( logger.debug(
f"No Slack bot token found for tenant {tenant_id}, bot {bot.id}" f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
) )
if tenant_bot_pair in self.socket_clients: if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close()) asyncio.run(self.socket_clients[tenant_bot_pair].close())
@ -204,9 +219,10 @@ class SlackbotHandler:
if not tokens_exist or tokens_changed: if not tokens_exist or tokens_changed:
if tokens_exist: if tokens_exist:
logger.info( logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id}, bot {bot.id} - reconnecting" f"Slack Bot tokens changed for tenant={tenant_id}, bot {bot.id}; reconnecting"
) )
else: else:
# Warm up the model if needed
search_settings = get_current_search_settings(db_session) search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model( embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings, search_settings=search_settings,
@ -217,77 +233,168 @@ class SlackbotHandler:
self.slack_bot_tokens[tenant_bot_pair] = slack_bot_tokens self.slack_bot_tokens[tenant_bot_pair] = slack_bot_tokens
# Close any existing connection first
if tenant_bot_pair in self.socket_clients: if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close()) asyncio.run(self.socket_clients[tenant_bot_pair].close())
self.start_socket_client(bot.id, tenant_id, slack_bot_tokens) self.start_socket_client(bot.id, tenant_id, slack_bot_tokens)
def acquire_tenants(self) -> None: def acquire_tenants(self) -> None:
tenant_ids = get_all_tenant_ids() """
- Attempt to acquire a Redis lock for each tenant.
- If acquired, check if that tenant actually has Slack bots.
- If yes, store them in self.tenant_ids and manage the socket connections.
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
"""
all_tenants = get_all_tenant_ids()
for tenant_id in tenant_ids: # 1) Try to acquire locks for new tenants
for tenant_id in all_tenants:
if ( if (
DISALLOWED_SLACK_BOT_TENANT_LIST is not None DISALLOWED_SLACK_BOT_TENANT_LIST is not None
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
): ):
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping") logger.debug(f"Tenant {tenant_id} is disallowed; skipping.")
continue continue
# Already acquired in a previous loop iteration?
if tenant_id in self.tenant_ids: if tenant_id in self.tenant_ids:
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
continue continue
# Respect max tenant limit per pod
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD: if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
logger.info( logger.info(
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}) Not acquiring any more tenants" f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}); not acquiring more."
) )
break break
redis_client = get_redis_client(tenant_id=tenant_id) redis_client = get_redis_client(tenant_id=tenant_id)
pod_id = self.pod_id # Acquire a Redis lock (non-blocking)
acquired = redis_client.set( rlock = redis_client.lock(
OnyxRedisLocks.SLACK_BOT_LOCK, OnyxRedisLocks.SLACK_BOT_LOCK, timeout=TENANT_LOCK_EXPIRATION
pod_id,
nx=True,
ex=TENANT_LOCK_EXPIRATION,
) )
if not acquired and not DEV_MODE: lock_acquired = rlock.acquire(blocking=False)
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
if not lock_acquired and not DEV_MODE:
logger.debug(
f"Another pod holds the lock for tenant {tenant_id}, skipping."
)
continue continue
logger.debug(f"Acquired lock for tenant {tenant_id}") if lock_acquired:
logger.debug(f"Acquired lock for tenant {tenant_id}.")
self.redis_locks[tenant_id] = rlock
else:
# DEV_MODE will skip the lock acquisition guard
logger.debug(
f"Running in DEV_MODE. Not enforcing lock for {tenant_id}."
)
self.tenant_ids.add(tenant_id) # Now check if this tenant actually has Slack bots
for tenant_id in self.tenant_ids:
token = CURRENT_TENANT_ID_CONTEXTVAR.set( token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA tenant_id or POSTGRES_DEFAULT_SCHEMA
) )
try: try:
with get_session_with_tenant(tenant_id) as db_session: with get_session_with_tenant(tenant_id) as db_session:
bots: list[SlackBot] = []
try: try:
bots = fetch_slack_bots(db_session=db_session) bots = list(fetch_slack_bots(db_session=db_session))
except KvKeyNotFoundError:
# No Slackbot tokens, pass
pass
except Exception as e:
logger.exception(
f"Error fetching Slack bots for tenant {tenant_id}: {e}"
)
if bots:
# Mark as active tenant
self.tenant_ids.add(tenant_id)
for bot in bots: for bot in bots:
self._manage_clients_per_tenant( self._manage_clients_per_tenant(
db_session=db_session, db_session=db_session,
tenant_id=tenant_id, tenant_id=tenant_id,
bot=bot, bot=bot,
) )
else:
except KvKeyNotFoundError: # If no Slack bots, release lock immediately (unless in DEV_MODE)
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}") if lock_acquired and not DEV_MODE:
if (tenant_id, bot.id) in self.socket_clients: rlock.release()
asyncio.run(self.socket_clients[tenant_id, bot.id].close()) del self.redis_locks[tenant_id]
del self.socket_clients[tenant_id, bot.id] logger.debug(
del self.slack_bot_tokens[tenant_id, bot.id] f"No Slack bots for tenant {tenant_id}; lock released (if held)."
except Exception as e: )
logger.exception(f"Error handling tenant {tenant_id}: {e}")
finally: finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token) CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
# 2) Make sure tenants we're handling still have Slack bots
for tenant_id in list(self.tenant_ids):
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
redis_client = get_redis_client(tenant_id=tenant_id)
try:
with get_session_with_tenant(tenant_id) as db_session:
# Attempt to fetch Slack bots
try:
bots = list(fetch_slack_bots(db_session=db_session))
except KvKeyNotFoundError:
# No Slackbot tokens, pass (and remove below)
bots = []
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
bots = []
if not bots:
logger.info(
f"Tenant {tenant_id} no longer has Slack bots. Removing."
)
self._remove_tenant(tenant_id)
# NOTE: We release the lock here (in the same scope it was acquired)
if tenant_id in self.redis_locks and not DEV_MODE:
try:
self.redis_locks[tenant_id].release()
del self.redis_locks[tenant_id]
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(
f"Error releasing lock for tenant {tenant_id}: {e}"
)
else:
# Manage or reconnect Slack bot sockets
for bot in bots:
self._manage_clients_per_tenant(
db_session=db_session,
tenant_id=tenant_id,
bot=bot,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _remove_tenant(self, tenant_id: str | None) -> None:
"""
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
(Lock release now happens in `acquire_tenants()`, not here.)
"""
# Close all socket clients for this tenant
for (t_id, slack_bot_id), client in list(self.socket_clients.items()):
if t_id == tenant_id:
asyncio.run(client.close())
del self.socket_clients[(t_id, slack_bot_id)]
del self.slack_bot_tokens[(t_id, slack_bot_id)]
logger.info(
f"Stopped SocketModeClient for tenant: {t_id}, app: {slack_bot_id}"
)
# Remove from active set
if tenant_id in self.tenant_ids:
self.tenant_ids.remove(tenant_id)
def send_heartbeats(self) -> None: def send_heartbeats(self) -> None:
current_time = int(time.time()) current_time = int(time.time())
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants") logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} active tenants")
for tenant_id in self.tenant_ids: for tenant_id in self.tenant_ids:
redis_client = get_redis_client(tenant_id=tenant_id) redis_client = get_redis_client(tenant_id=tenant_id)
heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}" heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
@ -315,6 +422,7 @@ class SlackbotHandler:
) )
socket_client.connect() socket_client.connect()
self.socket_clients[tenant_id, slack_bot_id] = socket_client self.socket_clients[tenant_id, slack_bot_id] = socket_client
# Ensure tenant is tracked as active
self.tenant_ids.add(tenant_id) self.tenant_ids.add(tenant_id)
logger.info( logger.info(
f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}" f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
@ -322,7 +430,7 @@ class SlackbotHandler:
def stop_socket_clients(self) -> None: def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients") logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for (tenant_id, slack_bot_id), client in self.socket_clients.items(): for (tenant_id, slack_bot_id), client in list(self.socket_clients.items()):
asyncio.run(client.close()) asyncio.run(client.close())
logger.info( logger.info(
f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}" f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
@ -340,17 +448,19 @@ class SlackbotHandler:
logger.info(f"Stopping {len(self.socket_clients)} socket clients") logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients() self.stop_socket_clients()
# Release locks for all tenants # Release locks for all tenants we currently hold
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants") logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids: for tenant_id in list(self.tenant_ids):
try: if tenant_id in self.redis_locks:
redis_client = get_redis_client(tenant_id=tenant_id) try:
redis_client.delete(OnyxRedisLocks.SLACK_BOT_LOCK) self.redis_locks[tenant_id].release()
logger.info(f"Released lock for tenant {tenant_id}") logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e: except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}") logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
finally:
del self.redis_locks[tenant_id]
# Wait for background threads to finish (with timeout) # Wait for background threads to finish (with a timeout)
logger.info("Waiting for background threads to finish...") logger.info("Waiting for background threads to finish...")
self.acquire_thread.join(timeout=5) self.acquire_thread.join(timeout=5)
self.heartbeat_thread.join(timeout=5) self.heartbeat_thread.join(timeout=5)