diff --git a/backend/alembic/env.py b/backend/alembic/env.py index afa5a9669c17..c89d34552276 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -1,10 +1,11 @@ +from sqlalchemy.engine.base import Connection from typing import Any import asyncio from logging.config import fileConfig +import logging from alembic import context from sqlalchemy import pool -from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.sql import text @@ -12,6 +13,7 @@ from danswer.configs.app_configs import MULTI_TENANT from danswer.db.engine import build_connection_string from danswer.db.models import Base from celery.backends.database.session import ResultModelBase # type: ignore +from danswer.background.celery.celery_app import get_all_tenant_ids # Alembic Config object config = context.config @@ -22,14 +24,32 @@ if config.config_file_name is not None and config.attributes.get( ): fileConfig(config.config_file_name) -# Add your model's MetaData object here -# for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata +# Add your model's MetaData object here for 'autogenerate' support target_metadata = [Base.metadata, ResultModelBase.metadata] +EXCLUDE_TABLES = {"kombu_queue", "kombu_message"} -def get_schema_options() -> tuple[str, bool]: +# Set up logging +logger = logging.getLogger(__name__) + + +def include_object( + object: Any, name: str, type_: str, reflected: bool, compare_to: Any +) -> bool: + """ + Determines whether a database object should be included in migrations. + Excludes specified tables from migrations. + """ + if type_ == "table" and name in EXCLUDE_TABLES: + return False + return True + + +def get_schema_options() -> tuple[str, bool, bool]: + """ + Parses command-line options passed via '-x' in Alembic commands. + Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options. + """ x_args_raw = context.get_x_argument() x_args = {} for arg in x_args_raw: @@ -39,49 +59,7 @@ def get_schema_options() -> tuple[str, bool]: x_args[key.strip()] = value.strip() schema_name = x_args.get("schema", "public") create_schema = x_args.get("create_schema", "true").lower() == "true" - return schema_name, create_schema - - -EXCLUDE_TABLES = {"kombu_queue", "kombu_message"} - - -def include_object( - object: Any, name: str, type_: str, reflected: bool, compare_to: Any -) -> bool: - if type_ == "table" and name in EXCLUDE_TABLES: - return False - return True - - -def run_migrations_offline() -> None: - """Run migrations in 'offline' mode. - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - Calls to context.execute() here emit the given string to the - script output. - """ - schema_name, _ = get_schema_options() - url = build_connection_string() - - context.configure( - url=url, - target_metadata=target_metadata, # type: ignore - literal_binds=True, - include_object=include_object, - version_table_schema=schema_name, - include_schemas=True, - script_location=config.get_main_option("script_location"), - dialect_opts={"paramstyle": "named"}, - ) - - with context.begin_transaction(): - context.run_migrations() - - -def do_run_migrations(connection: Connection) -> None: - schema_name, create_schema = get_schema_options() + upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true" if MULTI_TENANT and schema_name == "public": raise ValueError( @@ -89,6 +67,17 @@ def do_run_migrations(connection: Connection) -> None: "Please specify a tenant-specific schema." ) + return schema_name, create_schema, upgrade_all_tenants + + +def do_run_migrations( + connection: Connection, schema_name: str, create_schema: bool +) -> None: + """ + Executes migrations in the specified schema. + """ + logger.info(f"About to migrate schema: {schema_name}") + if create_schema: connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"')) connection.execute(text("COMMIT")) @@ -112,18 +101,98 @@ def do_run_migrations(connection: Connection) -> None: async def run_async_migrations() -> None: - connectable = create_async_engine( + """ + Determines whether to run migrations for a single schema or all schemas, + and executes migrations accordingly. + """ + schema_name, create_schema, upgrade_all_tenants = get_schema_options() + + engine = create_async_engine( build_connection_string(), poolclass=pool.NullPool, ) - async with connectable.connect() as connection: - await connection.run_sync(do_run_migrations) + if upgrade_all_tenants: + # Run migrations for all tenant schemas sequentially + tenant_schemas = get_all_tenant_ids() - await connectable.dispose() + for schema in tenant_schemas: + try: + logger.info(f"Migrating schema: {schema}") + async with engine.connect() as connection: + await connection.run_sync( + do_run_migrations, + schema_name=schema, + create_schema=create_schema, + ) + except Exception as e: + logger.error(f"Error migrating schema {schema}: {e}") + raise + else: + try: + logger.info(f"Migrating schema: {schema_name}") + async with engine.connect() as connection: + await connection.run_sync( + do_run_migrations, + schema_name=schema_name, + create_schema=create_schema, + ) + except Exception as e: + logger.error(f"Error migrating schema {schema_name}: {e}") + raise + + await engine.dispose() + + +def run_migrations_offline() -> None: + """ + Run migrations in 'offline' mode. + """ + schema_name, _, upgrade_all_tenants = get_schema_options() + url = build_connection_string() + + if upgrade_all_tenants: + # Run offline migrations for all tenant schemas + engine = create_async_engine(url) + tenant_schemas = get_all_tenant_ids() + engine.sync_engine.dispose() + + for schema in tenant_schemas: + logger.info(f"Migrating schema: {schema}") + context.configure( + url=url, + target_metadata=target_metadata, # type: ignore + literal_binds=True, + include_object=include_object, + version_table_schema=schema, + include_schemas=True, + script_location=config.get_main_option("script_location"), + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + else: + logger.info(f"Migrating schema: {schema_name}") + context.configure( + url=url, + target_metadata=target_metadata, # type: ignore + literal_binds=True, + include_object=include_object, + version_table_schema=schema_name, + include_schemas=True, + script_location=config.get_main_option("script_location"), + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() def run_migrations_online() -> None: + """ + Runs migrations in 'online' mode using an asynchronous engine. + """ asyncio.run(run_async_migrations())