From 2dd51230edb1397ee01596b5c451c29a74aaddf6 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Mon, 16 Dec 2024 16:55:58 -0800 Subject: [PATCH] clear indexing fences with no celery tasks queued (#3482) * allow beat tasks to expire. it isn't important that they all run * validate fences are in a good state and cancel/fail them if not * add function timings for important beat tasks * optimize lookups, add lots of comments * review changes --------- Co-authored-by: Richard Kuo Co-authored-by: Richard Kuo (Danswer) --- .../onyx/background/celery/celery_redis.py | 24 ++ .../background/celery/tasks/beat_schedule.py | 43 +++- .../background/celery/tasks/indexing/tasks.py | 225 +++++++++++++++++- .../background/celery/tasks/vespa/tasks.py | 11 +- backend/onyx/configs/constants.py | 4 + backend/onyx/redis/redis_connector_index.py | 26 ++ 6 files changed, 318 insertions(+), 15 deletions(-) diff --git a/backend/onyx/background/celery/celery_redis.py b/backend/onyx/background/celery/celery_redis.py index 9f879bbbe..d438e5957 100644 --- a/backend/onyx/background/celery/celery_redis.py +++ b/backend/onyx/background/celery/celery_redis.py @@ -1,4 +1,6 @@ # These are helper objects for tracking the keys we need to write in redis +import json +from typing import Any from typing import cast from redis import Redis @@ -23,3 +25,25 @@ def celery_get_queue_length(queue: str, r: Redis) -> int: total_length += cast(int, length) return total_length + + +def celery_find_task(task_id: str, queue: str, r: Redis) -> int: + """This is a redis specific way to find a task for a particular queue in redis. + It is priority aware and knows how to look through the multiple redis lists + used to implement task prioritization. + This operation is not atomic. + + This is a linear search O(n) ... so be careful using it when the task queues can be larger. + + Returns true if the id is in the queue, False if not. + """ + for priority in range(len(OnyxCeleryPriority)): + queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue + + tasks = cast(list[bytes], r.lrange(queue_name, 0, -1)) + for task in tasks: + task_dict: dict[str, Any] = json.loads(task.decode("utf-8")) + if task_dict.get("headers", {}).get("id") == task_id: + return True + + return False diff --git a/backend/onyx/background/celery/tasks/beat_schedule.py b/backend/onyx/background/celery/tasks/beat_schedule.py index c3d45d616..d00ceefcf 100644 --- a/backend/onyx/background/celery/tasks/beat_schedule.py +++ b/backend/onyx/background/celery/tasks/beat_schedule.py @@ -4,55 +4,80 @@ from typing import Any from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryTask - +# we set expires because it isn't necessary to queue up these tasks +# it's only important that they run relatively regularly tasks_to_schedule = [ { "name": "check-for-vespa-sync", "task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK, "schedule": timedelta(seconds=20), - "options": {"priority": OnyxCeleryPriority.HIGH}, + "options": { + "priority": OnyxCeleryPriority.HIGH, + "expires": 60, + }, }, { "name": "check-for-connector-deletion", "task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION, "schedule": timedelta(seconds=20), - "options": {"priority": OnyxCeleryPriority.HIGH}, + "options": { + "priority": OnyxCeleryPriority.HIGH, + "expires": 60, + }, }, { "name": "check-for-indexing", "task": OnyxCeleryTask.CHECK_FOR_INDEXING, "schedule": timedelta(seconds=15), - "options": {"priority": OnyxCeleryPriority.HIGH}, + "options": { + "priority": OnyxCeleryPriority.HIGH, + "expires": 60, + }, }, { "name": "check-for-prune", "task": OnyxCeleryTask.CHECK_FOR_PRUNING, "schedule": timedelta(seconds=15), - "options": {"priority": OnyxCeleryPriority.HIGH}, + "options": { + "priority": OnyxCeleryPriority.HIGH, + "expires": 60, + }, }, { "name": "kombu-message-cleanup", "task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK, "schedule": timedelta(seconds=3600), - "options": {"priority": OnyxCeleryPriority.LOWEST}, + "options": { + "priority": OnyxCeleryPriority.LOWEST, + "expires": 60, + }, }, { "name": "monitor-vespa-sync", "task": OnyxCeleryTask.MONITOR_VESPA_SYNC, "schedule": timedelta(seconds=5), - "options": {"priority": OnyxCeleryPriority.HIGH}, + "options": { + "priority": OnyxCeleryPriority.HIGH, + "expires": 60, + }, }, { "name": "check-for-doc-permissions-sync", "task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC, "schedule": timedelta(seconds=30), - "options": {"priority": OnyxCeleryPriority.HIGH}, + "options": { + "priority": OnyxCeleryPriority.HIGH, + "expires": 60, + }, }, { "name": "check-for-external-group-sync", "task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC, "schedule": timedelta(seconds=20), - "options": {"priority": OnyxCeleryPriority.HIGH}, + "options": { + "priority": OnyxCeleryPriority.HIGH, + "expires": 60, + }, }, ] diff --git a/backend/onyx/background/celery/tasks/indexing/tasks.py b/backend/onyx/background/celery/tasks/indexing/tasks.py index 8530a8e76..62be3e609 100644 --- a/backend/onyx/background/celery/tasks/indexing/tasks.py +++ b/backend/onyx/background/celery/tasks/indexing/tasks.py @@ -1,7 +1,9 @@ +import time from datetime import datetime from datetime import timezone from http import HTTPStatus from time import sleep +from typing import Any import redis import sentry_sdk @@ -15,6 +17,7 @@ from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.background.celery.apps.app_base import task_logger +from onyx.background.celery.celery_redis import celery_find_task from onyx.background.indexing.job_client import SimpleJobClient from onyx.background.indexing.run_indexing import run_indexing_entrypoint from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP @@ -26,6 +29,7 @@ from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks +from onyx.configs.constants import OnyxRedisSignals from onyx.db.connector import mark_ccpair_with_indexing_trigger from onyx.db.connector_credential_pair import fetch_connector_credential_pairs from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id @@ -162,11 +166,19 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[ bind=True, ) def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: + """a lightweight task used to kick off indexing tasks. + Occcasionally does some validation of existing state to clear up error conditions""" + time_start = time.monotonic() + tasks_created = 0 locked = False - r = get_redis_client(tenant_id=tenant_id) + redis_client = get_redis_client(tenant_id=tenant_id) - lock_beat: RedisLock = r.lock( + # we need to use celery's redis client to access its redis data + # (which lives on a different db number) + redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore + + lock_beat: RedisLock = redis_client.lock( OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK, timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) @@ -271,7 +283,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: search_settings_instance, reindex, db_session, - r, + redis_client, tenant_id, ) if attempt_id: @@ -286,7 +298,9 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: # Fail any index attempts in the DB that don't have fences # This shouldn't ever happen! with get_session_with_tenant(tenant_id) as db_session: - unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r) + unfenced_attempt_ids = get_unfenced_index_attempt_ids( + db_session, redis_client + ) for attempt_id in unfenced_attempt_ids: lock_beat.reacquire() @@ -304,6 +318,22 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: mark_attempt_failed( attempt.id, db_session, failure_reason=failure_reason ) + + # we want to run this less frequently than the overall task + if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES): + # clear any indexing fences that don't have associated celery tasks in progress + # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker), + # or be currently executing + try: + task_logger.info("Validating indexing fences...") + validate_indexing_fences( + tenant_id, self.app, redis_client, redis_client_celery, lock_beat + ) + except Exception: + task_logger.exception("Exception while validating indexing fences") + + redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60) + except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." @@ -320,9 +350,190 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: f"tenant={tenant_id}" ) + time_elapsed = time.monotonic() - time_start + task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}") return tasks_created +def validate_indexing_fences( + tenant_id: str | None, + celery_app: Celery, + r: Redis, + r_celery: Redis, + lock_beat: RedisLock, +) -> None: + reserved_indexing_tasks: set[str] = set() + active_indexing_tasks: set[str] = set() + indexing_worker_names: list[str] = [] + + # filter for and create an indexing specific inspect object + inspect = celery_app.control.inspect() + workers: dict[str, Any] = inspect.ping() # type: ignore + if not workers: + raise ValueError("No workers found!") + + for worker_name in list(workers.keys()): + if "indexing" in worker_name: + indexing_worker_names.append(worker_name) + + if len(indexing_worker_names) == 0: + raise ValueError("No indexing workers found!") + + inspect_indexing = celery_app.control.inspect(destination=indexing_worker_names) + + # NOTE: each dict entry is a map of worker name to a list of tasks + # we want sets for reserved task and active task id's to optimize + # subsequent validation lookups + + # get the list of reserved tasks + reserved_tasks: dict[str, list] | None = inspect_indexing.reserved() # type: ignore + if reserved_tasks is None: + raise ValueError("inspect_indexing.reserved() returned None!") + + for _, task_list in reserved_tasks.items(): + for task in task_list: + reserved_indexing_tasks.add(task["id"]) + + # get the list of active tasks + active_tasks: dict[str, list] | None = inspect_indexing.active() # type: ignore + if active_tasks is None: + raise ValueError("inspect_indexing.active() returned None!") + + for _, task_list in active_tasks.items(): + for task in task_list: + active_indexing_tasks.add(task["id"]) + + # validate all existing indexing jobs + for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"): + lock_beat.reacquire() + with get_session_with_tenant(tenant_id) as db_session: + validate_indexing_fence( + tenant_id, + key_bytes, + reserved_indexing_tasks, + active_indexing_tasks, + r_celery, + db_session, + ) + return + + +def validate_indexing_fence( + tenant_id: str | None, + key_bytes: bytes, + reserved_tasks: set[str], + active_tasks: set[str], + r_celery: Redis, + db_session: Session, +) -> None: + """Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist. + This can happen if the indexing worker hard crashes or is terminated. + Being in this bad state means the fence will never clear without help, so this function + gives the help. + + How this works: + 1. Active signal is renewed with a 5 minute TTL + 1.1 When the fence is created + 1.2. When the task is seen in the redis queue + 1.3. When the task is seen in the reserved or active list for a worker + 2. The TTL allows us to get through the transitions on fence startup + and when the task starts executing. + + More TTL clarification: it is seemingly impossible to exactly query Celery for + whether a task is in the queue or currently executing. + 1. An unknown task id is always returned as state PENDING. + 2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task + and the time it actually starts on the worker. + """ + # if the fence doesn't exist, there's nothing to do + fence_key = key_bytes.decode("utf-8") + composite_id = RedisConnector.get_id_from_fence_key(fence_key) + if composite_id is None: + task_logger.warning( + f"validate_indexing_fence - could not parse composite_id from {fence_key}" + ) + return + + # parse out metadata and initialize the helper class with it + parts = composite_id.split("/") + if len(parts) != 2: + return + + cc_pair_id = int(parts[0]) + search_settings_id = int(parts[1]) + + redis_connector = RedisConnector(tenant_id, cc_pair_id) + redis_connector_index = redis_connector.new_index(search_settings_id) + if not redis_connector_index.fenced: + return + + payload = redis_connector_index.payload + if not payload: + return + + # OK, there's actually something for us to validate + + if payload.celery_task_id is None: + # the fence is just barely set up. + if redis_connector_index.active(): + return + + # it would be odd to get here as there isn't that much that can go wrong during + # initial fence setup, but it's still worth making sure we can recover + logger.info( + f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}" + ) + redis_connector_index.reset() + return + + found = celery_find_task( + payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery + ) + if found: + # the celery task exists in the redis queue + redis_connector_index.set_active() + return + + if payload.celery_task_id in reserved_tasks: + # the celery task was prefetched and is reserved within the indexing worker + redis_connector_index.set_active() + return + + if payload.celery_task_id in active_tasks: + # the celery task is active (aka currently executing) + redis_connector_index.set_active() + return + + # we may want to enable this check if using the active task list somehow isn't good enough + # if redis_connector_index.generator_locked(): + # logger.info(f"{payload.celery_task_id} is currently executing.") + + # we didn't find any direct indication that associated celery tasks exist, but they still might be there + # due to gaps in our ability to check states during transitions + # Rely on the active signal (which has a duration that allows us to bridge those gaps) + if redis_connector_index.active(): + return + + # celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up. + logger.warning( + f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: fence={fence_key}" + ) + if payload.index_attempt_id: + try: + mark_attempt_failed( + payload.index_attempt_id, + db_session, + "validate_indexing_fence - Canceling index attempt due to missing celery tasks", + ) + except Exception: + logger.exception( + "validate_indexing_fence - Exception while marking index attempt as failed." + ) + + redis_connector_index.reset() + return + + def _should_index( cc_pair: ConnectorCredentialPair, last_index: IndexAttempt | None, @@ -469,6 +680,7 @@ def try_creating_indexing_task( celery_task_id=None, ) + redis_connector_index.set_active() redis_connector_index.set_fence(payload) # create the index attempt for tracking purposes @@ -502,6 +714,8 @@ def try_creating_indexing_task( raise RuntimeError("send_task for connector_indexing_proxy_task failed.") # now fill out the fence with the rest of the data + redis_connector_index.set_active() + payload.index_attempt_id = index_attempt_id payload.celery_task_id = result.id redis_connector_index.set_fence(payload) @@ -642,7 +856,7 @@ def connector_indexing_proxy_task( if job.process: exit_code = job.process.exitcode - # seeing non-deterministic behavior where spawned tasks occasionally return exit code 1 + # seeing odd behavior where spawned tasks usually return exit code 1 in the cloud, # even though logging clearly indicates that they completed successfully # to work around this, we ignore the job error state if the completion signal is OK status_int = redis_connector_index.get_completion() @@ -872,6 +1086,7 @@ def connector_indexing_task( f"search_settings={search_settings_id}" ) + # This is where the heavy/real work happens run_indexing_entrypoint( index_attempt_id, tenant_id, diff --git a/backend/onyx/background/celery/tasks/vespa/tasks.py b/backend/onyx/background/celery/tasks/vespa/tasks.py index 704378a74..ba59ff4b1 100644 --- a/backend/onyx/background/celery/tasks/vespa/tasks.py +++ b/backend/onyx/background/celery/tasks/vespa/tasks.py @@ -1,3 +1,4 @@ +import time import traceback from datetime import datetime from datetime import timezone @@ -89,10 +90,11 @@ logger = setup_logger() def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None: """Runs periodically to check if any document needs syncing. Generates sets of tasks for Celery if syncing is needed.""" + time_start = time.monotonic() r = get_redis_client(tenant_id=tenant_id) - lock_beat = r.lock( + lock_beat: RedisLock = r.lock( OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK, timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) @@ -161,6 +163,10 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None: if lock_beat.owned(): lock_beat.release() + time_elapsed = time.monotonic() - time_start + task_logger.info(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}") + return + def try_generate_stale_document_sync_tasks( celery_app: Celery, @@ -730,6 +736,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: Returns True if the task actually did work, False if it exited early to prevent overlap """ + time_start = time.monotonic() r = get_redis_client(tenant_id=tenant_id) lock_beat: RedisLock = r.lock( @@ -824,6 +831,8 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: if lock_beat.owned(): lock_beat.release() + time_elapsed = time.monotonic() - time_start + task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}") return True diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index 3899ce860..72db3c7ca 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -274,6 +274,10 @@ class OnyxRedisLocks: SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot" +class OnyxRedisSignals: + VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences" + + class OnyxCeleryPriority(int, Enum): HIGHEST = 0 HIGH = auto() diff --git a/backend/onyx/redis/redis_connector_index.py b/backend/onyx/redis/redis_connector_index.py index 40b194af0..7314da3c3 100644 --- a/backend/onyx/redis/redis_connector_index.py +++ b/backend/onyx/redis/redis_connector_index.py @@ -31,6 +31,10 @@ class RedisConnectorIndex: TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate + # used to signal the overall workflow is still active + # it's difficult to prevent + ACTIVE_PREFIX = PREFIX + "_active" + def __init__( self, tenant_id: str | None, @@ -54,6 +58,7 @@ class RedisConnectorIndex: f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}" ) self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}" + self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}" @classmethod def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str: @@ -107,6 +112,26 @@ class RedisConnectorIndex: # 10 minute TTL is good. self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600) + def set_active(self) -> None: + """This sets a signal to keep the indexing flow from getting cleaned up within + the expiration time. + + The slack in timing is needed to avoid race conditions where simply checking + the celery queue and task status could result in race conditions.""" + self.redis.set(self.active_key, 0, ex=300) + + def active(self) -> bool: + if self.redis.exists(self.active_key): + return True + + return False + + def generator_locked(self) -> bool: + if self.redis.exists(self.generator_lock_key): + return True + + return False + def set_generator_complete(self, payload: int | None) -> None: if not payload: self.redis.delete(self.generator_complete_key) @@ -138,6 +163,7 @@ class RedisConnectorIndex: return status def reset(self) -> None: + self.redis.delete(self.active_key) self.redis.delete(self.generator_lock_key) self.redis.delete(self.generator_progress_key) self.redis.delete(self.generator_complete_key)