diff --git a/backend/onyx/background/celery/celery_redis.py b/backend/onyx/background/celery/celery_redis.py index 717af036675..143c661d892 100644 --- a/backend/onyx/background/celery/celery_redis.py +++ b/backend/onyx/background/celery/celery_redis.py @@ -92,7 +92,8 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int: def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]: - """This is a redis specific way to build a list of tasks in a queue. + """This is a redis specific way to build a list of tasks in a queue and return them + as a set. This helps us read the queue once and then efficiently look for missing tasks in the queue. diff --git a/backend/onyx/background/celery/tasks/connector_deletion/tasks.py b/backend/onyx/background/celery/tasks/connector_deletion/tasks.py index f54aea791a3..baaa60135e1 100644 --- a/backend/onyx/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/onyx/background/celery/tasks/connector_deletion/tasks.py @@ -8,16 +8,21 @@ from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded +from pydantic import ValidationError from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.background.celery.apps.app_base import task_logger +from onyx.background.celery.celery_redis import celery_get_queue_length +from onyx.background.celery.celery_redis import celery_get_queued_task_ids from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT +from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisLocks +from onyx.configs.constants import OnyxRedisSignals from onyx.db.connector import fetch_connector_by_id from onyx.db.connector_credential_pair import add_deletion_failure_message from onyx.db.connector_credential_pair import ( @@ -109,6 +114,7 @@ def check_for_connector_deletion_task( ) -> bool | None: r = get_redis_client() r_replica = get_redis_replica_client() + r_celery: Redis = self.app.broker_connection().channel().client # type: ignore lock_beat: RedisLock = r.lock( OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK, @@ -120,6 +126,21 @@ def check_for_connector_deletion_task( return None try: + # we want to run this less frequently than the overall task + lock_beat.reacquire() + if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES): + # clear fences that don't have associated celery tasks in progress + try: + validate_connector_deletion_fences( + tenant_id, r, r_replica, r_celery, lock_beat + ) + except Exception: + task_logger.exception( + "Exception while validating connector deletion fences" + ) + + r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300) + # collect cc_pair_ids cc_pair_ids: list[int] = [] with get_session_with_current_tenant() as db_session: @@ -243,6 +264,7 @@ def try_generate_document_cc_pair_cleanup_tasks( return None # set a basic fence to start + redis_connector.delete.set_active() fence_payload = RedisConnectorDeletePayload( num_tasks=None, submitted=datetime.now(timezone.utc), @@ -475,3 +497,171 @@ def monitor_connector_deletion_taskset( ) redis_connector.delete.reset() + + +def validate_connector_deletion_fences( + tenant_id: str | None, + r: Redis, + r_replica: Redis, + r_celery: Redis, + lock_beat: RedisLock, +) -> None: + # building lookup table can be expensive, so we won't bother + # validating until the queue is small + CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024 + + queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery) + if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN: + return + + queued_upsert_tasks = celery_get_queued_task_ids( + OnyxCeleryQueues.CONNECTOR_DELETION, r_celery + ) + + # validate all existing connector deletion jobs + lock_beat.reacquire() + keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES)) + for key in keys: + key_bytes = cast(bytes, key) + key_str = key_bytes.decode("utf-8") + if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX): + continue + + validate_connector_deletion_fence( + tenant_id, + key_bytes, + queued_upsert_tasks, + r, + ) + + lock_beat.reacquire() + + return + + +def validate_connector_deletion_fence( + tenant_id: str | None, + key_bytes: bytes, + queued_tasks: set[str], + r: Redis, +) -> None: + """Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist. + This can happen if the indexing worker hard crashes or is terminated. + Being in this bad state means the fence will never clear without help, so this function + gives the help. + + How this works: + 1. This function renews the active signal with a 5 minute TTL under the following conditions + 1.2. When the task is seen in the redis queue + 1.3. When the task is seen in the reserved / prefetched list + + 2. Externally, the active signal is renewed when: + 2.1. The fence is created + 2.2. The indexing watchdog checks the spawned task. + + 3. The TTL allows us to get through the transitions on fence startup + and when the task starts executing. + + More TTL clarification: it is seemingly impossible to exactly query Celery for + whether a task is in the queue or currently executing. + 1. An unknown task id is always returned as state PENDING. + 2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task + and the time it actually starts on the worker. + + queued_tasks: the celery queue of lightweight permission sync tasks + reserved_tasks: prefetched tasks for sync task generator + """ + # if the fence doesn't exist, there's nothing to do + fence_key = key_bytes.decode("utf-8") + cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) + if cc_pair_id_str is None: + task_logger.warning( + f"validate_connector_deletion_fence - could not parse id from {fence_key}" + ) + return + + cc_pair_id = int(cc_pair_id_str) + # parse out metadata and initialize the helper class with it + redis_connector = RedisConnector(tenant_id, int(cc_pair_id)) + + # check to see if the fence/payload exists + if not redis_connector.delete.fenced: + return + + # in the cloud, the payload format may have changed ... + # it's a little sloppy, but just reset the fence for now if that happens + # TODO: add intentional cleanup/abort logic + try: + payload = redis_connector.delete.payload + except ValidationError: + task_logger.exception( + "validate_connector_deletion_fence - " + "Resetting fence because fence schema is out of date: " + f"cc_pair={cc_pair_id} " + f"fence={fence_key}" + ) + + redis_connector.delete.reset() + return + + if not payload: + return + + # OK, there's actually something for us to validate + + # look up every task in the current taskset in the celery queue + # every entry in the taskset should have an associated entry in the celery task queue + # because we get the celery tasks first, the entries in our own permissions taskset + # should be roughly a subset of the tasks in celery + + # this check isn't very exact, but should be sufficient over a period of time + # A single successful check over some number of attempts is sufficient. + + # TODO: if the number of tasks in celery is much lower than than the taskset length + # we might be able to shortcut the lookup since by definition some of the tasks + # must not exist in celery. + + tasks_scanned = 0 + tasks_not_in_celery = 0 # a non-zero number after completing our check is bad + + for member in r.sscan_iter(redis_connector.delete.taskset_key): + tasks_scanned += 1 + + member_bytes = cast(bytes, member) + member_str = member_bytes.decode("utf-8") + if member_str in queued_tasks: + continue + + tasks_not_in_celery += 1 + + task_logger.info( + "validate_connector_deletion_fence task check: " + f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}" + ) + + # we're active if there are still tasks to run and those tasks all exist in celery + if tasks_scanned > 0 and tasks_not_in_celery == 0: + redis_connector.delete.set_active() + return + + # we may want to enable this check if using the active task list somehow isn't good enough + # if redis_connector_index.generator_locked(): + # logger.info(f"{payload.celery_task_id} is currently executing.") + + # if we get here, we didn't find any direct indication that the associated celery tasks exist, + # but they still might be there due to gaps in our ability to check states during transitions + # Checking the active signal safeguards us against these transition periods + # (which has a duration that allows us to bridge those gaps) + if redis_connector.delete.active(): + return + + # celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up. + task_logger.warning( + "validate_connector_deletion_fence - " + "Resetting fence because no associated celery tasks were found: " + f"cc_pair={cc_pair_id} " + f"fence={fence_key}" + ) + + redis_connector.delete.reset() + return diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index 7c2d0d59064..94ca026b190 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -342,6 +342,9 @@ class OnyxRedisSignals: BLOCK_PRUNING = "signal:block_pruning" BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences" BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table" + BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES = ( + "signal:block_validate_connector_deletion_fences" + ) class OnyxRedisConstants: diff --git a/backend/onyx/redis/redis_connector_delete.py b/backend/onyx/redis/redis_connector_delete.py index 04eb459b0d5..d475c2545f7 100644 --- a/backend/onyx/redis/redis_connector_delete.py +++ b/backend/onyx/redis/redis_connector_delete.py @@ -33,6 +33,12 @@ class RedisConnectorDelete: FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence" TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset" + # used to signal the overall workflow is still active + # it's impossible to get the exact state of the system at a single point in time + # so we need a signal with a TTL to bridge gaps in our checks + ACTIVE_PREFIX = PREFIX + "_active" + ACTIVE_TTL = 3600 + def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None: self.tenant_id: str | None = tenant_id self.id = id @@ -41,6 +47,8 @@ class RedisConnectorDelete: self.fence_key: str = f"{self.FENCE_PREFIX}_{id}" self.taskset_key = f"{self.TASKSET_PREFIX}_{id}" + self.active_key = f"{self.ACTIVE_PREFIX}_{id}" + def taskset_clear(self) -> None: self.redis.delete(self.taskset_key) @@ -77,6 +85,20 @@ class RedisConnectorDelete: self.redis.set(self.fence_key, payload.model_dump_json()) self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) + def set_active(self) -> None: + """This sets a signal to keep the permissioning flow from getting cleaned up within + the expiration time. + + The slack in timing is needed to avoid race conditions where simply checking + the celery queue and task status could result in race conditions.""" + self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL) + + def active(self) -> bool: + if self.redis.exists(self.active_key): + return True + + return False + def _generate_task_id(self) -> str: # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" # we prefix the task id so it's easier to keep track of who created the task @@ -141,6 +163,7 @@ class RedisConnectorDelete: def reset(self) -> None: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) + self.redis.delete(self.active_key) self.redis.delete(self.taskset_key) self.redis.delete(self.fence_key) @@ -153,6 +176,9 @@ class RedisConnectorDelete: @staticmethod def reset_all(r: redis.Redis) -> None: """Deletes all redis values for all connectors""" + for key in r.scan_iter(RedisConnectorDelete.ACTIVE_PREFIX + "*"): + r.delete(key) + for key in r.scan_iter(RedisConnectorDelete.TASKSET_PREFIX + "*"): r.delete(key)