Minor tenant ID improvements (#2850)

* add migration dockerfile

* address edge case

* k

* k

* k

* nit

* k

* k

* k

* k

* remove

* k

* add comment
This commit is contained in:
pablodanswer
2024-10-20 16:48:00 -07:00
committed by GitHub
parent 7ab0063dc6
commit a24b465663
11 changed files with 98 additions and 53 deletions

View File

@@ -13,7 +13,7 @@ from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from danswer.background.celery.celery_app import get_all_tenant_ids
from danswer.db.engine import get_all_tenant_ids
# Alembic Config object
config = context.config
@@ -61,7 +61,7 @@ def get_schema_options() -> tuple[str, bool, bool]:
create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
if MULTI_TENANT and schema_name == "public":
if MULTI_TENANT and schema_name == "public" and not upgrade_all_tenants:
raise ValueError(
"Cannot run default migrations in public schema when multi-tenancy is enabled. "
"Please specify a tenant-specific schema."

View File

@@ -9,6 +9,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
def get_invited_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_USER_STORE_KEY))
except KvKeyNotFoundError:
return list()

View File

@@ -316,6 +316,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_in_whitelist(account_email, tenant_id)
verify_email_domain(account_email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db

View File

@@ -28,7 +28,6 @@ from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.celery_utils import get_all_tenant_ids
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerRedisLocks
@@ -37,6 +36,7 @@ from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import SqlEngine
from danswer.db.search_settings import get_current_search_settings

View File

@@ -211,6 +211,7 @@ def handle_regular_answer(
use_citations=use_citations,
danswerbot_flow=True,
)
if not answer.error_msg:
return answer
else:

View File

@@ -7,7 +7,6 @@ from slack_sdk import WebClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from danswer.background.celery.celery_app import get_all_tenant_ids
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
@@ -47,6 +46,7 @@ from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import get_session_with_tenant
from danswer.db.search_settings import get_current_search_settings
from danswer.key_value_store.interface import KvKeyNotFoundError
@@ -57,6 +57,7 @@ from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import current_tenant_id
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SLACK_CHANNEL_ID
@@ -345,7 +346,9 @@ def process_message(
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
) -> None:
logger.debug(f"Received Slack request of type: '{req.type}'")
logger.debug(
f"Received Slack request of type: '{req.type}' for tenant, {client.tenant_id}"
)
# Throw out requests that can't or shouldn't be handled
if not prefilter_requests(req, client):
@@ -357,51 +360,59 @@ def process_message(
client=client.web_client, channel_id=channel
)
with get_session_with_tenant(client.tenant_id) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
# Set the current tenant ID at the beginning for all DB calls within this thread
if client.tenant_id:
logger.info(f"Setting tenant ID to {client.tenant_id}")
token = current_tenant_id.set(client.tenant_id)
try:
with get_session_with_tenant(client.tenant_id) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
# Be careful about this default, don't want to accidentally spam every channel
# Users should be able to DM slack bot in their private channels though
if (
slack_bot_config is None
and not respond_every_channel
# Can't have configs for DMs so don't toss them out
and not is_dm
# If /DanswerBot (is_bot_msg) or @DanswerBot (bypass_filters)
# always respond with the default configs
and not (details.is_bot_msg or details.bypass_filters)
):
return
# Be careful about this default, don't want to accidentally spam every channel
# Users should be able to DM slack bot in their private channels though
if (
slack_bot_config is None
and not respond_every_channel
# Can't have configs for DMs so don't toss them out
and not is_dm
# If /DanswerBot (is_bot_msg) or @DanswerBot (bypass_filters)
# always respond with the default configs
and not (details.is_bot_msg or details.bypass_filters)
):
return
follow_up = bool(
slack_bot_config
and slack_bot_config.channel_config
and slack_bot_config.channel_config.get("follow_up_tags") is not None
)
feedback_reminder_id = schedule_feedback_reminder(
details=details, client=client.web_client, include_followup=follow_up
)
follow_up = bool(
slack_bot_config
and slack_bot_config.channel_config
and slack_bot_config.channel_config.get("follow_up_tags") is not None
)
feedback_reminder_id = schedule_feedback_reminder(
details=details, client=client.web_client, include_followup=follow_up
)
failed = handle_message(
message_info=details,
slack_bot_config=slack_bot_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
tenant_id=client.tenant_id,
)
failed = handle_message(
message_info=details,
slack_bot_config=slack_bot_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
tenant_id=client.tenant_id,
)
if failed:
if feedback_reminder_id:
remove_scheduled_feedback_reminder(
client=client.web_client,
channel=details.sender,
msg_id=feedback_reminder_id,
)
# Skipping answering due to pre-filtering is not considered a failure
if notify_no_answer:
apologize_for_fail(details, client)
if failed:
if feedback_reminder_id:
remove_scheduled_feedback_reminder(
client=client.web_client,
channel=details.sender,
msg_id=feedback_reminder_id,
)
# Skipping answering due to pre-filtering is not considered a failure
if notify_no_answer:
apologize_for_fail(details, client)
finally:
if client.tenant_id:
current_tenant_id.reset(token)
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
@@ -499,7 +510,9 @@ if __name__ == "__main__":
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session:
try:
token = current_tenant_id.set(tenant_id or "public")
latest_slack_bot_tokens = fetch_tokens()
current_tenant_id.reset(token)
if (
tenant_id not in slack_bot_tokens
@@ -533,6 +546,11 @@ if __name__ == "__main__":
socket_client = _get_socket_client(
latest_slack_bot_tokens, tenant_id
)
# Initialize socket client for this tenant. Each tenant has its own
# socket client, allowing for multiple concurrent connections (one
# per tenant) with the tenant ID wrapped in the socket model client.
# Each `connect` stores websocket connection in a separate thread.
_initialize_socket_client(socket_client)
socket_clients[tenant_id] = socket_client

View File

@@ -38,6 +38,7 @@ from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.configs.constants import TENANT_ID_PREFIX
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
@@ -188,6 +189,29 @@ class SqlEngine:
return cls._app_name
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
return [None]
with get_session_with_tenant(tenant_id="public") as session:
result = session.execute(
text(
"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
)
)
tenant_ids = [row[0] for row in result]
valid_tenants = [
tenant
for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
]
return valid_tenants
def build_connection_string(
*,
db_api: str = ASYNC_DB_API,
@@ -332,9 +356,8 @@ def get_session_with_tenant(
cursor.close()
def get_session_generator_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
tenant_id = current_tenant_id.get()
with get_session_with_tenant(tenant_id) as session:
yield session

View File

@@ -95,10 +95,11 @@ def upsert_llm_provider(
group_ids=llm_provider.groups,
db_session=db_session,
)
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
return full_llm_provider
def fetch_existing_embedding_providers(

View File

@@ -18,7 +18,6 @@ from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
logger = setup_logger()

View File

@@ -1,11 +1,11 @@
from datetime import timedelta
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_utils import get_all_tenant_ids
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.chat import delete_chat_sessions_older_than
from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import get_session_with_tenant
from danswer.server.settings.store import load_settings
from danswer.utils.logger import setup_logger

View File

@@ -2,7 +2,8 @@ apiVersion: v1
kind: ConfigMap
metadata:
name: env-configmap
data:
data:
# Auth Setting, also check the secrets file
AUTH_TYPE: "disabled" # Change this for production uses unless Danswer is only accessible behind VPN
ENCRYPTION_KEY_SECRET: "" # This should not be specified directly in the yaml, this is just for reference