fix provisioning and don't spawn tasks which could result in a race condition (#4604)

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer 2025-04-24 19:41:05 -07:00 committed by GitHub
parent 13b71f559f
commit 672f3a1c34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 30 deletions

View File

@ -14,9 +14,8 @@ from ee.onyx.server.tenants.provisioning import setup_tenant
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import get_current_alembic_version from ee.onyx.server.tenants.schema_management import get_current_alembic_version
from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks from onyx.configs.constants import OnyxRedisLocks
@ -39,7 +38,8 @@ _TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
name=OnyxCeleryTask.CLOUD_CHECK_AVAILABLE_TENANTS, name=OnyxCeleryTask.CLOUD_CHECK_AVAILABLE_TENANTS,
queue=OnyxCeleryQueues.MONITORING, queue=OnyxCeleryQueues.MONITORING,
ignore_result=True, ignore_result=True,
soft_time_limit=JOB_TIMEOUT, soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
time_limit=_TENANT_PROVISIONING_TIME_LIMIT,
trail=False, trail=False,
bind=True, bind=True,
) )
@ -55,7 +55,7 @@ def check_available_tenants(self: Task) -> None:
) )
return return
r = get_redis_client() r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_check: RedisLock = r.lock( lock_check: RedisLock = r.lock(
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK, OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT, timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
@ -71,32 +71,28 @@ def check_available_tenants(self: Task) -> None:
try: try:
# Get the current count of available tenants # Get the current count of available tenants
with get_session_with_shared_schema() as db_session: with get_session_with_shared_schema() as db_session:
available_tenants_count = db_session.query(AvailableTenant).count() num_available_tenants = db_session.query(AvailableTenant).count()
# Get the target number of available tenants # Get the target number of available tenants
target_available_tenants = getattr( num_minimum_available_tenants = getattr(
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
) )
# Calculate how many new tenants we need to provision # Calculate how many new tenants we need to provision
tenants_to_provision = max( if num_available_tenants < num_minimum_available_tenants:
0, target_available_tenants - available_tenants_count tenants_to_provision = num_minimum_available_tenants - num_available_tenants
) else:
tenants_to_provision = 0
task_logger.info( task_logger.info(
f"Available tenants: {available_tenants_count}, " f"Available tenants: {num_available_tenants}, "
f"Target: {target_available_tenants}, " f"Target minimum available tenants: {num_minimum_available_tenants}, "
f"To provision: {tenants_to_provision}" f"To provision: {tenants_to_provision}"
) )
# Trigger pre-provisioning tasks for each tenant needed # just provision one tenant each time we run this ... increase if needed.
for _ in range(tenants_to_provision): if tenants_to_provision > 0:
from celery import current_app pre_provision_tenant()
current_app.send_task(
OnyxCeleryTask.PRE_PROVISION_TENANT,
priority=OnyxCeleryPriority.LOW,
)
except Exception: except Exception:
task_logger.exception("Error in check_available_tenants task") task_logger.exception("Error in check_available_tenants task")
@ -105,15 +101,7 @@ def check_available_tenants(self: Task) -> None:
lock_check.release() lock_check.release()
@shared_task( def pre_provision_tenant() -> None:
name=OnyxCeleryTask.PRE_PROVISION_TENANT,
ignore_result=True,
soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
time_limit=_TENANT_PROVISIONING_TIME_LIMIT,
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def pre_provision_tenant(self: Task) -> None:
""" """
Pre-provision a new tenant and store it in the NewAvailableTenant table. Pre-provision a new tenant and store it in the NewAvailableTenant table.
This function fully sets up the tenant with all necessary configurations, This function fully sets up the tenant with all necessary configurations,
@ -122,7 +110,7 @@ def pre_provision_tenant(self: Task) -> None:
# The MULTI_TENANT check is now done at the caller level (check_available_tenants) # The MULTI_TENANT check is now done at the caller level (check_available_tenants)
# rather than inside this function # rather than inside this function
r = get_redis_client() r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_provision: RedisLock = r.lock( lock_provision: RedisLock = r.lock(
OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK, OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT, timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,

View File

@ -406,7 +406,6 @@ class OnyxCeleryTask:
) )
# Tenant pre-provisioning # Tenant pre-provisioning
PRE_PROVISION_TENANT = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_pre_provision_tenant"
UPDATE_USER_FILE_FOLDER_METADATA = "update_user_file_folder_metadata" UPDATE_USER_FILE_FOLDER_METADATA = "update_user_file_folder_metadata"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task" CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"