Slackbot optimization (#3696)

* initial pass

* update

* nit

* nit

* bot -> app

* nit

* quick update

* various improvements

* k

* k

* nit
This commit is contained in:
pablonyx 2025-01-20 11:46:52 -08:00 committed by GitHub
parent fe3eae3680
commit cc4953b560
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,6 +14,7 @@ from typing import Set
from prometheus_client import Gauge
from prometheus_client import start_http_server
from redis.lock import Lock
from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
@ -122,6 +123,9 @@ class SlackbotHandler:
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
# Store Redis lock objects here so we can release them properly
self.redis_locks: Dict[str | None, Lock] = {}
self.running = True
self.pod_id = self.get_pod_id()
self._shutdown_event = Event()
@ -159,10 +163,15 @@ class SlackbotHandler:
while not self._shutdown_event.is_set():
try:
self.acquire_tenants()
# After we finish acquiring and managing Slack bots,
# set the gauge to the number of active tenants (those with Slack bots).
active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set(
len(self.tenant_ids)
)
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
logger.debug(
f"Current active tenants with Slack bots: {len(self.tenant_ids)}"
)
except Exception as e:
logger.exception(f"Error in Slack acquisition: {e}")
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
@ -171,7 +180,9 @@ class SlackbotHandler:
while not self._shutdown_event.is_set():
try:
self.send_heartbeats()
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants")
logger.debug(
f"Sent heartbeats for {len(self.tenant_ids)} active tenants"
)
except Exception as e:
logger.exception(f"Error in heartbeat loop: {e}")
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
@ -179,17 +190,21 @@ class SlackbotHandler:
def _manage_clients_per_tenant(
self, db_session: Session, tenant_id: str | None, bot: SlackBot
) -> None:
"""
- If the tokens are missing or empty, close the socket client and remove them.
- If the tokens have changed, close the existing socket client and reconnect.
- If the tokens are new, warm up the model and start a new socket client.
"""
slack_bot_tokens = SlackBotTokens(
bot_token=bot.bot_token,
app_token=bot.app_token,
)
tenant_bot_pair = (tenant_id, bot.id)
# If the tokens are not set, we need to close the socket client and delete the tokens
# for the tenant and app
# If the tokens are missing or empty, close the socket client and remove them.
if not slack_bot_tokens:
logger.debug(
f"No Slack bot token found for tenant {tenant_id}, bot {bot.id}"
f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
)
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
@ -204,9 +219,10 @@ class SlackbotHandler:
if not tokens_exist or tokens_changed:
if tokens_exist:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id}, bot {bot.id} - reconnecting"
f"Slack Bot tokens changed for tenant={tenant_id}, bot {bot.id}; reconnecting"
)
else:
# Warm up the model if needed
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
@ -217,77 +233,168 @@ class SlackbotHandler:
self.slack_bot_tokens[tenant_bot_pair] = slack_bot_tokens
# Close any existing connection first
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
self.start_socket_client(bot.id, tenant_id, slack_bot_tokens)
def acquire_tenants(self) -> None:
tenant_ids = get_all_tenant_ids()
"""
- Attempt to acquire a Redis lock for each tenant.
- If acquired, check if that tenant actually has Slack bots.
- If yes, store them in self.tenant_ids and manage the socket connections.
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
"""
all_tenants = get_all_tenant_ids()
for tenant_id in tenant_ids:
# 1) Try to acquire locks for new tenants
for tenant_id in all_tenants:
if (
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
):
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
logger.debug(f"Tenant {tenant_id} is disallowed; skipping.")
continue
# Already acquired in a previous loop iteration?
if tenant_id in self.tenant_ids:
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
continue
# Respect max tenant limit per pod
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
logger.info(
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}) Not acquiring any more tenants"
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}); not acquiring more."
)
break
redis_client = get_redis_client(tenant_id=tenant_id)
pod_id = self.pod_id
acquired = redis_client.set(
OnyxRedisLocks.SLACK_BOT_LOCK,
pod_id,
nx=True,
ex=TENANT_LOCK_EXPIRATION,
# Acquire a Redis lock (non-blocking)
rlock = redis_client.lock(
OnyxRedisLocks.SLACK_BOT_LOCK, timeout=TENANT_LOCK_EXPIRATION
)
if not acquired and not DEV_MODE:
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
lock_acquired = rlock.acquire(blocking=False)
if not lock_acquired and not DEV_MODE:
logger.debug(
f"Another pod holds the lock for tenant {tenant_id}, skipping."
)
continue
logger.debug(f"Acquired lock for tenant {tenant_id}")
if lock_acquired:
logger.debug(f"Acquired lock for tenant {tenant_id}.")
self.redis_locks[tenant_id] = rlock
else:
# DEV_MODE will skip the lock acquisition guard
logger.debug(
f"Running in DEV_MODE. Not enforcing lock for {tenant_id}."
)
self.tenant_ids.add(tenant_id)
for tenant_id in self.tenant_ids:
# Now check if this tenant actually has Slack bots
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
try:
with get_session_with_tenant(tenant_id) as db_session:
bots: list[SlackBot] = []
try:
bots = fetch_slack_bots(db_session=db_session)
bots = list(fetch_slack_bots(db_session=db_session))
except KvKeyNotFoundError:
# No Slackbot tokens, pass
pass
except Exception as e:
logger.exception(
f"Error fetching Slack bots for tenant {tenant_id}: {e}"
)
if bots:
# Mark as active tenant
self.tenant_ids.add(tenant_id)
for bot in bots:
self._manage_clients_per_tenant(
db_session=db_session,
tenant_id=tenant_id,
bot=bot,
)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if (tenant_id, bot.id) in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id, bot.id].close())
del self.socket_clients[tenant_id, bot.id]
del self.slack_bot_tokens[tenant_id, bot.id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
else:
# If no Slack bots, release lock immediately (unless in DEV_MODE)
if lock_acquired and not DEV_MODE:
rlock.release()
del self.redis_locks[tenant_id]
logger.debug(
f"No Slack bots for tenant {tenant_id}; lock released (if held)."
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
# 2) Make sure tenants we're handling still have Slack bots
for tenant_id in list(self.tenant_ids):
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
redis_client = get_redis_client(tenant_id=tenant_id)
try:
with get_session_with_tenant(tenant_id) as db_session:
# Attempt to fetch Slack bots
try:
bots = list(fetch_slack_bots(db_session=db_session))
except KvKeyNotFoundError:
# No Slackbot tokens, pass (and remove below)
bots = []
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
bots = []
if not bots:
logger.info(
f"Tenant {tenant_id} no longer has Slack bots. Removing."
)
self._remove_tenant(tenant_id)
# NOTE: We release the lock here (in the same scope it was acquired)
if tenant_id in self.redis_locks and not DEV_MODE:
try:
self.redis_locks[tenant_id].release()
del self.redis_locks[tenant_id]
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(
f"Error releasing lock for tenant {tenant_id}: {e}"
)
else:
# Manage or reconnect Slack bot sockets
for bot in bots:
self._manage_clients_per_tenant(
db_session=db_session,
tenant_id=tenant_id,
bot=bot,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _remove_tenant(self, tenant_id: str | None) -> None:
"""
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
(Lock release now happens in `acquire_tenants()`, not here.)
"""
# Close all socket clients for this tenant
for (t_id, slack_bot_id), client in list(self.socket_clients.items()):
if t_id == tenant_id:
asyncio.run(client.close())
del self.socket_clients[(t_id, slack_bot_id)]
del self.slack_bot_tokens[(t_id, slack_bot_id)]
logger.info(
f"Stopped SocketModeClient for tenant: {t_id}, app: {slack_bot_id}"
)
# Remove from active set
if tenant_id in self.tenant_ids:
self.tenant_ids.remove(tenant_id)
def send_heartbeats(self) -> None:
current_time = int(time.time())
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} active tenants")
for tenant_id in self.tenant_ids:
redis_client = get_redis_client(tenant_id=tenant_id)
heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
@ -315,6 +422,7 @@ class SlackbotHandler:
)
socket_client.connect()
self.socket_clients[tenant_id, slack_bot_id] = socket_client
# Ensure tenant is tracked as active
self.tenant_ids.add(tenant_id)
logger.info(
f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
@ -322,7 +430,7 @@ class SlackbotHandler:
def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for (tenant_id, slack_bot_id), client in self.socket_clients.items():
for (tenant_id, slack_bot_id), client in list(self.socket_clients.items()):
asyncio.run(client.close())
logger.info(
f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
@ -340,17 +448,19 @@ class SlackbotHandler:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients()
# Release locks for all tenants
# Release locks for all tenants we currently hold
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
for tenant_id in self.tenant_ids:
try:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(OnyxRedisLocks.SLACK_BOT_LOCK)
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
for tenant_id in list(self.tenant_ids):
if tenant_id in self.redis_locks:
try:
self.redis_locks[tenant_id].release()
logger.info(f"Released lock for tenant {tenant_id}")
except Exception as e:
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
finally:
del self.redis_locks[tenant_id]
# Wait for background threads to finish (with timeout)
# Wait for background threads to finish (with a timeout)
logger.info("Waiting for background threads to finish...")
self.acquire_thread.join(timeout=5)
self.heartbeat_thread.join(timeout=5)