rkuo-danswer c9a609b7d8
Bugfix/slack bot channel config (#4585)
* friendlier handling of slack channel retrieval

* retry on downgrade_postgres deadlock

* fix comment

* text

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-23 20:00:03 +00:00

414 lines
13 KiB
Python

import logging
import time
from types import SimpleNamespace
import psycopg2
import requests
from alembic import command
from alembic.config import Config
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PASSWORD
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import SYNC_DB_API
from onyx.db.search_settings import get_current_search_settings
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.document_index.document_index_utils import get_multipass_config
from onyx.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from onyx.setup import setup_postgres
from onyx.setup import setup_vespa
from onyx.utils.logger import setup_logger
from tests.integration.common_utils.timeout import run_with_timeout_multiproc
logger = setup_logger()
def _run_migrations(
database_url: str,
config_name: str,
direction: str = "upgrade",
revision: str = "head",
schema: str = "public",
) -> None:
# hide info logs emitted during migration
logging.getLogger("alembic").setLevel(logging.CRITICAL)
# Create an Alembic configuration object
alembic_cfg = Config("alembic.ini")
alembic_cfg.set_section_option("logger_alembic", "level", "WARN")
alembic_cfg.attributes["configure_logger"] = False
alembic_cfg.config_ini_section = config_name
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema}"] # type: ignore
# Set the SQLAlchemy URL in the Alembic configuration
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
# Run the migration
if direction == "upgrade":
command.upgrade(alembic_cfg, revision)
elif direction == "downgrade":
command.downgrade(alembic_cfg, revision)
else:
raise ValueError(
f"Invalid direction: {direction}. Must be 'upgrade' or 'downgrade'."
)
logging.getLogger("alembic").setLevel(logging.INFO)
def downgrade_postgres(
database: str = "postgres",
schema: str = "public",
config_name: str = "alembic",
revision: str = "base",
clear_data: bool = False,
) -> None:
"""Downgrade Postgres database to base state."""
if clear_data:
if revision != "base":
raise ValueError("Clearing data without rolling back to base state")
conn = psycopg2.connect(
dbname=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
application_name="downgrade_postgres",
)
conn.autocommit = True # Need autocommit for dropping schema
cur = conn.cursor()
# Close any existing connections to the schema before dropping
cur.execute(
f"""
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND pg_stat_activity.state = 'idle in transaction'
AND pid <> pg_backend_pid();
"""
)
# Drop and recreate the public schema - this removes ALL objects
cur.execute(f"DROP SCHEMA {schema} CASCADE;")
cur.execute(f"CREATE SCHEMA {schema};")
# Restore default privileges
cur.execute(f"GRANT ALL ON SCHEMA {schema} TO postgres;")
cur.execute(f"GRANT ALL ON SCHEMA {schema} TO public;")
cur.close()
conn.close()
return
# Downgrade to base
conn_str = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
db_api=SYNC_DB_API,
)
_run_migrations(
conn_str,
config_name,
direction="downgrade",
revision=revision,
)
def upgrade_postgres(
database: str = "postgres", config_name: str = "alembic", revision: str = "head"
) -> None:
"""Upgrade Postgres database to latest version."""
conn_str = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
db_api=SYNC_DB_API,
app_name="upgrade_postgres",
)
_run_migrations(
conn_str,
config_name,
direction="upgrade",
revision=revision,
)
def drop_multitenant_postgres(
database: str = "postgres",
) -> None:
"""Reset the Postgres database."""
# this seems to hang due to locking issues, so run with a timeout with a few retries
NUM_TRIES = 10
TIMEOUT = 40
success = False
for _ in range(NUM_TRIES):
logger.info(f"drop_multitenant_postgres_task starting... ({_ + 1}/{NUM_TRIES})")
try:
run_with_timeout_multiproc(
drop_multitenant_postgres_task,
TIMEOUT,
kwargs={
"dbname": database,
},
)
success = True
break
except TimeoutError:
logger.warning(
f"drop_multitenant_postgres_task timed out, retrying... ({_ + 1}/{NUM_TRIES})"
)
except RuntimeError:
logger.warning(
f"drop_multitenant_postgres_task exceptioned, retrying... ({_ + 1}/{NUM_TRIES})"
)
if not success:
raise RuntimeError("drop_multitenant_postgres_task failed after 10 timeouts.")
def drop_multitenant_postgres_task(dbname: str) -> None:
conn = psycopg2.connect(
dbname=dbname,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
connect_timeout=10,
application_name="drop_multitenant_postgres_task",
)
conn.autocommit = True
cur = conn.cursor()
logger.info("Selecting tenant schemas.")
# Get all tenant schemas
cur.execute(
"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name LIKE 'tenant_%'
"""
)
tenant_schemas = cur.fetchall()
# Drop all tenant schemas
logger.info("Dropping all tenant schemas.")
for schema in tenant_schemas:
# Close any existing connections to the schema before dropping
cur.execute(
"""
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = 'postgres'
AND pg_stat_activity.state = 'idle in transaction'
AND pid <> pg_backend_pid();
"""
)
schema_name = schema[0]
cur.execute(f'DROP SCHEMA "{schema_name}" CASCADE')
# Drop tables in the public schema
logger.info("Selecting public schema tables.")
cur.execute(
"""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
"""
)
public_tables = cur.fetchall()
logger.info("Dropping public schema tables.")
for table in public_tables:
table_name = table[0]
cur.execute(f'DROP TABLE IF EXISTS public."{table_name}" CASCADE')
cur.close()
conn.close()
def reset_postgres(
database: str = "postgres",
config_name: str = "alembic",
setup_onyx: bool = True,
) -> None:
"""Reset the Postgres database."""
# this seems to hang due to locking issues, so run with a timeout with a few retries
NUM_TRIES = 10
TIMEOUT = 40
success = False
for _ in range(NUM_TRIES):
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")
try:
run_with_timeout_multiproc(
downgrade_postgres,
TIMEOUT,
kwargs={
"database": database,
"config_name": config_name,
"revision": "base",
"clear_data": True,
},
)
success = True
break
except TimeoutError:
logger.warning(
f"Postgres downgrade timed out, retrying... ({_ + 1}/{NUM_TRIES})"
)
except RuntimeError:
logger.warning(
f"Postgres downgrade exceptioned, retrying... ({_ + 1}/{NUM_TRIES})"
)
if not success:
raise RuntimeError("Postgres downgrade failed after 10 timeouts.")
logger.info("Upgrading Postgres...")
upgrade_postgres(database=database, config_name=config_name, revision="head")
if setup_onyx:
logger.info("Setting up Postgres...")
with get_session_context_manager() as db_session:
setup_postgres(db_session)
def reset_vespa() -> None:
"""Wipe all data from the Vespa index."""
with get_session_context_manager() as db_session:
# swap to the correct default model
check_and_perform_index_swap(db_session)
search_settings = get_current_search_settings(db_session)
multipass_config = get_multipass_config(search_settings)
index_name = search_settings.index_name
success = setup_vespa(
document_index=VespaIndex(
index_name=index_name,
secondary_index_name=None,
large_chunks_enabled=multipass_config.enable_large_chunks,
secondary_large_chunks_enabled=None,
),
index_setting=IndexingSetting.from_db_model(search_settings),
secondary_index_setting=None,
)
if not success:
raise RuntimeError("Could not connect to Vespa within the specified timeout.")
for _ in range(5):
try:
continuation = None
should_continue = True
while should_continue:
params = {"selection": "true", "cluster": "danswer_index"}
if continuation:
params = {**params, "continuation": continuation}
response = requests.delete(
DOCUMENT_ID_ENDPOINT.format(index_name=index_name), params=params
)
response.raise_for_status()
response_json = response.json()
continuation = response_json.get("continuation")
should_continue = bool(continuation)
break
except Exception as e:
print(f"Error deleting documents: {e}")
time.sleep(5)
def reset_postgres_multitenant() -> None:
"""Reset the Postgres database for all tenants in a multitenant setup."""
drop_multitenant_postgres()
reset_postgres(config_name="schema_private", setup_onyx=False)
def reset_vespa_multitenant() -> None:
"""Wipe all data from the Vespa index for all tenants."""
for tenant_id in get_all_tenant_ids():
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# swap to the correct default model for each tenant
check_and_perform_index_swap(db_session)
search_settings = get_current_search_settings(db_session)
multipass_config = get_multipass_config(search_settings)
index_name = search_settings.index_name
success = setup_vespa(
document_index=VespaIndex(
index_name=index_name,
secondary_index_name=None,
large_chunks_enabled=multipass_config.enable_large_chunks,
secondary_large_chunks_enabled=None,
),
index_setting=IndexingSetting.from_db_model(search_settings),
secondary_index_setting=None,
)
if not success:
raise RuntimeError(
f"Could not connect to Vespa for tenant {tenant_id} within the specified timeout."
)
for _ in range(5):
try:
continuation = None
should_continue = True
while should_continue:
params = {"selection": "true", "cluster": "danswer_index"}
if continuation:
params = {**params, "continuation": continuation}
response = requests.delete(
DOCUMENT_ID_ENDPOINT.format(index_name=index_name),
params=params,
)
response.raise_for_status()
response_json = response.json()
continuation = response_json.get("continuation")
should_continue = bool(continuation)
break
except Exception as e:
print(f"Error deleting documents for tenant {tenant_id}: {e}")
time.sleep(5)
def reset_all() -> None:
logger.info("Resetting Postgres...")
reset_postgres()
logger.info("Resetting Vespa...")
reset_vespa()
def reset_all_multitenant() -> None:
"""Reset both Postgres and Vespa for all tenants."""
logger.info("Resetting Postgres for all tenants...")
reset_postgres_multitenant()
logger.info("Resetting Vespa for all tenants...")
reset_vespa_multitenant()
logger.info("Finished resetting all.")