This commit is contained in:
pablonyx 2025-03-06 11:59:33 -08:00
parent 55b0b02068
commit b8b20585e1
7 changed files with 174 additions and 166 deletions

View File

@ -67,10 +67,12 @@ async def get_or_provision_tenant(
if referral_source and request:
await submit_to_hubspot(email, referral_source, request)
tenant_id: str | None = None
try:
# First, check if the user already has a tenant
tenant_id = get_tenant_id_for_email(email)
return tenant_id
except exceptions.UserNotExists:
# User doesn't exist, so we need to create a new tenant or assign an existing one
try:
@ -78,25 +80,8 @@ async def get_or_provision_tenant(
tenant_id = await get_available_tenant()
if tenant_id:
# If we have a pre-provisioned tenant, just add the user to it
add_users_to_tenant([email], tenant_id)
# Create milestone record
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
# Notify control plane
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(
f"Assigned pre-provisioned tenant {tenant_id} to user {email}"
)
@ -125,9 +110,11 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
if not DEV_MODE:
# Notify control plane if not already done in provision_tenant
if not DEV_MODE and referral_source:
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
@ -145,56 +132,25 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
)
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
token = None
try:
# Create the schema for the tenant
if not create_schema_if_not_exists(tenant_id):
logger.debug(f"Created schema for tenant {tenant_id}")
else:
logger.debug(f"Schema already exists for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Set up the tenant with all necessary configurations
await setup_tenant(tenant_id)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
# Add the user to the tenant
add_users_to_tenant([email], tenant_id)
# Create milestone record
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
# Assign the tenant to the user
await assign_tenant_to_user(tenant_id, email)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def notify_control_plane(
@ -422,3 +378,66 @@ async def get_available_tenant() -> str | None:
logger.error(f"Error getting available tenant: {e}")
return None
async def setup_tenant(tenant_id: str) -> None:
"""
Set up a tenant with all necessary configurations.
This is a centralized function that handles all tenant setup logic.
"""
token = None
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Run Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
# Configure the tenant with default settings
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Configure default API keys
configure_default_api_keys(db_session)
# Set up Onyx with appropriate settings
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
except Exception as e:
logger.exception(f"Failed to set up tenant {tenant_id}")
raise e
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def assign_tenant_to_user(
tenant_id: str, email: str, referral_source: str | None = None
) -> None:
"""
Assign a tenant to a user and perform necessary operations.
"""
# Add the user to the tenant
add_users_to_tenant([email], tenant_id)
# Create milestone record
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
# Notify control plane
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)

View File

@ -85,7 +85,7 @@ def get_current_alembic_version(tenant_id: str) -> str:
# Set the search path to the tenant's schema
with engine.connect() as connection:
connection.execute(text(f"SET search_path TO {tenant_id}"))
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
# Get the current version from the alembic_version table
context = MigrationContext.configure(connection)

View File

@ -111,5 +111,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.periodic.tenant_provisioning",
]
)

View File

@ -49,6 +49,19 @@ celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.primary")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
# Import tasks to ensure they are registered with Celery
import onyx.background.celery.tasks.connector_deletion.tasks # noqa
import onyx.background.celery.tasks.doc_permission_syncing.tasks # noqa
import onyx.background.celery.tasks.external_group_syncing.tasks # noqa
import onyx.background.celery.tasks.indexing.tasks # noqa
import onyx.background.celery.tasks.llm_model_update.tasks # noqa
import onyx.background.celery.tasks.monitoring.tasks # noqa
import onyx.background.celery.tasks.periodic.tasks # noqa
import onyx.background.celery.tasks.periodic.tenant_provisioning # noqa
import onyx.background.celery.tasks.pruning.tasks # noqa
import onyx.background.celery.tasks.shared.tasks # noqa
import onyx.background.celery.tasks.vespa.tasks # noqa
@signals.task_prerun.connect
def on_task_prerun(

View File

@ -20,7 +20,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# hack to slow down task dispatch in the cloud until
# we have a better implementation (backpressure, etc)
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
CLOUD_BEAT_MULTIPLIER_DEFAULT = 0.5
# tasks that run in either self-hosted on cloud
beat_task_templates: list[dict] = []
@ -100,15 +100,6 @@ beat_task_templates.extend(
"queue": OnyxCeleryQueues.MONITORING,
},
},
{
"name": "check-available-tenants",
"task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
"schedule": timedelta(minutes=5),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
)
@ -176,6 +167,15 @@ beat_cloud_tasks: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
"task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
"schedule": timedelta(seconds=10),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
# tasks that only run self hosted

View File

@ -3,7 +3,6 @@ Periodic tasks for tenant pre-provisioning.
"""
import asyncio
import datetime
import logging
import uuid
from celery import shared_task
@ -11,7 +10,7 @@ from celery import Task
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.celery_utils import get_redis_client
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@ -19,9 +18,9 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import NewAvailableTenant
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.enums import EmbeddingProvider
# Default number of pre-provisioned tenants to maintain
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
@ -31,8 +30,6 @@ _TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
# Hard time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
logger = logging.getLogger(__name__)
@shared_task(
name=OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
@ -47,8 +44,11 @@ def check_available_tenants(self: Task) -> None:
Check if we have enough pre-provisioned tenants available.
If not, trigger the pre-provisioning of new tenants.
"""
task_logger.warning("STARTING CHECK_AVAILABLE_TENANTS")
if not MULTI_TENANT:
logger.debug("Multi-tenancy is not enabled, skipping tenant pre-provisioning")
task_logger.warning(
"Multi-tenancy is not enabled, skipping tenant pre-provisioning"
)
return
r = get_redis_client()
@ -59,7 +59,7 @@ def check_available_tenants(self: Task) -> None:
# These tasks should never overlap
if not lock_check.acquire(blocking=False):
logger.debug(
task_logger.warning(
"Skipping check_available_tenants task because it is already running"
)
return
@ -79,7 +79,7 @@ def check_available_tenants(self: Task) -> None:
0, target_available_tenants - available_tenants_count
)
logger.info(
task_logger.warning(
f"Available tenants: {available_tenants_count}, "
f"Target: {target_available_tenants}, "
f"To provision: {tenants_to_provision}"
@ -92,7 +92,8 @@ def check_available_tenants(self: Task) -> None:
)
except Exception as e:
logger.exception(f"Error in check_available_tenants task: {e}")
task_logger.exception(f"Error in check_available_tenants task: {e}")
finally:
lock_check.release()
@ -111,10 +112,12 @@ def pre_provision_tenant(self: Task) -> None:
This function fully sets up the tenant with all necessary configurations,
so it's ready to be assigned to a user immediately.
"""
task_logger.warning("STARTING PRE_PROVISION_TENANT")
if not MULTI_TENANT:
logger.debug("Multi-tenancy is not enabled, skipping tenant pre-provisioning")
task_logger.warning(
"Multi-tenancy is not enabled, skipping tenant pre-provisioning"
)
return
r = get_redis_client()
lock_provision: RedisLock = r.lock(
OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK,
@ -123,75 +126,53 @@ def pre_provision_tenant(self: Task) -> None:
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
if not lock_provision.acquire(blocking=False):
logger.debug("Skipping pre_provision_tenant task because it is already running")
task_logger.warning(
"Skipping pre_provision_tenant task because it is already running"
)
return
try:
# Generate a new tenant ID
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
token = None
task_logger.warning(f"Starting pre-provisioning for tenant {tenant_id}")
# Import here to avoid circular imports
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
from ee.onyx.server.tenants.provisioning import configure_default_api_keys
from onyx.setup import setup_onyx
from onyx.db.models import SearchSettings, IndexModelStatus
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from onyx.db.engine import get_session_with_tenant
from ee.onyx.server.tenants.provisioning import setup_tenant
# Create the schema for the new tenant
if not create_schema_if_not_exists(tenant_id):
logger.debug(f"Created schema for tenant {tenant_id}")
schema_created = create_schema_if_not_exists(tenant_id)
if schema_created:
task_logger.warning(f"Created schema for tenant '{tenant_id}'")
else:
logger.debug(f"Schema already exists for tenant {tenant_id}")
task_logger.warning(f"Schema already exists for tenant '{tenant_id}'")
try:
# Set the tenant context
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Set up the tenant with all necessary configurations
task_logger.warning(f"Setting up tenant configuration for '{tenant_id}'")
asyncio.run(setup_tenant(tenant_id))
task_logger.warning(f"Tenant configuration completed for '{tenant_id}'")
# Run Alembic migrations
asyncio.run(asyncio.to_thread(run_alembic_migrations, tenant_id))
# Get the current Alembic version
alembic_version = get_current_alembic_version(tenant_id)
task_logger.warning(
f"Tenant '{tenant_id}' using Alembic version: {alembic_version}"
)
# Configure the tenant with default settings
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Configure default API keys
configure_default_api_keys(db_session)
# Store the pre-provisioned tenant in the database
task_logger.warning(f"Storing pre-provisioned tenant '{tenant_id}' in database")
with Session(get_sqlalchemy_engine()) as db_session:
new_tenant = NewAvailableTenant(
tenant_id=tenant_id,
alembic_version=alembic_version,
date_created=datetime.datetime.now(),
)
db_session.add(new_tenant)
db_session.commit()
# Set up Onyx with appropriate settings
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type
== EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
# Get the current Alembic version
alembic_version = get_current_alembic_version(tenant_id)
# Store the pre-provisioned tenant in the database
with Session(get_sqlalchemy_engine()) as db_session:
new_tenant = NewAvailableTenant(
tenant_id=tenant_id,
alembic_version=alembic_version,
date_created=datetime.datetime.now(),
)
db_session.add(new_tenant)
db_session.commit()
logger.info(f"Successfully pre-provisioned tenant {tenant_id}")
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
task_logger.warning(f"Successfully pre-provisioned tenant {tenant_id}")
except Exception as e:
logger.exception(f"Error in pre_provision_tenant task: {e}")
task_logger.exception(f"Error in pre_provision_tenant task: {e}")
finally:
lock_provision.release()

View File

@ -1259,21 +1259,18 @@ export function createConnectorInitialValues(
name: "",
groups: [],
access_type: "public",
...configuration.values.reduce(
(acc, field) => {
if (field.type === "select") {
acc[field.name] = null;
} else if (field.type === "list") {
acc[field.name] = field.default || [];
} else if (field.type === "checkbox") {
acc[field.name] = field.default || false;
} else if (field.default !== undefined) {
acc[field.name] = field.default;
}
return acc;
},
{} as { [record: string]: any }
),
...configuration.values.reduce((acc, field) => {
if (field.type === "select") {
acc[field.name] = null;
} else if (field.type === "list") {
acc[field.name] = field.default || [];
} else if (field.type === "checkbox") {
acc[field.name] = field.default || false;
} else if (field.default !== undefined) {
acc[field.name] = field.default;
}
return acc;
}, {} as { [record: string]: any }),
};
}
@ -1285,28 +1282,25 @@ export function createConnectorValidationSchema(
return Yup.object().shape({
access_type: Yup.string().required("Access Type is required"),
name: Yup.string().required("Connector Name is required"),
...configuration.values.reduce(
(acc, field) => {
let schema: any =
field.type === "select"
? Yup.string()
: field.type === "list"
? Yup.array().of(Yup.string())
: field.type === "checkbox"
? Yup.boolean()
: field.type === "file"
? Yup.mixed()
: Yup.string();
...configuration.values.reduce((acc, field) => {
let schema: any =
field.type === "select"
? Yup.string()
: field.type === "list"
? Yup.array().of(Yup.string())
: field.type === "checkbox"
? Yup.boolean()
: field.type === "file"
? Yup.mixed()
: Yup.string();
if (!field.optional) {
schema = schema.required(`${field.label} is required`);
}
if (!field.optional) {
schema = schema.required(`${field.label} is required`);
}
acc[field.name] = schema;
return acc;
},
{} as Record<string, any>
),
acc[field.name] = schema;
return acc;
}, {} as Record<string, any>),
// These are advanced settings
indexingStart: Yup.string().nullable(),
pruneFreq: Yup.number().min(0, "Prune frequency must be non-negative"),