diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 8a944689d..f31a1594c 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -1,6 +1,17 @@ import asyncio from logging.config import fileConfig +from alembic import context +from danswer.db.engine import build_connection_string +from danswer.db.models import Base +from sqlalchemy import pool, text +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import create_async_engine +from celery.backends.database.session import ResultModelBase # type: ignore + +import asyncio +from logging.config import fileConfig + from alembic import context from danswer.db.engine import build_connection_string from danswer.db.models import Base @@ -9,64 +20,74 @@ from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import create_async_engine from celery.backends.database.session import ResultModelBase # type: ignore -# this is the Alembic Config object, which provides -# access to the values within the .ini file in use. +# Alembic Config object config = context.config # Interpret the config file for Python logging. -# This line sets up loggers basically. if config.config_file_name is not None: fileConfig(config.config_file_name) -# add your model's MetaData object here -# for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata +# Add your model's MetaData object here target_metadata = [Base.metadata, ResultModelBase.metadata] -# other values from the config, defined by the needs of env.py, -# can be acquired: -# my_important_option = config.get_main_option("my_important_option") -# ... etc. - +def get_schema_options(): + x_args_raw = context.get_x_argument() + x_args = {} + for arg in x_args_raw: + for pair in arg.split(','): + if '=' in pair: + key, value = pair.split('=', 1) + x_args[key] = value + print(f"x_args: {x_args}") # For debugging + schema_name = x_args.get('schema', 'public') # Default schema + create_schema = x_args.get('create_schema', 'false').lower() == 'true' + return schema_name, create_schema def run_migrations_offline() -> None: - """Run migrations in 'offline' mode. - - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - - Calls to context.execute() here emit the given string to the - script output. - - """ + """Run migrations in 'offline' mode.""" url = build_connection_string() + schema, create_schema = get_schema_options() + + if create_schema: + raise RuntimeError("Cannot create schema in offline mode. Please run migrations online to create the schema.") + context.configure( url=url, - target_metadata=target_metadata, # type: ignore + target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"}, + version_table_schema=schema, + include_schemas=True, ) with context.begin_transaction(): context.run_migrations() - def do_run_migrations(connection: Connection) -> None: - context.configure(connection=connection, target_metadata=target_metadata) # type: ignore + schema, create_schema = get_schema_options() + + if create_schema: + # Use text() to create a proper SQL expression + connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"')) + connection.execute(text('COMMIT')) + + # Set the search_path to the target schema + connection.execute(text(f'SET search_path TO "{schema}"')) + + context.configure( + connection=connection, + target_metadata=target_metadata, + version_table_schema=schema, + include_schemas=True, + compare_type=True, + compare_server_default=True, + ) with context.begin_transaction(): context.run_migrations() - async def run_async_migrations() -> None: - """In this scenario we need to create an Engine - and associate a connection with the context. - - """ - + """Run migrations in 'online' mode.""" connectable = create_async_engine( build_connection_string(), poolclass=pool.NullPool, @@ -77,13 +98,10 @@ async def run_async_migrations() -> None: await connectable.dispose() - def run_migrations_online() -> None: """Run migrations in 'online' mode.""" - asyncio.run(run_async_migrations()) - if context.is_offline_mode(): run_migrations_offline() else: diff --git a/backend/alembic/versions/da4c21c69164_chosen_assistants_changed_to_jsonb.py b/backend/alembic/versions/da4c21c69164_chosen_assistants_changed_to_jsonb.py index 95b53cbeb..5eb0f099f 100644 --- a/backend/alembic/versions/da4c21c69164_chosen_assistants_changed_to_jsonb.py +++ b/backend/alembic/versions/da4c21c69164_chosen_assistants_changed_to_jsonb.py @@ -19,15 +19,16 @@ depends_on: None = None def upgrade() -> None: conn = op.get_bind() + existing_ids_and_chosen_assistants = conn.execute( - sa.text("select id, chosen_assistants from public.user") + sa.text('SELECT id, chosen_assistants FROM "user"') ) op.drop_column( - "user", + 'user', "chosen_assistants", ) op.add_column( - "user", + 'user', sa.Column( "chosen_assistants", postgresql.JSONB(astext_type=sa.Text()), @@ -37,7 +38,7 @@ def upgrade() -> None: for id, chosen_assistants in existing_ids_and_chosen_assistants: conn.execute( sa.text( - "update public.user set chosen_assistants = :chosen_assistants where id = :id" + 'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id' ), {"chosen_assistants": json.dumps(chosen_assistants), "id": id}, ) @@ -46,20 +47,20 @@ def upgrade() -> None: def downgrade() -> None: conn = op.get_bind() existing_ids_and_chosen_assistants = conn.execute( - sa.text("select id, chosen_assistants from public.user") + sa.text('SELECT id, chosen_assistants FROM user') ) op.drop_column( - "user", + 'user', "chosen_assistants", ) op.add_column( - "user", + 'user', sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True), ) for id, chosen_assistants in existing_ids_and_chosen_assistants: conn.execute( sa.text( - "update public.user set chosen_assistants = :chosen_assistants where id = :id" + 'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id' ), {"chosen_assistants": chosen_assistants, "id": id}, )