diff --git a/backend/onyx/background/celery/tasks/beat_schedule.py b/backend/onyx/background/celery/tasks/beat_schedule.py index 3e7b1c45a..7d0aa9e9c 100644 --- a/backend/onyx/background/celery/tasks/beat_schedule.py +++ b/backend/onyx/background/celery/tasks/beat_schedule.py @@ -29,6 +29,16 @@ cloud_tasks_to_schedule = [ "expires": BEAT_EXPIRES_DEFAULT, }, }, + { + "name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic", + "task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC, + "schedule": timedelta(hours=1), + "options": { + "priority": OnyxCeleryPriority.HIGH, + "expires": BEAT_EXPIRES_DEFAULT, + "queue": OnyxCeleryQueues.MONITORING, + }, + }, ] # tasks that run in either self-hosted on cloud diff --git a/backend/onyx/background/celery/tasks/monitoring/tasks.py b/backend/onyx/background/celery/tasks/monitoring/tasks.py index 20782ae7f..921932919 100644 --- a/backend/onyx/background/celery/tasks/monitoring/tasks.py +++ b/backend/onyx/background/celery/tasks/monitoring/tasks.py @@ -1,6 +1,8 @@ import json +import time from collections.abc import Callable from datetime import timedelta +from itertools import islice from typing import Any from celery import shared_task @@ -10,13 +12,17 @@ from pydantic import BaseModel from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy import select +from sqlalchemy import text from sqlalchemy.orm import Session from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length +from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT +from onyx.configs.constants import ONYX_CLOUD_TENANT_ID from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks +from onyx.db.engine import get_all_tenant_ids from onyx.db.engine import get_db_current_time from onyx.db.engine import get_session_with_tenant from onyx.db.enums import IndexingStatus @@ -27,6 +33,7 @@ from onyx.db.models import IndexAttempt from onyx.db.models import SyncRecord from onyx.db.models import UserGroup from onyx.redis.redis_pool import get_redis_client +from onyx.redis.redis_pool import redis_lock_dump from onyx.utils.telemetry import optional_telemetry from onyx.utils.telemetry import RecordType @@ -456,3 +463,116 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None: lock_monitoring.release() task_logger.info("Background monitoring task finished") + + +@shared_task( + name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC, +) +def cloud_check_alembic() -> bool | None: + """A task to verify that all tenants are on the same alembic revision. + + This check is expected to fail if a cloud alembic migration is currently running + across all tenants. + + TODO: have the cloud migration script set an activity signal that this check + uses to know it doesn't make sense to run a check at the present time. + """ + time_start = time.monotonic() + + redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID) + + lock_beat: RedisLock = redis_client.lock( + OnyxRedisLocks.CLOUD_CHECK_ALEMBIC_BEAT_LOCK, + timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, + ) + + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return None + + last_lock_time = time.monotonic() + + tenant_to_revision: dict[str, str | None] = {} + revision_counts: dict[str, int] = {} + out_of_date_tenants: dict[str, str | None] = {} + top_revision: str = "" + + try: + # map each tenant_id to its revision + tenant_ids = get_all_tenant_ids() + for tenant_id in tenant_ids: + current_time = time.monotonic() + if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4): + lock_beat.reacquire() + last_lock_time = current_time + + if tenant_id is None: + continue + + with get_session_with_tenant(tenant_id=None) as session: + result = session.execute( + text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1') + ) + + result_scalar: str | None = result.scalar_one_or_none() + tenant_to_revision[tenant_id] = result_scalar + + # get the total count of each revision + for k, v in tenant_to_revision.items(): + if v is None: + continue + + revision_counts[v] = revision_counts.get(v, 0) + 1 + + # get the revision with the most counts + sorted_revision_counts = sorted( + revision_counts.items(), key=lambda item: item[1], reverse=True + ) + + if len(sorted_revision_counts) == 0: + task_logger.error( + f"cloud_check_alembic - No revisions found for {len(tenant_ids)} tenant ids!" + ) + else: + top_revision, _ = sorted_revision_counts[0] + + # build a list of out of date tenants + for k, v in tenant_to_revision.items(): + if v == top_revision: + continue + + out_of_date_tenants[k] = v + + except SoftTimeLimitExceeded: + task_logger.info( + "Soft time limit exceeded, task is being terminated gracefully." + ) + except Exception: + task_logger.exception("Unexpected exception during cloud alembic check") + raise + finally: + if lock_beat.owned(): + lock_beat.release() + else: + task_logger.error("cloud_check_alembic - Lock not owned on completion") + redis_lock_dump(lock_beat, redis_client) + + if len(out_of_date_tenants) > 0: + task_logger.error( + f"Found out of date tenants: " + f"num_out_of_date_tenants={len(out_of_date_tenants)} " + f"num_tenants={len(tenant_ids)} " + f"revision={top_revision}" + ) + for k, v in islice(out_of_date_tenants.items(), 5): + task_logger.info(f"Out of date tenant: tenant={k} revision={v}") + else: + task_logger.info( + f"All tenants are up to date: num_tenants={len(tenant_ids)} revision={top_revision}" + ) + + time_elapsed = time.monotonic() - time_start + task_logger.info( + f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}" + ) + return True diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index b64cfcd21..34156324c 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -295,6 +295,7 @@ class OnyxRedisLocks: ANONYMOUS_USER_ENABLED = "anonymous_user_enabled" CLOUD_CHECK_INDEXING_BEAT_LOCK = "da_lock:cloud_check_indexing_beat" + CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic" class OnyxRedisSignals: @@ -344,6 +345,7 @@ class OnyxCeleryTask: AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task" CLOUD_CHECK_FOR_INDEXING = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_for_indexing" + CLOUD_CHECK_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_alembic" REDIS_SOCKET_KEEPALIVE_OPTIONS = {}