Push multi tenancy for slackbot (#2828)

* push multi tenancy for slackbot

* move to utils

* k

* k

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
This commit is contained in:
pablodanswer
2024-10-17 14:04:48 -07:00
committed by GitHub
parent e48086b1c2
commit b169f78699
5 changed files with 111 additions and 85 deletions

View File

@ -4,9 +4,7 @@ from typing import cast
from slack_sdk import WebClient from slack_sdk import WebClient
from slack_sdk.models.blocks import SectionBlock from slack_sdk.models.blocks import SectionBlock
from slack_sdk.models.views import View from slack_sdk.models.views import View
from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.request import SocketModeRequest
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType from danswer.configs.constants import SearchFeedbackType
@ -35,20 +33,22 @@ from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_feedback_visibility from danswer.danswerbot.slack.utils import get_feedback_visibility
from danswer.danswerbot.slack.utils import read_slack_thread from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.danswerbot.slack.utils import update_emote_react from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_session_with_tenant
from danswer.db.feedback import create_chat_message_feedback from danswer.db.feedback import create_chat_message_feedback
from danswer.db.feedback import create_doc_retrieval_feedback from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
def handle_doc_feedback_button( def handle_doc_feedback_button(
req: SocketModeRequest, req: SocketModeRequest,
client: SocketModeClient, client: TenantSocketModeClient,
) -> None: ) -> None:
if not (actions := req.payload.get("actions")): if not (actions := req.payload.get("actions")):
logger.error("Missing actions. Unable to build the source feedback view") logger.error("Missing actions. Unable to build the source feedback view")
@ -81,7 +81,7 @@ def handle_doc_feedback_button(
def handle_generate_answer_button( def handle_generate_answer_button(
req: SocketModeRequest, req: SocketModeRequest,
client: SocketModeClient, client: TenantSocketModeClient,
) -> None: ) -> None:
channel_id = req.payload["channel"]["id"] channel_id = req.payload["channel"]["id"]
channel_name = req.payload["channel"]["name"] channel_name = req.payload["channel"]["name"]
@ -116,7 +116,7 @@ def handle_generate_answer_button(
thread_ts=thread_ts, thread_ts=thread_ts,
) )
with Session(get_sqlalchemy_engine()) as db_session: with get_session_with_tenant(client.tenant_id) as db_session:
slack_bot_config = get_slack_bot_config_for_channel( slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session channel_name=channel_name, db_session=db_session
) )
@ -136,6 +136,7 @@ def handle_generate_answer_button(
slack_bot_config=slack_bot_config, slack_bot_config=slack_bot_config,
receiver_ids=None, receiver_ids=None,
client=client.web_client, client=client.web_client,
tenant_id=client.tenant_id,
channel=channel_id, channel=channel_id,
logger=logger, logger=logger,
feedback_reminder_id=None, feedback_reminder_id=None,
@ -150,12 +151,11 @@ def handle_slack_feedback(
user_id_to_post_confirmation: str, user_id_to_post_confirmation: str,
channel_id_to_post_confirmation: str, channel_id_to_post_confirmation: str,
thread_ts_to_post_confirmation: str, thread_ts_to_post_confirmation: str,
tenant_id: str | None,
) -> None: ) -> None:
engine = get_sqlalchemy_engine()
message_id, doc_id, doc_rank = decompose_action_id(feedback_id) message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
with Session(engine) as db_session: with get_session_with_tenant(tenant_id) as db_session:
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]: if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
create_chat_message_feedback( create_chat_message_feedback(
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID, is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
@ -232,7 +232,7 @@ def handle_slack_feedback(
def handle_followup_button( def handle_followup_button(
req: SocketModeRequest, req: SocketModeRequest,
client: SocketModeClient, client: TenantSocketModeClient,
) -> None: ) -> None:
action_id = None action_id = None
if actions := req.payload.get("actions"): if actions := req.payload.get("actions"):
@ -252,7 +252,7 @@ def handle_followup_button(
tag_ids: list[str] = [] tag_ids: list[str] = []
group_ids: list[str] = [] group_ids: list[str] = []
with Session(get_sqlalchemy_engine()) as db_session: with get_session_with_tenant(client.tenant_id) as db_session:
channel_name, is_dm = get_channel_name_from_id( channel_name, is_dm = get_channel_name_from_id(
client=client.web_client, channel_id=channel_id client=client.web_client, channel_id=channel_id
) )
@ -295,7 +295,7 @@ def handle_followup_button(
def get_clicker_name( def get_clicker_name(
req: SocketModeRequest, req: SocketModeRequest,
client: SocketModeClient, client: TenantSocketModeClient,
) -> str: ) -> str:
clicker_name = req.payload.get("user", {}).get("name", "Someone") clicker_name = req.payload.get("user", {}).get("name", "Someone")
clicker_real_name = None clicker_real_name = None
@ -316,7 +316,7 @@ def get_clicker_name(
def handle_followup_resolved_button( def handle_followup_resolved_button(
req: SocketModeRequest, req: SocketModeRequest,
client: SocketModeClient, client: TenantSocketModeClient,
immediate: bool = False, immediate: bool = False,
) -> None: ) -> None:
channel_id = req.payload["container"]["channel_id"] channel_id = req.payload["container"]["channel_id"]

View File

@ -2,7 +2,6 @@ import datetime
from slack_sdk import WebClient from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session
from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
@ -19,7 +18,7 @@ from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import slack_usage_report from danswer.danswerbot.slack.utils import slack_usage_report
from danswer.danswerbot.slack.utils import update_emote_react from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_session_with_tenant
from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotConfig
from danswer.db.users import add_non_web_user_if_not_exists from danswer.db.users import add_non_web_user_if_not_exists
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@ -110,6 +109,7 @@ def handle_message(
slack_bot_config: SlackBotConfig | None, slack_bot_config: SlackBotConfig | None,
client: WebClient, client: WebClient,
feedback_reminder_id: str | None, feedback_reminder_id: str | None,
tenant_id: str | None,
) -> bool: ) -> bool:
"""Potentially respond to the user message depending on filters and if an answer was generated """Potentially respond to the user message depending on filters and if an answer was generated
@ -135,7 +135,9 @@ def handle_message(
action = "slack_tag_message" action = "slack_tag_message"
elif is_bot_dm: elif is_bot_dm:
action = "slack_dm_message" action = "slack_dm_message"
slack_usage_report(action=action, sender_id=sender_id, client=client) slack_usage_report(
action=action, sender_id=sender_id, client=client, tenant_id=tenant_id
)
document_set_names: list[str] | None = None document_set_names: list[str] | None = None
persona = slack_bot_config.persona if slack_bot_config else None persona = slack_bot_config.persona if slack_bot_config else None
@ -209,7 +211,7 @@ def handle_message(
except SlackApiError as e: except SlackApiError as e:
logger.error(f"Was not able to react to user message due to: {e}") logger.error(f"Was not able to react to user message due to: {e}")
with Session(get_sqlalchemy_engine()) as db_session: with get_session_with_tenant(tenant_id) as db_session:
if message_info.email: if message_info.email:
add_non_web_user_if_not_exists(db_session, message_info.email) add_non_web_user_if_not_exists(db_session, message_info.email)
@ -235,5 +237,6 @@ def handle_message(
channel=channel, channel=channel,
logger=logger, logger=logger,
feedback_reminder_id=feedback_reminder_id, feedback_reminder_id=feedback_reminder_id,
tenant_id=tenant_id,
) )
return issue_with_regular_answer return issue_with_regular_answer

View File

@ -9,7 +9,6 @@ from retry import retry
from slack_sdk import WebClient from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock from slack_sdk.models.blocks import SectionBlock
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
@ -33,7 +32,7 @@ from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import SlackRateLimiter from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_session_with_tenant
from danswer.db.models import Persona from danswer.db.models import Persona
from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType from danswer.db.models import SlackBotResponseType
@ -88,6 +87,7 @@ def handle_regular_answer(
channel: str, channel: str,
logger: DanswerLoggingAdapter, logger: DanswerLoggingAdapter,
feedback_reminder_id: str | None, feedback_reminder_id: str | None,
tenant_id: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES, num_retries: int = DANSWER_BOT_NUM_RETRIES,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT, answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE, thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
@ -104,8 +104,7 @@ def handle_regular_answer(
user = None user = None
if message_info.is_bot_dm: if message_info.is_bot_dm:
if message_info.email: if message_info.email:
engine = get_sqlalchemy_engine() with get_session_with_tenant(tenant_id) as db_session:
with Session(engine) as db_session:
user = get_user_by_email(message_info.email, db_session) user = get_user_by_email(message_info.email, db_session)
document_set_names: list[str] | None = None document_set_names: list[str] | None = None
@ -152,7 +151,7 @@ def handle_regular_answer(
max_document_tokens: int | None = None max_document_tokens: int | None = None
max_history_tokens: int | None = None max_history_tokens: int | None = None
with Session(get_sqlalchemy_engine()) as db_session: with get_session_with_tenant(tenant_id) as db_session:
if len(new_message_request.messages) > 1: if len(new_message_request.messages) > 1:
if new_message_request.persona_config: if new_message_request.persona_config:
raise RuntimeError("Slack bot does not support persona config") raise RuntimeError("Slack bot does not support persona config")
@ -246,7 +245,7 @@ def handle_regular_answer(
) )
# Always apply reranking settings if it exists, this is the non-streaming flow # Always apply reranking settings if it exists, this is the non-streaming flow
with Session(get_sqlalchemy_engine()) as db_session: with get_session_with_tenant(tenant_id) as db_session:
saved_search_settings = get_current_search_settings(db_session) saved_search_settings = get_current_search_settings(db_session)
# This includes throwing out answer via reflexion # This includes throwing out answer via reflexion

View File

@ -4,11 +4,10 @@ from typing import Any
from typing import cast from typing import cast
from slack_sdk import WebClient from slack_sdk import WebClient
from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import get_all_tenant_ids
from danswer.configs.constants import MessageType from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
@ -47,7 +46,8 @@ from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.engine import get_sqlalchemy_engine from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.db.engine import get_session_with_tenant
from danswer.db.search_settings import get_current_search_settings from danswer.db.search_settings import get_current_search_settings
from danswer.key_value_store.interface import KvKeyNotFoundError from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
@ -80,7 +80,7 @@ _SLACK_GREETINGS_TO_IGNORE = {
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT" _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool: def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request""" """True to keep going, False to ignore this Slack request"""
if req.type == "events_api": if req.type == "events_api":
# Verify channel is valid # Verify channel is valid
@ -153,8 +153,7 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
client=client.web_client, channel_id=channel client=client.web_client, channel_id=channel
) )
engine = get_sqlalchemy_engine() with get_session_with_tenant(client.tenant_id) as db_session:
with Session(engine) as db_session:
slack_bot_config = get_slack_bot_config_for_channel( slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session channel_name=channel_name, db_session=db_session
) )
@ -221,7 +220,7 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
return True return True
def process_feedback(req: SocketModeRequest, client: SocketModeClient) -> None: def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
if actions := req.payload.get("actions"): if actions := req.payload.get("actions"):
action = cast(dict[str, Any], actions[0]) action = cast(dict[str, Any], actions[0])
feedback_type = cast(str, action.get("action_id")) feedback_type = cast(str, action.get("action_id"))
@ -243,6 +242,7 @@ def process_feedback(req: SocketModeRequest, client: SocketModeClient) -> None:
user_id_to_post_confirmation=user_id, user_id_to_post_confirmation=user_id,
channel_id_to_post_confirmation=channel_id, channel_id_to_post_confirmation=channel_id,
thread_ts_to_post_confirmation=thread_ts, thread_ts_to_post_confirmation=thread_ts,
tenant_id=client.tenant_id,
) )
query_event_id, _, _ = decompose_action_id(feedback_id) query_event_id, _, _ = decompose_action_id(feedback_id)
@ -250,7 +250,7 @@ def process_feedback(req: SocketModeRequest, client: SocketModeClient) -> None:
def build_request_details( def build_request_details(
req: SocketModeRequest, client: SocketModeClient req: SocketModeRequest, client: TenantSocketModeClient
) -> SlackMessageInfo: ) -> SlackMessageInfo:
if req.type == "events_api": if req.type == "events_api":
event = cast(dict[str, Any], req.payload["event"]) event = cast(dict[str, Any], req.payload["event"])
@ -329,7 +329,7 @@ def build_request_details(
def apologize_for_fail( def apologize_for_fail(
details: SlackMessageInfo, details: SlackMessageInfo,
client: SocketModeClient, client: TenantSocketModeClient,
) -> None: ) -> None:
respond_in_thread( respond_in_thread(
client=client.web_client, client=client.web_client,
@ -341,7 +341,7 @@ def apologize_for_fail(
def process_message( def process_message(
req: SocketModeRequest, req: SocketModeRequest,
client: SocketModeClient, client: TenantSocketModeClient,
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL, respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER, notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
) -> None: ) -> None:
@ -357,8 +357,7 @@ def process_message(
client=client.web_client, channel_id=channel client=client.web_client, channel_id=channel
) )
engine = get_sqlalchemy_engine() with get_session_with_tenant(client.tenant_id) as db_session:
with Session(engine) as db_session:
slack_bot_config = get_slack_bot_config_for_channel( slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session channel_name=channel_name, db_session=db_session
) )
@ -390,6 +389,7 @@ def process_message(
slack_bot_config=slack_bot_config, slack_bot_config=slack_bot_config,
client=client.web_client, client=client.web_client,
feedback_reminder_id=feedback_reminder_id, feedback_reminder_id=feedback_reminder_id,
tenant_id=client.tenant_id,
) )
if failed: if failed:
@ -404,12 +404,12 @@ def process_message(
apologize_for_fail(details, client) apologize_for_fail(details, client)
def acknowledge_message(req: SocketModeRequest, client: SocketModeClient) -> None: def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
response = SocketModeResponse(envelope_id=req.envelope_id) response = SocketModeResponse(envelope_id=req.envelope_id)
client.send_socket_mode_response(response) client.send_socket_mode_response(response)
def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None: def action_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
if actions := req.payload.get("actions"): if actions := req.payload.get("actions"):
action = cast(dict[str, Any], actions[0]) action = cast(dict[str, Any], actions[0])
@ -429,13 +429,13 @@ def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
return handle_generate_answer_button(req, client) return handle_generate_answer_button(req, client)
def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None: def view_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
if view := req.payload.get("view"): if view := req.payload.get("view"):
if view["callback_id"] == VIEW_DOC_FEEDBACK_ID: if view["callback_id"] == VIEW_DOC_FEEDBACK_ID:
return process_feedback(req, client) return process_feedback(req, client)
def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> None: def process_slack_event(client: TenantSocketModeClient, req: SocketModeRequest) -> None:
# Always respond right away, if Slack doesn't receive these frequently enough # Always respond right away, if Slack doesn't receive these frequently enough
# it will assume the Bot is DEAD!!! :( # it will assume the Bot is DEAD!!! :(
acknowledge_message(req, client) acknowledge_message(req, client)
@ -453,21 +453,24 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
logger.error(f"Slack request payload: {req.payload}") logger.error(f"Slack request payload: {req.payload}")
def _get_socket_client(slack_bot_tokens: SlackBotTokens) -> SocketModeClient: def _get_socket_client(
slack_bot_tokens: SlackBotTokens, tenant_id: str | None
) -> TenantSocketModeClient:
# For more info on how to set this up, checkout the docs: # For more info on how to set this up, checkout the docs:
# https://docs.danswer.dev/slack_bot_setup # https://docs.danswer.dev/slack_bot_setup
return SocketModeClient( return TenantSocketModeClient(
# This app-level token will be used only for establishing a connection # This app-level token will be used only for establishing a connection
app_token=slack_bot_tokens.app_token, app_token=slack_bot_tokens.app_token,
web_client=WebClient(token=slack_bot_tokens.bot_token), web_client=WebClient(token=slack_bot_tokens.bot_token),
tenant_id=tenant_id,
) )
def _initialize_socket_client(socket_client: SocketModeClient) -> None: def _initialize_socket_client(socket_client: TenantSocketModeClient) -> None:
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers # Establish a WebSocket connection to the Socket Mode servers
logger.notice("Listening for messages from Slack...") logger.notice(f"Listening for messages from Slack {socket_client.tenant_id }...")
socket_client.connect() socket_client.connect()
@ -481,8 +484,8 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None:
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC # NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
# without issue. # without issue.
if __name__ == "__main__": if __name__ == "__main__":
slack_bot_tokens: SlackBotTokens | None = None slack_bot_tokens: dict[str | None, SlackBotTokens] = {}
socket_client: SocketModeClient | None = None socket_clients: dict[str | None, TenantSocketModeClient] = {}
set_is_ee_based_on_env_variable() set_is_ee_based_on_env_variable()
@ -490,47 +493,60 @@ if __name__ == "__main__":
download_nltk_data() download_nltk_data()
while True: while True:
try:
tenant_ids = get_all_tenant_ids() # Function to retrieve all tenant IDs
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session:
try: try:
latest_slack_bot_tokens = fetch_tokens() latest_slack_bot_tokens = fetch_tokens()
if latest_slack_bot_tokens != slack_bot_tokens: if (
if slack_bot_tokens is not None: tenant_id not in slack_bot_tokens
logger.notice("Slack Bot tokens have changed - reconnecting") or latest_slack_bot_tokens != slack_bot_tokens[tenant_id]
):
if tenant_id in slack_bot_tokens:
logger.notice(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else: else:
# This happens on the very first time the listener process comes up # Initial setup for this tenant
# or the tokens have updated (set up for the first time) search_settings = get_current_search_settings(
with Session(get_sqlalchemy_engine()) as db_session: db_session
search_settings = get_current_search_settings(db_session) )
embedding_model = EmbeddingModel.from_db_model( embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings, search_settings=search_settings,
server_host=MODEL_SERVER_HOST, server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT, server_port=MODEL_SERVER_PORT,
) )
warm_up_bi_encoder(embedding_model=embedding_model)
warm_up_bi_encoder( slack_bot_tokens[tenant_id] = latest_slack_bot_tokens
embedding_model=embedding_model,
)
slack_bot_tokens = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated # potentially may cause a message to be dropped, but it is complicated
# to avoid + (1) if the user is changing tokens, they are likely okay with some # to avoid + (1) if the user is changing tokens, they are likely okay with some
# "migration downtime" and (2) if a single message is lost it is okay # "migration downtime" and (2) if a single message is lost it is okay
# as this should be a very rare occurrence # as this should be a very rare occurrence
if socket_client: if tenant_id in socket_clients:
socket_client.close() socket_clients[tenant_id].close()
socket_client = _get_socket_client(slack_bot_tokens) socket_client = _get_socket_client(
latest_slack_bot_tokens, tenant_id
)
_initialize_socket_client(socket_client) _initialize_socket_client(socket_client)
# Let the handlers run in the background + re-check for token updates every 60 seconds socket_clients[tenant_id] = socket_client
Event().wait(timeout=60)
except KvKeyNotFoundError: except KvKeyNotFoundError:
# try again every 30 seconds. This is needed since the user may add tokens logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
# via the UI at any point in the programs lifecycle - if we just allow it to if tenant_id in socket_clients:
# fail, then the user will need to restart the containers after adding tokens socket_clients[tenant_id].disconnect()
logger.debug( del socket_clients[tenant_id]
"Missing Slack Bot tokens - waiting 60 seconds and trying again" del slack_bot_tokens[tenant_id]
)
if socket_client: # Wait before checking for updates
socket_client.disconnect() Event().wait(timeout=60)
except Exception:
logger.exception("An error occurred outside of main event loop")
time.sleep(60) time.sleep(60)

View File

@ -12,7 +12,7 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError from slack_sdk.errors import SlackApiError
from slack_sdk.models.blocks import Block from slack_sdk.models.blocks import Block
from slack_sdk.models.metadata import Metadata from slack_sdk.models.metadata import Metadata
from sqlalchemy.orm import Session from slack_sdk.socket_mode import SocketModeClient
from danswer.configs.app_configs import DISABLE_TELEMETRY from danswer.configs.app_configs import DISABLE_TELEMETRY
from danswer.configs.constants import ID_SEPARATOR from danswer.configs.constants import ID_SEPARATOR
@ -31,7 +31,7 @@ from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.danswerbot.slack.constants import FeedbackVisibility from danswer.danswerbot.slack.constants import FeedbackVisibility
from danswer.danswerbot.slack.tokens import fetch_tokens from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_session_with_tenant
from danswer.db.users import get_user_by_email from danswer.db.users import get_user_by_email
from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms from danswer.llm.factory import get_default_llms
@ -489,7 +489,9 @@ def read_slack_thread(
return thread_messages return thread_messages
def slack_usage_report(action: str, sender_id: str | None, client: WebClient) -> None: def slack_usage_report(
action: str, sender_id: str | None, client: WebClient, tenant_id: str | None
) -> None:
if DISABLE_TELEMETRY: if DISABLE_TELEMETRY:
return return
@ -501,7 +503,7 @@ def slack_usage_report(action: str, sender_id: str | None, client: WebClient) ->
logger.warning("Unable to find sender email") logger.warning("Unable to find sender email")
if sender_email is not None: if sender_email is not None:
with Session(get_sqlalchemy_engine()) as db_session: with get_session_with_tenant(tenant_id) as db_session:
danswer_user = get_user_by_email(email=sender_email, db_session=db_session) danswer_user = get_user_by_email(email=sender_email, db_session=db_session)
optional_telemetry( optional_telemetry(
@ -577,3 +579,9 @@ def get_feedback_visibility() -> FeedbackVisibility:
return FeedbackVisibility(DANSWER_BOT_FEEDBACK_VISIBILITY.lower()) return FeedbackVisibility(DANSWER_BOT_FEEDBACK_VISIBILITY.lower())
except ValueError: except ValueError:
return FeedbackVisibility.PRIVATE return FeedbackVisibility.PRIVATE
class TenantSocketModeClient(SocketModeClient):
def __init__(self, tenant_id: str | None, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.tenant_id = tenant_id