mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-12 14:12:53 +02:00
validate connector deletion (#4108)
* validate connector deletion * fixes --------- Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
@ -92,7 +92,8 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
|
||||
|
||||
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""This is a redis specific way to build a list of tasks in a queue.
|
||||
"""This is a redis specific way to build a list of tasks in a queue and return them
|
||||
as a set.
|
||||
|
||||
This helps us read the queue once and then efficiently look for missing tasks
|
||||
in the queue.
|
||||
|
@ -8,16 +8,21 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
@ -109,6 +114,7 @@ def check_for_connector_deletion_task(
|
||||
) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@ -120,6 +126,21 @@ def check_for_connector_deletion_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
try:
|
||||
validate_connector_deletion_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Exception while validating connector deletion fences"
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
|
||||
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@ -243,6 +264,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
redis_connector.delete.set_active()
|
||||
fence_payload = RedisConnectorDeletePayload(
|
||||
num_tasks=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
@ -475,3 +497,171 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
|
||||
|
||||
def validate_connector_deletion_fences(
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
# building lookup table can be expensive, so we won't bother
|
||||
# validating until the queue is small
|
||||
CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024
|
||||
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN:
|
||||
return
|
||||
|
||||
queued_upsert_tasks = celery_get_queued_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
|
||||
# validate all existing connector deletion jobs
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
validate_connector_deletion_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
queued_upsert_tasks,
|
||||
r,
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
|
||||
queued_tasks: the celery queue of lightweight permission sync tasks
|
||||
reserved_tasks: prefetched tasks for sync task generator
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_connector_deletion_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.delete.fenced:
|
||||
return
|
||||
|
||||
# in the cloud, the payload format may have changed ...
|
||||
# it's a little sloppy, but just reset the fence for now if that happens
|
||||
# TODO: add intentional cleanup/abort logic
|
||||
try:
|
||||
payload = redis_connector.delete.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_connector_deletion_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
# look up every task in the current taskset in the celery queue
|
||||
# every entry in the taskset should have an associated entry in the celery task queue
|
||||
# because we get the celery tasks first, the entries in our own permissions taskset
|
||||
# should be roughly a subset of the tasks in celery
|
||||
|
||||
# this check isn't very exact, but should be sufficient over a period of time
|
||||
# A single successful check over some number of attempts is sufficient.
|
||||
|
||||
# TODO: if the number of tasks in celery is much lower than than the taskset length
|
||||
# we might be able to shortcut the lookup since by definition some of the tasks
|
||||
# must not exist in celery.
|
||||
|
||||
tasks_scanned = 0
|
||||
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
|
||||
|
||||
for member in r.sscan_iter(redis_connector.delete.taskset_key):
|
||||
tasks_scanned += 1
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_connector_deletion_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
if tasks_scanned > 0 and tasks_not_in_celery == 0:
|
||||
redis_connector.delete.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector.delete.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
task_logger.warning(
|
||||
"validate_connector_deletion_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
@ -342,6 +342,9 @@ class OnyxRedisSignals:
|
||||
BLOCK_PRUNING = "signal:block_pruning"
|
||||
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
|
||||
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
|
||||
BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES = (
|
||||
"signal:block_validate_connector_deletion_fences"
|
||||
)
|
||||
|
||||
|
||||
class OnyxRedisConstants:
|
||||
|
@ -33,6 +33,12 @@ class RedisConnectorDelete:
|
||||
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
@ -41,6 +47,8 @@ class RedisConnectorDelete:
|
||||
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
|
||||
@ -77,6 +85,20 @@ class RedisConnectorDelete:
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the permissioning flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _generate_task_id(self) -> str:
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
@ -141,6 +163,7 @@ class RedisConnectorDelete:
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@ -153,6 +176,9 @@ class RedisConnectorDelete:
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorDelete.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDelete.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
Reference in New Issue
Block a user