cloud check for migrations (#3734)

* cloud check for migrations

* fix table declaration

* change back interval

* Fix usage of POSTGRES_DEFAULT_SCHEMA

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer
2025-01-22 14:41:28 -08:00
committed by GitHub
parent a215ea9143
commit 69c60feda4
3 changed files with 132 additions and 0 deletions

View File

@@ -29,6 +29,16 @@ cloud_tasks_to_schedule = [
"expires": BEAT_EXPIRES_DEFAULT, "expires": BEAT_EXPIRES_DEFAULT,
}, },
}, },
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
] ]
# tasks that run in either self-hosted on cloud # tasks that run in either self-hosted on cloud

View File

@@ -1,6 +1,8 @@
import json import json
import time
from collections.abc import Callable from collections.abc import Callable
from datetime import timedelta from datetime import timedelta
from itertools import islice
from typing import Any from typing import Any
from celery import shared_task from celery import shared_task
@@ -10,13 +12,17 @@ from pydantic import BaseModel
from redis import Redis from redis import Redis
from redis.lock import Lock as RedisLock from redis.lock import Lock as RedisLock
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_db_current_time from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus from onyx.db.enums import IndexingStatus
@@ -27,6 +33,7 @@ from onyx.db.models import IndexAttempt
from onyx.db.models import SyncRecord from onyx.db.models import SyncRecord
from onyx.db.models import UserGroup from onyx.db.models import UserGroup
from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.telemetry import optional_telemetry from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType from onyx.utils.telemetry import RecordType
@@ -456,3 +463,116 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
lock_monitoring.release() lock_monitoring.release()
task_logger.info("Background monitoring task finished") task_logger.info("Background monitoring task finished")
@shared_task(
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
)
def cloud_check_alembic() -> bool | None:
"""A task to verify that all tenants are on the same alembic revision.
This check is expected to fail if a cloud alembic migration is currently running
across all tenants.
TODO: have the cloud migration script set an activity signal that this check
uses to know it doesn't make sense to run a check at the present time.
"""
time_start = time.monotonic()
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CLOUD_CHECK_ALEMBIC_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
last_lock_time = time.monotonic()
tenant_to_revision: dict[str, str | None] = {}
revision_counts: dict[str, int] = {}
out_of_date_tenants: dict[str, str | None] = {}
top_revision: str = ""
try:
# map each tenant_id to its revision
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
current_time = time.monotonic()
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
lock_beat.reacquire()
last_lock_time = current_time
if tenant_id is None:
continue
with get_session_with_tenant(tenant_id=None) as session:
result = session.execute(
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
)
result_scalar: str | None = result.scalar_one_or_none()
tenant_to_revision[tenant_id] = result_scalar
# get the total count of each revision
for k, v in tenant_to_revision.items():
if v is None:
continue
revision_counts[v] = revision_counts.get(v, 0) + 1
# get the revision with the most counts
sorted_revision_counts = sorted(
revision_counts.items(), key=lambda item: item[1], reverse=True
)
if len(sorted_revision_counts) == 0:
task_logger.error(
f"cloud_check_alembic - No revisions found for {len(tenant_ids)} tenant ids!"
)
else:
top_revision, _ = sorted_revision_counts[0]
# build a list of out of date tenants
for k, v in tenant_to_revision.items():
if v == top_revision:
continue
out_of_date_tenants[k] = v
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception during cloud alembic check")
raise
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error("cloud_check_alembic - Lock not owned on completion")
redis_lock_dump(lock_beat, redis_client)
if len(out_of_date_tenants) > 0:
task_logger.error(
f"Found out of date tenants: "
f"num_out_of_date_tenants={len(out_of_date_tenants)} "
f"num_tenants={len(tenant_ids)} "
f"revision={top_revision}"
)
for k, v in islice(out_of_date_tenants.items(), 5):
task_logger.info(f"Out of date tenant: tenant={k} revision={v}")
else:
task_logger.info(
f"All tenants are up to date: num_tenants={len(tenant_ids)} revision={top_revision}"
)
time_elapsed = time.monotonic() - time_start
task_logger.info(
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
)
return True

View File

@@ -295,6 +295,7 @@ class OnyxRedisLocks:
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled" ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
CLOUD_CHECK_INDEXING_BEAT_LOCK = "da_lock:cloud_check_indexing_beat" CLOUD_CHECK_INDEXING_BEAT_LOCK = "da_lock:cloud_check_indexing_beat"
CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic"
class OnyxRedisSignals: class OnyxRedisSignals:
@@ -344,6 +345,7 @@ class OnyxCeleryTask:
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task" AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
CLOUD_CHECK_FOR_INDEXING = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_for_indexing" CLOUD_CHECK_FOR_INDEXING = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_for_indexing"
CLOUD_CHECK_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_alembic"
REDIS_SOCKET_KEEPALIVE_OPTIONS = {} REDIS_SOCKET_KEEPALIVE_OPTIONS = {}