clear indexing fences with no celery tasks queued (#3482)

* allow beat tasks to expire. it isn't important that they all run

* validate fences are in a good state and cancel/fail them if not

* add function timings for important beat tasks

* optimize lookups, add lots of comments

* review changes

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer 2024-12-16 16:55:58 -08:00 committed by GitHub
parent 8b249cbe63
commit 2dd51230ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 318 additions and 15 deletions

View File

@ -1,4 +1,6 @@
# These are helper objects for tracking the keys we need to write in redis
import json
from typing import Any
from typing import cast
from redis import Redis
@ -23,3 +25,25 @@ def celery_get_queue_length(queue: str, r: Redis) -> int:
total_length += cast(int, length)
return total_length
def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
"""This is a redis specific way to find a task for a particular queue in redis.
It is priority aware and knows how to look through the multiple redis lists
used to implement task prioritization.
This operation is not atomic.
This is a linear search O(n) ... so be careful using it when the task queues can be larger.
Returns true if the id is in the queue, False if not.
"""
for priority in range(len(OnyxCeleryPriority)):
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
for task in tasks:
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
if task_dict.get("headers", {}).get("id") == task_id:
return True
return False

View File

@ -4,55 +4,80 @@ from typing import Any
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
# we set expires because it isn't necessary to queue up these tasks
# it's only important that they run relatively regularly
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
{
"name": "check-for-prune",
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=15),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
{
"name": "kombu-message-cleanup",
"task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
"schedule": timedelta(seconds=3600),
"options": {"priority": OnyxCeleryPriority.LOWEST},
"options": {
"priority": OnyxCeleryPriority.LOWEST,
"expires": 60,
},
},
{
"name": "monitor-vespa-sync",
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
{
"name": "check-for-external-group-sync",
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {"priority": OnyxCeleryPriority.HIGH},
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
},
},
]

View File

@ -1,7 +1,9 @@
import time
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import Any
import redis
import sentry_sdk
@ -15,6 +17,7 @@ from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
@ -26,6 +29,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
@ -162,11 +166,19 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
bind=True,
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
time_start = time.monotonic()
tasks_created = 0
locked = False
r = get_redis_client(tenant_id=tenant_id)
redis_client = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@ -271,7 +283,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
search_settings_instance,
reindex,
db_session,
r,
redis_client,
tenant_id,
)
if attempt_id:
@ -286,7 +298,9 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
db_session, redis_client
)
for attempt_id in unfenced_attempt_ids:
lock_beat.reacquire()
@ -304,6 +318,22 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
# we want to run this less frequently than the overall task
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
# clear any indexing fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
task_logger.info("Validating indexing fences...")
validate_indexing_fences(
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
)
except Exception:
task_logger.exception("Exception while validating indexing fences")
redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@ -320,9 +350,190 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"tenant={tenant_id}"
)
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
return tasks_created
def validate_indexing_fences(
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_indexing_tasks: set[str] = set()
active_indexing_tasks: set[str] = set()
indexing_worker_names: list[str] = []
# filter for and create an indexing specific inspect object
inspect = celery_app.control.inspect()
workers: dict[str, Any] = inspect.ping() # type: ignore
if not workers:
raise ValueError("No workers found!")
for worker_name in list(workers.keys()):
if "indexing" in worker_name:
indexing_worker_names.append(worker_name)
if len(indexing_worker_names) == 0:
raise ValueError("No indexing workers found!")
inspect_indexing = celery_app.control.inspect(destination=indexing_worker_names)
# NOTE: each dict entry is a map of worker name to a list of tasks
# we want sets for reserved task and active task id's to optimize
# subsequent validation lookups
# get the list of reserved tasks
reserved_tasks: dict[str, list] | None = inspect_indexing.reserved() # type: ignore
if reserved_tasks is None:
raise ValueError("inspect_indexing.reserved() returned None!")
for _, task_list in reserved_tasks.items():
for task in task_list:
reserved_indexing_tasks.add(task["id"])
# get the list of active tasks
active_tasks: dict[str, list] | None = inspect_indexing.active() # type: ignore
if active_tasks is None:
raise ValueError("inspect_indexing.active() returned None!")
for _, task_list in active_tasks.items():
for task in task_list:
active_indexing_tasks.add(task["id"])
# validate all existing indexing jobs
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
reserved_indexing_tasks,
active_indexing_tasks,
r_celery,
db_session,
)
return
def validate_indexing_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
active_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. Active signal is renewed with a 5 minute TTL
1.1 When the fence is created
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved or active list for a worker
2. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
if redis_connector_index.active():
return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
if found:
# the celery task exists in the redis queue
redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
redis_connector_index.set_active()
return
if payload.celery_task_id in active_tasks:
# the celery task is active (aka currently executing)
redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# we didn't find any direct indication that associated celery tasks exist, but they still might be there
# due to gaps in our ability to check states during transitions
# Rely on the active signal (which has a duration that allows us to bridge those gaps)
if redis_connector_index.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: fence={fence_key}"
)
if payload.index_attempt_id:
try:
mark_attempt_failed(
payload.index_attempt_id,
db_session,
"validate_indexing_fence - Canceling index attempt due to missing celery tasks",
)
except Exception:
logger.exception(
"validate_indexing_fence - Exception while marking index attempt as failed."
)
redis_connector_index.reset()
return
def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
@ -469,6 +680,7 @@ def try_creating_indexing_task(
celery_task_id=None,
)
redis_connector_index.set_active()
redis_connector_index.set_fence(payload)
# create the index attempt for tracking purposes
@ -502,6 +714,8 @@ def try_creating_indexing_task(
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
redis_connector_index.set_active()
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
@ -642,7 +856,7 @@ def connector_indexing_proxy_task(
if job.process:
exit_code = job.process.exitcode
# seeing non-deterministic behavior where spawned tasks occasionally return exit code 1
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
# even though logging clearly indicates that they completed successfully
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
@ -872,6 +1086,7 @@ def connector_indexing_task(
f"search_settings={search_settings_id}"
)
# This is where the heavy/real work happens
run_indexing_entrypoint(
index_attempt_id,
tenant_id,

View File

@ -1,3 +1,4 @@
import time
import traceback
from datetime import datetime
from datetime import timezone
@ -89,10 +90,11 @@ logger = setup_logger()
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@ -161,6 +163,10 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
if lock_beat.owned():
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
return
def try_generate_stale_document_sync_tasks(
celery_app: Celery,
@ -730,6 +736,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
Returns True if the task actually did work, False if it exited early to prevent overlap
"""
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
@ -824,6 +831,8 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
if lock_beat.owned():
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
return True

View File

@ -274,6 +274,10 @@ class OnyxRedisLocks:
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
class OnyxRedisSignals:
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
class OnyxCeleryPriority(int, Enum):
HIGHEST = 0
HIGH = auto()

View File

@ -31,6 +31,10 @@ class RedisConnectorIndex:
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
# used to signal the overall workflow is still active
# it's difficult to prevent
ACTIVE_PREFIX = PREFIX + "_active"
def __init__(
self,
tenant_id: str | None,
@ -54,6 +58,7 @@ class RedisConnectorIndex:
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
)
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
@ -107,6 +112,26 @@ class RedisConnectorIndex:
# 10 minute TTL is good.
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
def set_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=300)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
def generator_locked(self) -> bool:
if self.redis.exists(self.generator_lock_key):
return True
return False
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)
@ -138,6 +163,7 @@ class RedisConnectorIndex:
return status
def reset(self) -> None:
self.redis.delete(self.active_key)
self.redis.delete(self.generator_lock_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)