mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-25 04:13:25 +02:00
Tenants on standby (#4218)
* add tenants on standby feature * k * fix alembic * k * k
This commit is contained in:
@@ -0,0 +1,33 @@
|
|||||||
|
"""add new available tenant table
|
||||||
|
|
||||||
|
Revision ID: 3b45e0018bf1
|
||||||
|
Revises: ac842f85f932
|
||||||
|
Create Date: 2025-03-06 09:55:18.229910
|
||||||
|
|
||||||
|
"""
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "3b45e0018bf1"
|
||||||
|
down_revision = "ac842f85f932"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Create new_available_tenant table
|
||||||
|
op.create_table(
|
||||||
|
"available_tenant",
|
||||||
|
sa.Column("tenant_id", sa.String(), nullable=False),
|
||||||
|
sa.Column("alembic_version", sa.String(), nullable=False),
|
||||||
|
sa.Column("date_created", sa.DateTime(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("tenant_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Drop new_available_tenant table
|
||||||
|
op.drop_table("available_tenant")
|
@@ -28,11 +28,12 @@ from onyx.auth.users import exceptions
|
|||||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||||
from onyx.configs.app_configs import DEV_MODE
|
from onyx.configs.app_configs import DEV_MODE
|
||||||
from onyx.configs.constants import MilestoneRecordType
|
from onyx.configs.constants import MilestoneRecordType
|
||||||
|
from onyx.db.engine import get_session_with_shared_schema
|
||||||
from onyx.db.engine import get_session_with_tenant
|
from onyx.db.engine import get_session_with_tenant
|
||||||
from onyx.db.engine import get_sqlalchemy_engine
|
|
||||||
from onyx.db.llm import update_default_provider
|
from onyx.db.llm import update_default_provider
|
||||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||||
from onyx.db.llm import upsert_llm_provider
|
from onyx.db.llm import upsert_llm_provider
|
||||||
|
from onyx.db.models import AvailableTenant
|
||||||
from onyx.db.models import IndexModelStatus
|
from onyx.db.models import IndexModelStatus
|
||||||
from onyx.db.models import SearchSettings
|
from onyx.db.models import SearchSettings
|
||||||
from onyx.db.models import UserTenantMapping
|
from onyx.db.models import UserTenantMapping
|
||||||
@@ -62,42 +63,72 @@ async def get_or_provision_tenant(
|
|||||||
This function should only be called after we have verified we want this user's tenant to exist.
|
This function should only be called after we have verified we want this user's tenant to exist.
|
||||||
It returns the tenant ID associated with the email, creating a new tenant if necessary.
|
It returns the tenant ID associated with the email, creating a new tenant if necessary.
|
||||||
"""
|
"""
|
||||||
|
# Early return for non-multi-tenant mode
|
||||||
if not MULTI_TENANT:
|
if not MULTI_TENANT:
|
||||||
return POSTGRES_DEFAULT_SCHEMA
|
return POSTGRES_DEFAULT_SCHEMA
|
||||||
|
|
||||||
if referral_source and request:
|
if referral_source and request:
|
||||||
await submit_to_hubspot(email, referral_source, request)
|
await submit_to_hubspot(email, referral_source, request)
|
||||||
|
|
||||||
|
# First, check if the user already has a tenant
|
||||||
|
tenant_id: str | None = None
|
||||||
try:
|
try:
|
||||||
tenant_id = get_tenant_id_for_email(email)
|
tenant_id = get_tenant_id_for_email(email)
|
||||||
except exceptions.UserNotExists:
|
|
||||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
|
||||||
try:
|
|
||||||
tenant_id = await create_tenant(email, referral_source)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Tenant provisioning failed: {e}")
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
|
||||||
|
|
||||||
if not tenant_id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401, detail="User does not belong to an organization"
|
|
||||||
)
|
|
||||||
|
|
||||||
return tenant_id
|
return tenant_id
|
||||||
|
except exceptions.UserNotExists:
|
||||||
|
# User doesn't exist, so we need to create a new tenant or assign an existing one
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to get a pre-provisioned tenant
|
||||||
|
tenant_id = await get_available_tenant()
|
||||||
|
|
||||||
|
if tenant_id:
|
||||||
|
# If we have a pre-provisioned tenant, assign it to the user
|
||||||
|
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||||
|
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||||
|
return tenant_id
|
||||||
|
else:
|
||||||
|
# If no pre-provisioned tenant is available, create a new one on-demand
|
||||||
|
tenant_id = await create_tenant(email, referral_source)
|
||||||
|
return tenant_id
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# If we've encountered an error, log and raise an exception
|
||||||
|
error_msg = "Failed to provision tenant"
|
||||||
|
logger.error(error_msg, exc_info=e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail="Failed to provision tenant. Please try again later.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||||
|
"""
|
||||||
|
Create a new tenant on-demand when no pre-provisioned tenants are available.
|
||||||
|
This is the fallback method when we can't use a pre-provisioned tenant.
|
||||||
|
|
||||||
|
"""
|
||||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||||
|
logger.info(f"Creating new tenant {tenant_id} for user {email}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Provision tenant on data plane
|
# Provision tenant on data plane
|
||||||
await provision_tenant(tenant_id, email)
|
await provision_tenant(tenant_id, email)
|
||||||
# Notify control plane
|
|
||||||
if not DEV_MODE:
|
# Notify control plane if not already done in provision_tenant
|
||||||
|
if not DEV_MODE and referral_source:
|
||||||
await notify_control_plane(tenant_id, email, referral_source)
|
await notify_control_plane(tenant_id, email, referral_source)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Tenant provisioning failed: {e}")
|
logger.exception(f"Tenant provisioning failed: {str(e)}")
|
||||||
|
# Attempt to rollback the tenant provisioning
|
||||||
|
try:
|
||||||
await rollback_tenant_provisioning(tenant_id)
|
await rollback_tenant_provisioning(tenant_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||||
|
|
||||||
return tenant_id
|
return tenant_id
|
||||||
|
|
||||||
|
|
||||||
@@ -111,54 +142,25 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
|
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
|
||||||
token = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Create the schema for the tenant
|
||||||
if not create_schema_if_not_exists(tenant_id):
|
if not create_schema_if_not_exists(tenant_id):
|
||||||
logger.debug(f"Created schema for tenant {tenant_id}")
|
logger.debug(f"Created schema for tenant {tenant_id}")
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Schema already exists for tenant {tenant_id}")
|
logger.debug(f"Schema already exists for tenant {tenant_id}")
|
||||||
|
|
||||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
# Set up the tenant with all necessary configurations
|
||||||
|
await setup_tenant(tenant_id)
|
||||||
|
|
||||||
# Await the Alembic migrations
|
# Assign the tenant to the user
|
||||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
await assign_tenant_to_user(tenant_id, email)
|
||||||
|
|
||||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
|
||||||
configure_default_api_keys(db_session)
|
|
||||||
|
|
||||||
current_search_settings = (
|
|
||||||
db_session.query(SearchSettings)
|
|
||||||
.filter_by(status=IndexModelStatus.FUTURE)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
cohere_enabled = (
|
|
||||||
current_search_settings is not None
|
|
||||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
|
||||||
)
|
|
||||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
|
||||||
|
|
||||||
add_users_to_tenant([email], tenant_id)
|
|
||||||
|
|
||||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
|
||||||
create_milestone_and_report(
|
|
||||||
user=None,
|
|
||||||
distinct_id=tenant_id,
|
|
||||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
|
||||||
properties={
|
|
||||||
"email": email,
|
|
||||||
},
|
|
||||||
db_session=db_session,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
if token is not None:
|
|
||||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
|
||||||
|
|
||||||
|
|
||||||
async def notify_control_plane(
|
async def notify_control_plane(
|
||||||
@@ -189,20 +191,74 @@ async def notify_control_plane(
|
|||||||
|
|
||||||
|
|
||||||
async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||||
# Logic to rollback tenant provisioning on data plane
|
"""
|
||||||
|
Logic to rollback tenant provisioning on data plane.
|
||||||
|
Handles each step independently to ensure maximum cleanup even if some steps fail.
|
||||||
|
"""
|
||||||
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
|
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
|
||||||
try:
|
|
||||||
# Drop the tenant's schema to rollback provisioning
|
|
||||||
drop_schema(tenant_id)
|
|
||||||
|
|
||||||
# Remove tenant mapping
|
# Track if any part of the rollback fails
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
rollback_errors = []
|
||||||
|
|
||||||
|
# 1. Try to drop the tenant's schema
|
||||||
|
try:
|
||||||
|
drop_schema(tenant_id)
|
||||||
|
logger.info(f"Successfully dropped schema for tenant {tenant_id}")
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to drop schema for tenant {tenant_id}: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
rollback_errors.append(error_msg)
|
||||||
|
|
||||||
|
# 2. Try to remove tenant mapping
|
||||||
|
try:
|
||||||
|
with get_session_with_shared_schema() as db_session:
|
||||||
|
db_session.begin()
|
||||||
|
try:
|
||||||
db_session.query(UserTenantMapping).filter(
|
db_session.query(UserTenantMapping).filter(
|
||||||
UserTenantMapping.tenant_id == tenant_id
|
UserTenantMapping.tenant_id == tenant_id
|
||||||
).delete()
|
).delete()
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
logger.info(
|
||||||
|
f"Successfully removed user mappings for tenant {tenant_id}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to rollback tenant provisioning: {e}")
|
db_session.rollback()
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to remove user mappings for tenant {tenant_id}: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
rollback_errors.append(error_msg)
|
||||||
|
|
||||||
|
# 3. If this tenant was in the available tenants table, remove it
|
||||||
|
try:
|
||||||
|
with get_session_with_shared_schema() as db_session:
|
||||||
|
db_session.begin()
|
||||||
|
try:
|
||||||
|
available_tenant = (
|
||||||
|
db_session.query(AvailableTenant)
|
||||||
|
.filter(AvailableTenant.tenant_id == tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if available_tenant:
|
||||||
|
db_session.delete(available_tenant)
|
||||||
|
db_session.commit()
|
||||||
|
logger.info(
|
||||||
|
f"Removed tenant {tenant_id} from available tenants table"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
db_session.rollback()
|
||||||
|
raise e
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Failed to remove tenant {tenant_id} from available tenants table: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
rollback_errors.append(error_msg)
|
||||||
|
|
||||||
|
# Log summary of rollback operation
|
||||||
|
if rollback_errors:
|
||||||
|
logger.error(f"Tenant rollback completed with {len(rollback_errors)} errors")
|
||||||
|
else:
|
||||||
|
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
|
||||||
|
|
||||||
|
|
||||||
def configure_default_api_keys(db_session: Session) -> None:
|
def configure_default_api_keys(db_session: Session) -> None:
|
||||||
@@ -399,3 +455,111 @@ def get_tenant_by_domain_from_control_plane(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching tenant by domain: {str(e)}")
|
logger.error(f"Error fetching tenant by domain: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_available_tenant() -> str | None:
|
||||||
|
"""
|
||||||
|
Get an available pre-provisioned tenant from the NewAvailableTenant table.
|
||||||
|
Returns the tenant_id if one is available, None otherwise.
|
||||||
|
Uses row-level locking to prevent race conditions when multiple processes
|
||||||
|
try to get an available tenant simultaneously.
|
||||||
|
"""
|
||||||
|
if not MULTI_TENANT:
|
||||||
|
return None
|
||||||
|
|
||||||
|
with get_session_with_shared_schema() as db_session:
|
||||||
|
try:
|
||||||
|
db_session.begin()
|
||||||
|
|
||||||
|
# Get the oldest available tenant with FOR UPDATE lock to prevent race conditions
|
||||||
|
available_tenant = (
|
||||||
|
db_session.query(AvailableTenant)
|
||||||
|
.order_by(AvailableTenant.date_created)
|
||||||
|
.with_for_update(skip_locked=True) # Skip locked rows to avoid blocking
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if available_tenant:
|
||||||
|
tenant_id = available_tenant.tenant_id
|
||||||
|
# Remove the tenant from the available tenants table
|
||||||
|
db_session.delete(available_tenant)
|
||||||
|
db_session.commit()
|
||||||
|
logger.info(f"Using pre-provisioned tenant {tenant_id}")
|
||||||
|
return tenant_id
|
||||||
|
else:
|
||||||
|
db_session.rollback()
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error getting available tenant")
|
||||||
|
db_session.rollback()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_tenant(tenant_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Set up a tenant with all necessary configurations.
|
||||||
|
This is a centralized function that handles all tenant setup logic.
|
||||||
|
"""
|
||||||
|
token = None
|
||||||
|
try:
|
||||||
|
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||||
|
|
||||||
|
# Run Alembic migrations
|
||||||
|
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||||
|
|
||||||
|
# Configure the tenant with default settings
|
||||||
|
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||||
|
# Configure default API keys
|
||||||
|
configure_default_api_keys(db_session)
|
||||||
|
|
||||||
|
# Set up Onyx with appropriate settings
|
||||||
|
current_search_settings = (
|
||||||
|
db_session.query(SearchSettings)
|
||||||
|
.filter_by(status=IndexModelStatus.FUTURE)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
cohere_enabled = (
|
||||||
|
current_search_settings is not None
|
||||||
|
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||||
|
)
|
||||||
|
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to set up tenant {tenant_id}")
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
if token is not None:
|
||||||
|
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
async def assign_tenant_to_user(
|
||||||
|
tenant_id: str, email: str, referral_source: str | None = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Assign a tenant to a user and perform necessary operations.
|
||||||
|
Uses transaction handling to ensure atomicity and includes retry logic
|
||||||
|
for control plane notifications.
|
||||||
|
"""
|
||||||
|
# First, add the user to the tenant in a transaction
|
||||||
|
|
||||||
|
try:
|
||||||
|
add_users_to_tenant([email], tenant_id)
|
||||||
|
|
||||||
|
# Create milestone record in the same transaction context as the tenant assignment
|
||||||
|
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||||
|
create_milestone_and_report(
|
||||||
|
user=None,
|
||||||
|
distinct_id=tenant_id,
|
||||||
|
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||||
|
properties={
|
||||||
|
"email": email,
|
||||||
|
},
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||||
|
raise Exception("Failed to assign tenant to user")
|
||||||
|
|
||||||
|
# Notify control plane with retry logic
|
||||||
|
if not DEV_MODE:
|
||||||
|
await notify_control_plane(tenant_id, email, referral_source)
|
||||||
|
@@ -74,3 +74,21 @@ def drop_schema(tenant_id: str) -> None:
|
|||||||
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
|
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
|
||||||
{"schema_name": tenant_id},
|
{"schema_name": tenant_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_alembic_version(tenant_id: str) -> str:
|
||||||
|
"""Get the current Alembic version for a tenant."""
|
||||||
|
from alembic.runtime.migration import MigrationContext
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
engine = get_sqlalchemy_engine()
|
||||||
|
|
||||||
|
# Set the search path to the tenant's schema
|
||||||
|
with engine.connect() as connection:
|
||||||
|
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
|
||||||
|
|
||||||
|
# Get the current version from the alembic_version table
|
||||||
|
context = MigrationContext.configure(connection)
|
||||||
|
current_rev = context.get_current_revision()
|
||||||
|
|
||||||
|
return current_rev or "head"
|
||||||
|
@@ -67,15 +67,39 @@ def user_owns_a_tenant(email: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Add users to a tenant with proper transaction handling.
|
||||||
|
Checks if users already have a tenant mapping to avoid duplicates.
|
||||||
|
"""
|
||||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||||
try:
|
try:
|
||||||
|
# Start a transaction
|
||||||
|
db_session.begin()
|
||||||
|
|
||||||
for email in emails:
|
for email in emails:
|
||||||
db_session.add(
|
# Check if the user already has a mapping to this tenant
|
||||||
UserTenantMapping(email=email, tenant_id=tenant_id, active=False)
|
existing_mapping = (
|
||||||
|
db_session.query(UserTenantMapping)
|
||||||
|
.filter(
|
||||||
|
UserTenantMapping.email == email,
|
||||||
|
UserTenantMapping.tenant_id == tenant_id,
|
||||||
)
|
)
|
||||||
|
.with_for_update()
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not existing_mapping:
|
||||||
|
# Only add if mapping doesn't exist
|
||||||
|
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||||
|
|
||||||
|
# Commit the transaction
|
||||||
|
db_session.commit()
|
||||||
|
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
||||||
db_session.commit()
|
db_session.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||||
|
@@ -112,5 +112,6 @@ celery_app.autodiscover_tasks(
|
|||||||
"onyx.background.celery.tasks.connector_deletion",
|
"onyx.background.celery.tasks.connector_deletion",
|
||||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||||
"onyx.background.celery.tasks.indexing",
|
"onyx.background.celery.tasks.indexing",
|
||||||
|
"onyx.background.celery.tasks.tenant_provisioning",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@@ -92,5 +92,6 @@ def on_setup_logging(
|
|||||||
celery_app.autodiscover_tasks(
|
celery_app.autodiscover_tasks(
|
||||||
[
|
[
|
||||||
"onyx.background.celery.tasks.monitoring",
|
"onyx.background.celery.tasks.monitoring",
|
||||||
|
"onyx.background.celery.tasks.tenant_provisioning",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@@ -167,6 +167,16 @@ beat_cloud_tasks: list[dict] = [
|
|||||||
"expires": BEAT_EXPIRES_DEFAULT,
|
"expires": BEAT_EXPIRES_DEFAULT,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
|
||||||
|
"task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
|
||||||
|
"schedule": timedelta(minutes=10),
|
||||||
|
"options": {
|
||||||
|
"queue": OnyxCeleryQueues.MONITORING,
|
||||||
|
"priority": OnyxCeleryPriority.HIGH,
|
||||||
|
"expires": BEAT_EXPIRES_DEFAULT,
|
||||||
|
},
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
# tasks that only run self hosted
|
# tasks that only run self hosted
|
||||||
|
@@ -0,0 +1,199 @@
|
|||||||
|
"""
|
||||||
|
Periodic tasks for tenant pre-provisioning.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from celery import shared_task
|
||||||
|
from celery import Task
|
||||||
|
from redis.lock import Lock as RedisLock
|
||||||
|
|
||||||
|
from ee.onyx.server.tenants.provisioning import setup_tenant
|
||||||
|
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||||
|
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
|
||||||
|
from onyx.background.celery.apps.app_base import task_logger
|
||||||
|
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||||
|
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
|
||||||
|
from onyx.configs.constants import OnyxCeleryPriority
|
||||||
|
from onyx.configs.constants import OnyxCeleryQueues
|
||||||
|
from onyx.configs.constants import OnyxCeleryTask
|
||||||
|
from onyx.configs.constants import OnyxRedisLocks
|
||||||
|
from onyx.db.engine import get_session_with_shared_schema
|
||||||
|
from onyx.db.models import AvailableTenant
|
||||||
|
from onyx.redis.redis_pool import get_redis_client
|
||||||
|
from shared_configs.configs import MULTI_TENANT
|
||||||
|
from shared_configs.configs import TENANT_ID_PREFIX
|
||||||
|
|
||||||
|
# Default number of pre-provisioned tenants to maintain
|
||||||
|
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
|
||||||
|
|
||||||
|
# Soft time limit for tenant pre-provisioning tasks (in seconds)
|
||||||
|
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||||
|
# Hard time limit for tenant pre-provisioning tasks (in seconds)
|
||||||
|
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(
|
||||||
|
name=OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
|
||||||
|
queue=OnyxCeleryQueues.MONITORING,
|
||||||
|
ignore_result=True,
|
||||||
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
|
trail=False,
|
||||||
|
bind=True,
|
||||||
|
)
|
||||||
|
def check_available_tenants(self: Task) -> None:
|
||||||
|
"""
|
||||||
|
Check if we have enough pre-provisioned tenants available.
|
||||||
|
If not, trigger the pre-provisioning of new tenants.
|
||||||
|
"""
|
||||||
|
task_logger.info("STARTING CHECK_AVAILABLE_TENANTS")
|
||||||
|
if not MULTI_TENANT:
|
||||||
|
task_logger.info(
|
||||||
|
"Multi-tenancy is not enabled, skipping tenant pre-provisioning"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
r = get_redis_client()
|
||||||
|
lock_check: RedisLock = r.lock(
|
||||||
|
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
|
||||||
|
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# These tasks should never overlap
|
||||||
|
if not lock_check.acquire(blocking=False):
|
||||||
|
task_logger.info(
|
||||||
|
"Skipping check_available_tenants task because it is already running"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get the current count of available tenants
|
||||||
|
with get_session_with_shared_schema() as db_session:
|
||||||
|
available_tenants_count = db_session.query(AvailableTenant).count()
|
||||||
|
|
||||||
|
# Get the target number of available tenants
|
||||||
|
target_available_tenants = getattr(
|
||||||
|
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate how many new tenants we need to provision
|
||||||
|
tenants_to_provision = max(
|
||||||
|
0, target_available_tenants - available_tenants_count
|
||||||
|
)
|
||||||
|
|
||||||
|
task_logger.info(
|
||||||
|
f"Available tenants: {available_tenants_count}, "
|
||||||
|
f"Target: {target_available_tenants}, "
|
||||||
|
f"To provision: {tenants_to_provision}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger pre-provisioning tasks for each tenant needed
|
||||||
|
for _ in range(tenants_to_provision):
|
||||||
|
from celery import current_app
|
||||||
|
|
||||||
|
current_app.send_task(
|
||||||
|
OnyxCeleryTask.PRE_PROVISION_TENANT,
|
||||||
|
priority=OnyxCeleryPriority.LOW,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
task_logger.exception("Error in check_available_tenants task")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
lock_check.release()
|
||||||
|
|
||||||
|
|
||||||
|
@shared_task(
|
||||||
|
name=OnyxCeleryTask.PRE_PROVISION_TENANT,
|
||||||
|
ignore_result=True,
|
||||||
|
soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||||
|
time_limit=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||||
|
queue=OnyxCeleryQueues.MONITORING,
|
||||||
|
bind=True,
|
||||||
|
)
|
||||||
|
def pre_provision_tenant(self: Task) -> None:
|
||||||
|
"""
|
||||||
|
Pre-provision a new tenant and store it in the NewAvailableTenant table.
|
||||||
|
This function fully sets up the tenant with all necessary configurations,
|
||||||
|
so it's ready to be assigned to a user immediately.
|
||||||
|
"""
|
||||||
|
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
|
||||||
|
# rather than inside this function
|
||||||
|
|
||||||
|
r = get_redis_client()
|
||||||
|
lock_provision: RedisLock = r.lock(
|
||||||
|
OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK,
|
||||||
|
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
|
||||||
|
if not lock_provision.acquire(blocking=False):
|
||||||
|
task_logger.debug(
|
||||||
|
"Skipping pre_provision_tenant task because it is already running"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
tenant_id: str | None = None
|
||||||
|
try:
|
||||||
|
# Generate a new tenant ID
|
||||||
|
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||||
|
task_logger.info(f"Pre-provisioning tenant: {tenant_id}")
|
||||||
|
|
||||||
|
# Create the schema for the new tenant
|
||||||
|
schema_created = create_schema_if_not_exists(tenant_id)
|
||||||
|
if schema_created:
|
||||||
|
task_logger.debug(f"Created schema for tenant: {tenant_id}")
|
||||||
|
else:
|
||||||
|
task_logger.debug(f"Schema already exists for tenant: {tenant_id}")
|
||||||
|
|
||||||
|
# Set up the tenant with all necessary configurations
|
||||||
|
task_logger.debug(f"Setting up tenant configuration: {tenant_id}")
|
||||||
|
asyncio.run(setup_tenant(tenant_id))
|
||||||
|
task_logger.debug(f"Tenant configuration completed: {tenant_id}")
|
||||||
|
|
||||||
|
# Get the current Alembic version
|
||||||
|
alembic_version = get_current_alembic_version(tenant_id)
|
||||||
|
task_logger.debug(
|
||||||
|
f"Tenant {tenant_id} using Alembic version: {alembic_version}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store the pre-provisioned tenant in the database
|
||||||
|
task_logger.debug(f"Storing pre-provisioned tenant in database: {tenant_id}")
|
||||||
|
with get_session_with_shared_schema() as db_session:
|
||||||
|
# Use a transaction to ensure atomicity
|
||||||
|
db_session.begin()
|
||||||
|
try:
|
||||||
|
new_tenant = AvailableTenant(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
alembic_version=alembic_version,
|
||||||
|
date_created=datetime.datetime.now(),
|
||||||
|
)
|
||||||
|
db_session.add(new_tenant)
|
||||||
|
db_session.commit()
|
||||||
|
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
|
||||||
|
except Exception:
|
||||||
|
db_session.rollback()
|
||||||
|
task_logger.error(
|
||||||
|
f"Failed to store pre-provisioned tenant: {tenant_id}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
task_logger.error("Error in pre_provision_tenant task", exc_info=True)
|
||||||
|
# If we have a tenant_id, attempt to rollback any partially completed provisioning
|
||||||
|
if tenant_id:
|
||||||
|
task_logger.info(
|
||||||
|
f"Rolling back failed tenant provisioning for: {tenant_id}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
from ee.onyx.server.tenants.provisioning import (
|
||||||
|
rollback_tenant_provisioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(rollback_tenant_provisioning(tenant_id))
|
||||||
|
except Exception:
|
||||||
|
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
|
||||||
|
finally:
|
||||||
|
lock_provision.release()
|
@@ -643,3 +643,6 @@ MOCK_LLM_RESPONSE = (
|
|||||||
|
|
||||||
|
|
||||||
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
|
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
|
||||||
|
|
||||||
|
# Number of pre-provisioned tenants to maintain
|
||||||
|
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
|
||||||
|
@@ -322,6 +322,8 @@ class OnyxRedisLocks:
|
|||||||
"da_lock:check_connector_external_group_sync_beat"
|
"da_lock:check_connector_external_group_sync_beat"
|
||||||
)
|
)
|
||||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||||
|
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
|
||||||
|
PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
|
||||||
|
|
||||||
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
||||||
"da_lock:connector_doc_permissions_sync"
|
"da_lock:connector_doc_permissions_sync"
|
||||||
@@ -384,6 +386,7 @@ class OnyxCeleryTask:
|
|||||||
CLOUD_MONITOR_CELERY_QUEUES = (
|
CLOUD_MONITOR_CELERY_QUEUES = (
|
||||||
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
|
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
|
||||||
)
|
)
|
||||||
|
CHECK_AVAILABLE_TENANTS = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_available_tenants"
|
||||||
|
|
||||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||||
@@ -400,6 +403,9 @@ class OnyxCeleryTask:
|
|||||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||||
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
|
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
|
||||||
|
|
||||||
|
# Tenant pre-provisioning
|
||||||
|
PRE_PROVISION_TENANT = "pre_provision_tenant"
|
||||||
|
|
||||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||||
"connector_permission_sync_generator_task"
|
"connector_permission_sync_generator_task"
|
||||||
|
@@ -2309,6 +2309,17 @@ class UserTenantMapping(Base):
|
|||||||
return value.lower() if value else value
|
return value.lower() if value else value
|
||||||
|
|
||||||
|
|
||||||
|
class AvailableTenant(Base):
|
||||||
|
__tablename__ = "available_tenant"
|
||||||
|
"""
|
||||||
|
These entries will only exist ephemerally and are meant to be picked up by new users on registration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tenant_id: Mapped[str] = mapped_column(String, primary_key=True, nullable=False)
|
||||||
|
alembic_version: Mapped[str] = mapped_column(String, nullable=False)
|
||||||
|
date_created: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
# This is a mapping from tenant IDs to anonymous user paths
|
# This is a mapping from tenant IDs to anonymous user paths
|
||||||
class TenantAnonymousUserPath(Base):
|
class TenantAnonymousUserPath(Base):
|
||||||
__tablename__ = "tenant_anonymous_user_path"
|
__tablename__ = "tenant_anonymous_user_path"
|
||||||
|
Reference in New Issue
Block a user