danswer/backend/onyx/redis/redis_connector_prune.py
pablonyx a98dcbc7de
Update tenant logic (#4122)
* k

* k

* k

* quick nit

* nit
2025-02-26 03:53:46 +00:00

243 lines
8.5 KiB
Python

import time
from datetime import datetime
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorPrunePayload(BaseModel):
id: str
submitted: datetime
started: datetime | None
celery_task_id: str | None
class RedisConnectorPrune:
"""Manages interactions with redis for pruning tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectorpruning"
FENCE_PREFIX = f"{PREFIX}_fence"
# phase 1 - geneartor task and progress signals
GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpruning+generator
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # connectorpruning_generator_progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # connectorpruning_generator_complete
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = CELERY_PRUNING_LOCK_TIMEOUT * 2
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
self.tenant_id: str = tenant_id
self.id = id
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}"
self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}"
self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
def generator_clear(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
def get_remaining(self) -> int:
# todo: move into fence
remaining = cast(int, self.redis.scard(self.taskset_key))
return remaining
def get_active_task_count(self) -> int:
"""Count of active pruning tasks"""
count = 0
for _ in self.redis.sscan_iter(
OnyxRedisConstants.ACTIVE_FENCES,
RedisConnectorPrune.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
count += 1
return count
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
@property
def payload(self) -> RedisConnectorPrunePayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorPrunePayload.model_validate_json(cast(str, fence_str))
return payload
def set_fence(
self,
payload: RedisConnectorPrunePayload | None,
) -> None:
if not payload:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, payload.model_dump_json())
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def set_active(self) -> None:
"""This sets a signal to keep the permissioning flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
@property
def generator_complete(self) -> int | None:
"""the fence payload is an int representing the starting number of
pruning tasks to be processed ... just after the generator completes."""
fence_bytes = self.redis.get(self.generator_complete_key)
if fence_bytes is None:
return None
fence_int = int(cast(bytes, fence_bytes))
return fence_int
@generator_complete.setter
def generator_complete(self, payload: int | None) -> None:
"""Set the payload to an int to set the fence, otherwise if None it will
be deleted"""
if payload is None:
self.redis.delete(self.generator_complete_key)
return
self.redis.set(self.generator_complete_key, payload)
def generate_tasks(
self,
documents_to_prune: set[str],
celery_app: Celery,
db_session: Session,
lock: RedisLock | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=int(self.id),
)
if not cc_pair:
return None
for doc_id in documents_to_prune:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.subtask_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
self.redis.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=self.tenant_id,
),
queue=OnyxCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
ignore_result=True,
)
async_results.append(result)
return len(async_results)
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.active_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorPrune.TASKSET_PREFIX}_{id}"
r.srem(taskset_key, task_id)
return
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorPrune.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPrune.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPrune.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPrune.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
r.delete(key)