diff --git a/backend/ee/onyx/server/tenants/provisioning.py b/backend/ee/onyx/server/tenants/provisioning.py index 44003b892..1cce6f795 100644 --- a/backend/ee/onyx/server/tenants/provisioning.py +++ b/backend/ee/onyx/server/tenants/provisioning.py @@ -67,10 +67,12 @@ async def get_or_provision_tenant( if referral_source and request: await submit_to_hubspot(email, referral_source, request) + tenant_id: str | None = None try: # First, check if the user already has a tenant tenant_id = get_tenant_id_for_email(email) return tenant_id + except exceptions.UserNotExists: # User doesn't exist, so we need to create a new tenant or assign an existing one try: @@ -78,25 +80,8 @@ async def get_or_provision_tenant( tenant_id = await get_available_tenant() if tenant_id: - # If we have a pre-provisioned tenant, just add the user to it - add_users_to_tenant([email], tenant_id) - - # Create milestone record - with get_session_with_tenant(tenant_id=tenant_id) as db_session: - create_milestone_and_report( - user=None, - distinct_id=tenant_id, - event_type=MilestoneRecordType.TENANT_CREATED, - properties={ - "email": email, - }, - db_session=db_session, - ) - - # Notify control plane - if not DEV_MODE: - await notify_control_plane(tenant_id, email, referral_source) - + # If we have a pre-provisioned tenant, assign it to the user + await assign_tenant_to_user(tenant_id, email, referral_source) logger.info( f"Assigned pre-provisioned tenant {tenant_id} to user {email}" ) @@ -125,9 +110,11 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str: try: # Provision tenant on data plane await provision_tenant(tenant_id, email) - # Notify control plane - if not DEV_MODE: + + # Notify control plane if not already done in provision_tenant + if not DEV_MODE and referral_source: await notify_control_plane(tenant_id, email, referral_source) + except Exception as e: logger.error(f"Tenant provisioning failed: {e}") await rollback_tenant_provisioning(tenant_id) @@ -145,56 +132,25 @@ async def provision_tenant(tenant_id: str, email: str) -> None: ) logger.debug(f"Provisioning tenant {tenant_id} for user {email}") - token = None try: + # Create the schema for the tenant if not create_schema_if_not_exists(tenant_id): logger.debug(f"Created schema for tenant {tenant_id}") else: logger.debug(f"Schema already exists for tenant {tenant_id}") - token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + # Set up the tenant with all necessary configurations + await setup_tenant(tenant_id) - # Await the Alembic migrations - await asyncio.to_thread(run_alembic_migrations, tenant_id) - - with get_session_with_tenant(tenant_id=tenant_id) as db_session: - configure_default_api_keys(db_session) - - current_search_settings = ( - db_session.query(SearchSettings) - .filter_by(status=IndexModelStatus.FUTURE) - .first() - ) - cohere_enabled = ( - current_search_settings is not None - and current_search_settings.provider_type == EmbeddingProvider.COHERE - ) - setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled) - - # Add the user to the tenant - add_users_to_tenant([email], tenant_id) - - # Create milestone record - with get_session_with_tenant(tenant_id=tenant_id) as db_session: - create_milestone_and_report( - user=None, - distinct_id=tenant_id, - event_type=MilestoneRecordType.TENANT_CREATED, - properties={ - "email": email, - }, - db_session=db_session, - ) + # Assign the tenant to the user + await assign_tenant_to_user(tenant_id, email) except Exception as e: logger.exception(f"Failed to create tenant {tenant_id}") raise HTTPException( status_code=500, detail=f"Failed to create tenant: {str(e)}" ) - finally: - if token is not None: - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) async def notify_control_plane( @@ -422,3 +378,66 @@ async def get_available_tenant() -> str | None: logger.error(f"Error getting available tenant: {e}") return None + + +async def setup_tenant(tenant_id: str) -> None: + """ + Set up a tenant with all necessary configurations. + This is a centralized function that handles all tenant setup logic. + """ + token = None + try: + token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + + # Run Alembic migrations + await asyncio.to_thread(run_alembic_migrations, tenant_id) + + # Configure the tenant with default settings + with get_session_with_tenant(tenant_id=tenant_id) as db_session: + # Configure default API keys + configure_default_api_keys(db_session) + + # Set up Onyx with appropriate settings + current_search_settings = ( + db_session.query(SearchSettings) + .filter_by(status=IndexModelStatus.FUTURE) + .first() + ) + cohere_enabled = ( + current_search_settings is not None + and current_search_settings.provider_type == EmbeddingProvider.COHERE + ) + setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled) + + except Exception as e: + logger.exception(f"Failed to set up tenant {tenant_id}") + raise e + finally: + if token is not None: + CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + + +async def assign_tenant_to_user( + tenant_id: str, email: str, referral_source: str | None = None +) -> None: + """ + Assign a tenant to a user and perform necessary operations. + """ + # Add the user to the tenant + add_users_to_tenant([email], tenant_id) + + # Create milestone record + with get_session_with_tenant(tenant_id=tenant_id) as db_session: + create_milestone_and_report( + user=None, + distinct_id=tenant_id, + event_type=MilestoneRecordType.TENANT_CREATED, + properties={ + "email": email, + }, + db_session=db_session, + ) + + # Notify control plane + if not DEV_MODE: + await notify_control_plane(tenant_id, email, referral_source) diff --git a/backend/ee/onyx/server/tenants/schema_management.py b/backend/ee/onyx/server/tenants/schema_management.py index 7a83b0246..1e9952df9 100644 --- a/backend/ee/onyx/server/tenants/schema_management.py +++ b/backend/ee/onyx/server/tenants/schema_management.py @@ -85,7 +85,7 @@ def get_current_alembic_version(tenant_id: str) -> str: # Set the search path to the tenant's schema with engine.connect() as connection: - connection.execute(text(f"SET search_path TO {tenant_id}")) + connection.execute(text(f'SET search_path TO "{tenant_id}"')) # Get the current version from the alembic_version table context = MigrationContext.configure(connection) diff --git a/backend/onyx/background/celery/apps/light.py b/backend/onyx/background/celery/apps/light.py index b6b99ca4c..07b68963c 100644 --- a/backend/onyx/background/celery/apps/light.py +++ b/backend/onyx/background/celery/apps/light.py @@ -111,5 +111,6 @@ celery_app.autodiscover_tasks( "onyx.background.celery.tasks.vespa", "onyx.background.celery.tasks.connector_deletion", "onyx.background.celery.tasks.doc_permission_syncing", + "onyx.background.celery.tasks.periodic.tenant_provisioning", ] ) diff --git a/backend/onyx/background/celery/apps/primary.py b/backend/onyx/background/celery/apps/primary.py index 0d6679155..98bfd74d9 100644 --- a/backend/onyx/background/celery/apps/primary.py +++ b/backend/onyx/background/celery/apps/primary.py @@ -49,6 +49,19 @@ celery_app = Celery(__name__) celery_app.config_from_object("onyx.background.celery.configs.primary") celery_app.Task = app_base.TenantAwareTask # type: ignore [misc] +# Import tasks to ensure they are registered with Celery +import onyx.background.celery.tasks.connector_deletion.tasks # noqa +import onyx.background.celery.tasks.doc_permission_syncing.tasks # noqa +import onyx.background.celery.tasks.external_group_syncing.tasks # noqa +import onyx.background.celery.tasks.indexing.tasks # noqa +import onyx.background.celery.tasks.llm_model_update.tasks # noqa +import onyx.background.celery.tasks.monitoring.tasks # noqa +import onyx.background.celery.tasks.periodic.tasks # noqa +import onyx.background.celery.tasks.periodic.tenant_provisioning # noqa +import onyx.background.celery.tasks.pruning.tasks # noqa +import onyx.background.celery.tasks.shared.tasks # noqa +import onyx.background.celery.tasks.vespa.tasks # noqa + @signals.task_prerun.connect def on_task_prerun( diff --git a/backend/onyx/background/celery/tasks/beat_schedule.py b/backend/onyx/background/celery/tasks/beat_schedule.py index 5ca334b5e..4d5896291 100644 --- a/backend/onyx/background/celery/tasks/beat_schedule.py +++ b/backend/onyx/background/celery/tasks/beat_schedule.py @@ -20,7 +20,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds) # hack to slow down task dispatch in the cloud until # we have a better implementation (backpressure, etc) # Note that DynamicTenantScheduler can adjust the runtime value for this via Redis -CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0 +CLOUD_BEAT_MULTIPLIER_DEFAULT = 0.5 # tasks that run in either self-hosted on cloud beat_task_templates: list[dict] = [] @@ -100,15 +100,6 @@ beat_task_templates.extend( "queue": OnyxCeleryQueues.MONITORING, }, }, - { - "name": "check-available-tenants", - "task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS, - "schedule": timedelta(minutes=5), - "options": { - "priority": OnyxCeleryPriority.MEDIUM, - "expires": BEAT_EXPIRES_DEFAULT, - }, - }, ] ) @@ -176,6 +167,15 @@ beat_cloud_tasks: list[dict] = [ "expires": BEAT_EXPIRES_DEFAULT, }, }, + { + "name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants", + "task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS, + "schedule": timedelta(seconds=10), + "options": { + "priority": OnyxCeleryPriority.MEDIUM, + "expires": BEAT_EXPIRES_DEFAULT, + }, + }, ] # tasks that only run self hosted diff --git a/backend/onyx/background/celery/tasks/periodic/tenant_provisioning.py b/backend/onyx/background/celery/tasks/periodic/tenant_provisioning.py index 1981f0139..62dcb24cd 100644 --- a/backend/onyx/background/celery/tasks/periodic/tenant_provisioning.py +++ b/backend/onyx/background/celery/tasks/periodic/tenant_provisioning.py @@ -3,7 +3,6 @@ Periodic tasks for tenant pre-provisioning. """ import asyncio import datetime -import logging import uuid from celery import shared_task @@ -11,7 +10,7 @@ from celery import Task from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session -from onyx.background.celery.celery_utils import get_redis_client +from onyx.background.celery.apps.app_base import task_logger from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues @@ -19,9 +18,9 @@ from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks from onyx.db.engine import get_sqlalchemy_engine from onyx.db.models import NewAvailableTenant +from onyx.redis.redis_pool import get_redis_client from shared_configs.configs import MULTI_TENANT from shared_configs.configs import TENANT_ID_PREFIX -from shared_configs.enums import EmbeddingProvider # Default number of pre-provisioned tenants to maintain DEFAULT_TARGET_AVAILABLE_TENANTS = 5 @@ -31,8 +30,6 @@ _TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes # Hard time limit for tenant pre-provisioning tasks (in seconds) _TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes -logger = logging.getLogger(__name__) - @shared_task( name=OnyxCeleryTask.CHECK_AVAILABLE_TENANTS, @@ -47,8 +44,11 @@ def check_available_tenants(self: Task) -> None: Check if we have enough pre-provisioned tenants available. If not, trigger the pre-provisioning of new tenants. """ + task_logger.warning("STARTING CHECK_AVAILABLE_TENANTS") if not MULTI_TENANT: - logger.debug("Multi-tenancy is not enabled, skipping tenant pre-provisioning") + task_logger.warning( + "Multi-tenancy is not enabled, skipping tenant pre-provisioning" + ) return r = get_redis_client() @@ -59,7 +59,7 @@ def check_available_tenants(self: Task) -> None: # These tasks should never overlap if not lock_check.acquire(blocking=False): - logger.debug( + task_logger.warning( "Skipping check_available_tenants task because it is already running" ) return @@ -79,7 +79,7 @@ def check_available_tenants(self: Task) -> None: 0, target_available_tenants - available_tenants_count ) - logger.info( + task_logger.warning( f"Available tenants: {available_tenants_count}, " f"Target: {target_available_tenants}, " f"To provision: {tenants_to_provision}" @@ -92,7 +92,8 @@ def check_available_tenants(self: Task) -> None: ) except Exception as e: - logger.exception(f"Error in check_available_tenants task: {e}") + task_logger.exception(f"Error in check_available_tenants task: {e}") + finally: lock_check.release() @@ -111,10 +112,12 @@ def pre_provision_tenant(self: Task) -> None: This function fully sets up the tenant with all necessary configurations, so it's ready to be assigned to a user immediately. """ + task_logger.warning("STARTING PRE_PROVISION_TENANT") if not MULTI_TENANT: - logger.debug("Multi-tenancy is not enabled, skipping tenant pre-provisioning") + task_logger.warning( + "Multi-tenancy is not enabled, skipping tenant pre-provisioning" + ) return - r = get_redis_client() lock_provision: RedisLock = r.lock( OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK, @@ -123,75 +126,53 @@ def pre_provision_tenant(self: Task) -> None: # Allow multiple pre-provisioning tasks to run, but ensure they don't overlap if not lock_provision.acquire(blocking=False): - logger.debug("Skipping pre_provision_tenant task because it is already running") + task_logger.warning( + "Skipping pre_provision_tenant task because it is already running" + ) return try: # Generate a new tenant ID tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) - token = None + task_logger.warning(f"Starting pre-provisioning for tenant {tenant_id}") # Import here to avoid circular imports from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists - from ee.onyx.server.tenants.schema_management import run_alembic_migrations from ee.onyx.server.tenants.schema_management import get_current_alembic_version - from ee.onyx.server.tenants.provisioning import configure_default_api_keys - from onyx.setup import setup_onyx - from onyx.db.models import SearchSettings, IndexModelStatus - from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR - from onyx.db.engine import get_session_with_tenant + from ee.onyx.server.tenants.provisioning import setup_tenant # Create the schema for the new tenant - if not create_schema_if_not_exists(tenant_id): - logger.debug(f"Created schema for tenant {tenant_id}") + schema_created = create_schema_if_not_exists(tenant_id) + if schema_created: + task_logger.warning(f"Created schema for tenant '{tenant_id}'") else: - logger.debug(f"Schema already exists for tenant {tenant_id}") + task_logger.warning(f"Schema already exists for tenant '{tenant_id}'") - try: - # Set the tenant context - token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + # Set up the tenant with all necessary configurations + task_logger.warning(f"Setting up tenant configuration for '{tenant_id}'") + asyncio.run(setup_tenant(tenant_id)) + task_logger.warning(f"Tenant configuration completed for '{tenant_id}'") - # Run Alembic migrations - asyncio.run(asyncio.to_thread(run_alembic_migrations, tenant_id)) + # Get the current Alembic version + alembic_version = get_current_alembic_version(tenant_id) + task_logger.warning( + f"Tenant '{tenant_id}' using Alembic version: {alembic_version}" + ) - # Configure the tenant with default settings - with get_session_with_tenant(tenant_id=tenant_id) as db_session: - # Configure default API keys - configure_default_api_keys(db_session) + # Store the pre-provisioned tenant in the database + task_logger.warning(f"Storing pre-provisioned tenant '{tenant_id}' in database") + with Session(get_sqlalchemy_engine()) as db_session: + new_tenant = NewAvailableTenant( + tenant_id=tenant_id, + alembic_version=alembic_version, + date_created=datetime.datetime.now(), + ) + db_session.add(new_tenant) + db_session.commit() - # Set up Onyx with appropriate settings - current_search_settings = ( - db_session.query(SearchSettings) - .filter_by(status=IndexModelStatus.FUTURE) - .first() - ) - cohere_enabled = ( - current_search_settings is not None - and current_search_settings.provider_type - == EmbeddingProvider.COHERE - ) - setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled) - - # Get the current Alembic version - alembic_version = get_current_alembic_version(tenant_id) - - # Store the pre-provisioned tenant in the database - with Session(get_sqlalchemy_engine()) as db_session: - new_tenant = NewAvailableTenant( - tenant_id=tenant_id, - alembic_version=alembic_version, - date_created=datetime.datetime.now(), - ) - db_session.add(new_tenant) - db_session.commit() - - logger.info(f"Successfully pre-provisioned tenant {tenant_id}") - - finally: - if token is not None: - CURRENT_TENANT_ID_CONTEXTVAR.reset(token) + task_logger.warning(f"Successfully pre-provisioned tenant {tenant_id}") except Exception as e: - logger.exception(f"Error in pre_provision_tenant task: {e}") + task_logger.exception(f"Error in pre_provision_tenant task: {e}") finally: lock_provision.release() diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index e26631864..dcfc75eeb 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -1259,21 +1259,18 @@ export function createConnectorInitialValues( name: "", groups: [], access_type: "public", - ...configuration.values.reduce( - (acc, field) => { - if (field.type === "select") { - acc[field.name] = null; - } else if (field.type === "list") { - acc[field.name] = field.default || []; - } else if (field.type === "checkbox") { - acc[field.name] = field.default || false; - } else if (field.default !== undefined) { - acc[field.name] = field.default; - } - return acc; - }, - {} as { [record: string]: any } - ), + ...configuration.values.reduce((acc, field) => { + if (field.type === "select") { + acc[field.name] = null; + } else if (field.type === "list") { + acc[field.name] = field.default || []; + } else if (field.type === "checkbox") { + acc[field.name] = field.default || false; + } else if (field.default !== undefined) { + acc[field.name] = field.default; + } + return acc; + }, {} as { [record: string]: any }), }; } @@ -1285,28 +1282,25 @@ export function createConnectorValidationSchema( return Yup.object().shape({ access_type: Yup.string().required("Access Type is required"), name: Yup.string().required("Connector Name is required"), - ...configuration.values.reduce( - (acc, field) => { - let schema: any = - field.type === "select" - ? Yup.string() - : field.type === "list" - ? Yup.array().of(Yup.string()) - : field.type === "checkbox" - ? Yup.boolean() - : field.type === "file" - ? Yup.mixed() - : Yup.string(); + ...configuration.values.reduce((acc, field) => { + let schema: any = + field.type === "select" + ? Yup.string() + : field.type === "list" + ? Yup.array().of(Yup.string()) + : field.type === "checkbox" + ? Yup.boolean() + : field.type === "file" + ? Yup.mixed() + : Yup.string(); - if (!field.optional) { - schema = schema.required(`${field.label} is required`); - } + if (!field.optional) { + schema = schema.required(`${field.label} is required`); + } - acc[field.name] = schema; - return acc; - }, - {} as Record - ), + acc[field.name] = schema; + return acc; + }, {} as Record), // These are advanced settings indexingStart: Yup.string().nullable(), pruneFreq: Yup.number().min(0, "Prune frequency must be non-negative"),