mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-06 18:14:35 +02:00
Minor tenant ID improvements (#2850)
* add migration dockerfile * address edge case * k * k * k * nit * k * k * k * k * remove * k * add comment
This commit is contained in:
@@ -13,7 +13,7 @@ from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from danswer.background.celery.celery_app import get_all_tenant_ids
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
@@ -61,7 +61,7 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
create_schema = x_args.get("create_schema", "true").lower() == "true"
|
||||
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
|
||||
|
||||
if MULTI_TENANT and schema_name == "public":
|
||||
if MULTI_TENANT and schema_name == "public" and not upgrade_all_tenants:
|
||||
raise ValueError(
|
||||
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
|
||||
"Please specify a tenant-specific schema."
|
||||
|
@@ -9,6 +9,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_kv_store()
|
||||
|
||||
return cast(list, store.load(KV_USER_STORE_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
return list()
|
||||
|
@@ -316,6 +316,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
verify_email_in_whitelist(account_email, tenant_id)
|
||||
verify_email_domain(account_email)
|
||||
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
|
||||
self.user_db = tenant_user_db
|
||||
|
@@ -28,7 +28,6 @@ from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.background.celery.celery_utils import get_all_tenant_ids
|
||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
@@ -37,6 +36,7 @@ from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
|
@@ -211,6 +211,7 @@ def handle_regular_answer(
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=True,
|
||||
)
|
||||
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
|
@@ -7,7 +7,6 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
|
||||
from danswer.background.celery.celery_app import get_all_tenant_ids
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||
@@ -47,6 +46,7 @@ from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
|
||||
from danswer.danswerbot.slack.utils import rephrase_slack_message
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import TenantSocketModeClient
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
@@ -57,6 +57,7 @@ from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
@@ -345,7 +346,9 @@ def process_message(
|
||||
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
|
||||
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
|
||||
) -> None:
|
||||
logger.debug(f"Received Slack request of type: '{req.type}'")
|
||||
logger.debug(
|
||||
f"Received Slack request of type: '{req.type}' for tenant, {client.tenant_id}"
|
||||
)
|
||||
|
||||
# Throw out requests that can't or shouldn't be handled
|
||||
if not prefilter_requests(req, client):
|
||||
@@ -357,51 +360,59 @@ def process_message(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_bot_config = get_slack_bot_config_for_channel(
|
||||
channel_name=channel_name, db_session=db_session
|
||||
)
|
||||
# Set the current tenant ID at the beginning for all DB calls within this thread
|
||||
if client.tenant_id:
|
||||
logger.info(f"Setting tenant ID to {client.tenant_id}")
|
||||
token = current_tenant_id.set(client.tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_bot_config = get_slack_bot_config_for_channel(
|
||||
channel_name=channel_name, db_session=db_session
|
||||
)
|
||||
|
||||
# Be careful about this default, don't want to accidentally spam every channel
|
||||
# Users should be able to DM slack bot in their private channels though
|
||||
if (
|
||||
slack_bot_config is None
|
||||
and not respond_every_channel
|
||||
# Can't have configs for DMs so don't toss them out
|
||||
and not is_dm
|
||||
# If /DanswerBot (is_bot_msg) or @DanswerBot (bypass_filters)
|
||||
# always respond with the default configs
|
||||
and not (details.is_bot_msg or details.bypass_filters)
|
||||
):
|
||||
return
|
||||
# Be careful about this default, don't want to accidentally spam every channel
|
||||
# Users should be able to DM slack bot in their private channels though
|
||||
if (
|
||||
slack_bot_config is None
|
||||
and not respond_every_channel
|
||||
# Can't have configs for DMs so don't toss them out
|
||||
and not is_dm
|
||||
# If /DanswerBot (is_bot_msg) or @DanswerBot (bypass_filters)
|
||||
# always respond with the default configs
|
||||
and not (details.is_bot_msg or details.bypass_filters)
|
||||
):
|
||||
return
|
||||
|
||||
follow_up = bool(
|
||||
slack_bot_config
|
||||
and slack_bot_config.channel_config
|
||||
and slack_bot_config.channel_config.get("follow_up_tags") is not None
|
||||
)
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
follow_up = bool(
|
||||
slack_bot_config
|
||||
and slack_bot_config.channel_config
|
||||
and slack_bot_config.channel_config.get("follow_up_tags") is not None
|
||||
)
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
slack_bot_config=slack_bot_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
slack_bot_config=slack_bot_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
|
||||
if failed:
|
||||
if feedback_reminder_id:
|
||||
remove_scheduled_feedback_reminder(
|
||||
client=client.web_client,
|
||||
channel=details.sender,
|
||||
msg_id=feedback_reminder_id,
|
||||
)
|
||||
# Skipping answering due to pre-filtering is not considered a failure
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
if failed:
|
||||
if feedback_reminder_id:
|
||||
remove_scheduled_feedback_reminder(
|
||||
client=client.web_client,
|
||||
channel=details.sender,
|
||||
msg_id=feedback_reminder_id,
|
||||
)
|
||||
# Skipping answering due to pre-filtering is not considered a failure
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
finally:
|
||||
if client.tenant_id:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
|
||||
@@ -499,7 +510,9 @@ if __name__ == "__main__":
|
||||
for tenant_id in tenant_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try:
|
||||
token = current_tenant_id.set(tenant_id or "public")
|
||||
latest_slack_bot_tokens = fetch_tokens()
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
if (
|
||||
tenant_id not in slack_bot_tokens
|
||||
@@ -533,6 +546,11 @@ if __name__ == "__main__":
|
||||
socket_client = _get_socket_client(
|
||||
latest_slack_bot_tokens, tenant_id
|
||||
)
|
||||
|
||||
# Initialize socket client for this tenant. Each tenant has its own
|
||||
# socket client, allowing for multiple concurrent connections (one
|
||||
# per tenant) with the tenant ID wrapped in the socket model client.
|
||||
# Each `connect` stores websocket connection in a separate thread.
|
||||
_initialize_socket_client(socket_client)
|
||||
|
||||
socket_clients[tenant_id] = socket_client
|
||||
|
@@ -38,6 +38,7 @@ from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.configs.constants import TENANT_ID_PREFIX
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
@@ -188,6 +189,29 @@ class SqlEngine:
|
||||
return cls._app_name
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
with get_session_with_tenant(tenant_id="public") as session:
|
||||
result = session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
|
||||
)
|
||||
)
|
||||
tenant_ids = [row[0] for row in result]
|
||||
|
||||
valid_tenants = [
|
||||
tenant
|
||||
for tenant in tenant_ids
|
||||
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
]
|
||||
|
||||
return valid_tenants
|
||||
|
||||
|
||||
def build_connection_string(
|
||||
*,
|
||||
db_api: str = ASYNC_DB_API,
|
||||
@@ -332,9 +356,8 @@ def get_session_with_tenant(
|
||||
cursor.close()
|
||||
|
||||
|
||||
def get_session_generator_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
||||
tenant_id = current_tenant_id.get()
|
||||
with get_session_with_tenant(tenant_id) as session:
|
||||
yield session
|
||||
|
||||
|
@@ -95,10 +95,11 @@ def upsert_llm_provider(
|
||||
group_ids=llm_provider.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
return full_llm_provider
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
|
@@ -18,7 +18,6 @@ from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
|
@@ -1,11 +1,11 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_utils import get_all_tenant_ids
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.db.chat import delete_chat_sessions_older_than
|
||||
from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
@@ -2,7 +2,8 @@ apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: env-configmap
|
||||
data:
|
||||
data:
|
||||
|
||||
# Auth Setting, also check the secrets file
|
||||
AUTH_TYPE: "disabled" # Change this for production uses unless Danswer is only accessible behind VPN
|
||||
ENCRYPTION_KEY_SECRET: "" # This should not be specified directly in the yaml, this is just for reference
|
||||
|
Reference in New Issue
Block a user