add multi tenant alembic (#2589)

This commit is contained in:
pablodanswer 2024-10-05 14:59:15 -07:00 committed by GitHub
parent 493c3d7314
commit 28e65669b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -9,9 +9,9 @@ from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore from celery.backends.database.session import ResultModelBase # type: ignore
from sqlalchemy.schema import SchemaItem from sqlalchemy.schema import SchemaItem
from sqlalchemy.sql import text
# this is the Alembic Config object, which provides # Alembic Config object
# access to the values within the .ini file in use.
config = context.config config = context.config
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
@ -21,16 +21,26 @@ if config.config_file_name is not None and config.attributes.get(
): ):
fileConfig(config.config_file_name) fileConfig(config.config_file_name)
# add your model's MetaData object here # Add your model's MetaData object here
# for 'autogenerate' support # for 'autogenerate' support
# from myapp import mymodel # from myapp import mymodel
# target_metadata = mymodel.Base.metadata # target_metadata = mymodel.Base.metadata
target_metadata = [Base.metadata, ResultModelBase.metadata] target_metadata = [Base.metadata, ResultModelBase.metadata]
# other values from the config, defined by the needs of env.py,
# can be acquired: def get_schema_options() -> tuple[str, bool]:
# my_important_option = config.get_main_option("my_important_option") x_args_raw = context.get_x_argument()
# ... etc. x_args = {}
for arg in x_args_raw:
for pair in arg.split(","):
if "=" in pair:
key, value = pair.split("=", 1)
x_args[key] = value
schema_name = x_args.get("schema", "public")
create_schema = x_args.get("create_schema", "true").lower() == "true"
return schema_name, create_schema
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"} EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
@ -54,17 +64,20 @@ def run_migrations_offline() -> None:
and not an Engine, though an Engine is acceptable and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation here as well. By skipping the Engine creation
we don't even need a DBAPI to be available. we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the Calls to context.execute() here emit the given string to the
script output. script output.
""" """
url = build_connection_string() url = build_connection_string()
schema, _ = get_schema_options()
context.configure( context.configure(
url=url, url=url,
target_metadata=target_metadata, # type: ignore target_metadata=target_metadata, # type: ignore
literal_binds=True, literal_binds=True,
include_object=include_object,
dialect_opts={"paramstyle": "named"}, dialect_opts={"paramstyle": "named"},
version_table_schema=schema,
include_schemas=True,
) )
with context.begin_transaction(): with context.begin_transaction():
@ -72,22 +85,28 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None: def do_run_migrations(connection: Connection) -> None:
schema, create_schema = get_schema_options()
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
connection.execute(text("COMMIT"))
connection.execute(text(f'SET search_path TO "{schema}"'))
context.configure( context.configure(
connection=connection, connection=connection,
target_metadata=target_metadata, # type: ignore target_metadata=target_metadata, # type: ignore
include_object=include_object, version_table_schema=schema,
) # type: ignore include_schemas=True,
compare_type=True,
compare_server_default=True,
)
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()
async def run_async_migrations() -> None: async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine """Run migrations in 'online' mode."""
and associate a connection with the context.
"""
connectable = create_async_engine( connectable = create_async_engine(
build_connection_string(), build_connection_string(),
poolclass=pool.NullPool, poolclass=pool.NullPool,
@ -101,7 +120,6 @@ async def run_async_migrations() -> None:
def run_migrations_online() -> None: def run_migrations_online() -> None:
"""Run migrations in 'online' mode.""" """Run migrations in 'online' mode."""
asyncio.run(run_async_migrations()) asyncio.run(run_async_migrations())