Multiple cloud/indexing fixes (#3609)

* more debugging

* test reacquire outside of loop

* more logging

* move lock_beat test outside the try catch so that we don't worry about testing locks we never took

* use a larger scan_iter value for performance

* batch stale document sync batches

* add debug logging for a particular timeout issue

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer 2025-01-07 17:30:29 -08:00 committed by GitHub
parent eb916df139
commit 02f72a5c86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 202 additions and 73 deletions

View File

@ -44,11 +44,11 @@ def check_for_connector_deletion_task(
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:

View File

@ -102,11 +102,11 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:

View File

@ -102,11 +102,11 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)

View File

@ -63,6 +63,7 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
@ -204,6 +205,10 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
debug_tenants = {
"tenant_i-043470d740845ec56",
"tenant_82b497ce-88aa-4fbd-841a-92cae43529c8",
}
time_start = time.monotonic()
tasks_created = 0
@ -219,11 +224,11 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
locked = True
# check for search settings swap
@ -246,15 +251,25 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)
# gather cc_pair_ids
lock_beat.reacquire()
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# kick off index attempts
for cc_pair_id in cc_pair_ids:
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing cc_pair lock: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
@ -331,14 +346,33 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)
tasks_created += 1
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing unfenced lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
db_session, redis_client
)
for attempt_id in unfenced_attempt_ids:
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing unfenced attempt id lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
attempt = get_index_attempt(db_session, attempt_id)
@ -356,9 +390,18 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
attempt.id, db_session, failure_reason=failure_reason
)
# debugging logic - remove after we're done
if tenant_id in debug_tenants:
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
task_logger.info(
f"check_for_indexing validate fences lock: "
f"tenant={tenant_id} "
f"ttl={ttl}"
)
lock_beat.reacquire()
# we want to run this less frequently than the overall task
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
lock_beat.reacquire()
# clear any indexing fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
@ -370,7 +413,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
task_logger.exception("Exception while validating indexing fences")
redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@ -405,7 +447,9 @@ def validate_indexing_fences(
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
for key_bytes in r.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_indexing_fence(

View File

@ -89,11 +89,11 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)

View File

@ -4,6 +4,7 @@ import traceback
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from typing import Any
from typing import cast
import httpx
@ -26,6 +27,7 @@ from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocument
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@ -70,6 +72,7 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
@ -103,14 +106,14 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
try_generate_stale_document_sync_tasks(
self.app, db_session, r, lock_beat, tenant_id
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
)
# region document set scan
@ -185,6 +188,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
def try_generate_stale_document_sync_tasks(
celery_app: Celery,
max_tasks: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
@ -215,11 +219,16 @@ def try_generate_stale_document_sync_tasks(
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
tasks_remaining = max_tasks
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
lock_beat.reacquire()
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
rc.set_skip_docs(docs_to_skip)
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
result = rc.generate_tasks(
tasks_remaining, celery_app, db_session, r, lock_beat, tenant_id
)
if result is None:
continue
@ -233,10 +242,19 @@ def try_generate_stale_document_sync_tasks(
)
total_tasks_generated += result[0]
tasks_remaining -= result[0]
if tasks_remaining <= 0:
break
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
)
if tasks_remaining <= 0:
task_logger.info(
f"RedisConnector.generate_tasks reached the task generation limit: "
f"total_tasks_generated={total_tasks_generated} max_tasks={max_tasks}"
)
else:
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
)
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
return total_tasks_generated
@ -275,7 +293,9 @@ def try_generate_document_set_sync_tasks(
)
# Add all documents that need to be updated into the queue
result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
result = rds.generate_tasks(
VESPA_SYNC_MAX_TASKS, celery_app, db_session, r, lock_beat, tenant_id
)
if result is None:
return None
@ -330,7 +350,9 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
result = rug.generate_tasks(
VESPA_SYNC_MAX_TASKS, celery_app, db_session, r, lock_beat, tenant_id
)
if result is None:
return None
@ -752,7 +774,7 @@ def monitor_ccpair_indexing_taskset(
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
@ -766,7 +788,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
time_start = time.monotonic()
timings: dict[str, float] = {}
timings: dict[str, Any] = {}
timings["start"] = time_start
r = get_redis_client(tenant_id=tenant_id)
@ -776,16 +798,15 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
task_logger.info("monitor_vespa_sync exiting due to overlap")
return False
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return None
try:
# print current queue lengths
phase_start = time.monotonic()
# we don't need every tenant polling redis for this info.
if not MULTI_TENANT or random.randint(1, 100) == 100:
if not MULTI_TENANT or random.randint(1, 10) == 10:
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
@ -826,6 +847,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
f"permissions_upsert={n_permissions_upsert} "
)
timings["queues"] = time.monotonic() - phase_start
timings["queues_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
# scan and monitor activity to completion
phase_start = time.monotonic()
@ -833,24 +855,37 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
timings["connector"] = time.monotonic() - phase_start
timings["connector_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
lock_beat.reacquire()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisConnectorDelete.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
lock_beat.reacquire()
timings["connector_deletion"] = time.monotonic() - phase_start
timings["connector_deletion_ttl"] = r.ttl(
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK
)
phase_start = time.monotonic()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
lock_beat.reacquire()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisDocumentSet.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
timings["document_set"] = time.monotonic() - phase_start
lock_beat.reacquire()
timings["documentset"] = time.monotonic() - phase_start
timings["documentset_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
lock_beat.reacquire()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisUserGroup.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
"onyx.background.celery.tasks.vespa.tasks",
"monitor_usergroup_taskset",
@ -858,29 +893,45 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
)
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["usergroup"] = time.monotonic() - phase_start
timings["usergroup_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
lock_beat.reacquire()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["pruning"] = time.monotonic() - phase_start
timings["pruning_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
lock_beat.reacquire()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["indexing"] = time.monotonic() - phase_start
timings["indexing_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
phase_start = time.monotonic()
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
lock_beat.reacquire()
lock_beat.reacquire()
for key_bytes in r.scan_iter(
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
timings["permissions"] = time.monotonic() - phase_start
timings["permissions_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@ -889,18 +940,10 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
if lock_beat.owned():
lock_beat.release()
else:
t = timings
task_logger.error(
"monitor_vespa_sync - Lock not owned on completion: "
f"tenant={tenant_id} "
f"queues={t.get('queues')} "
f"connector={t.get('connector')} "
f"connector_deletion={t.get('connector_deletion')} "
f"document_set={t.get('document_set')} "
f"usergroup={t.get('usergroup')} "
f"pruning={t.get('pruning')} "
f"indexing={t.get('indexing')} "
f"permissions={t.get('permissions')}"
f"timings={timings}"
)
redis_lock_dump(lock_beat, r)

View File

@ -280,6 +280,11 @@ try:
except ValueError:
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
# The maximum number of tasks that can be queued up to sync to Vespa in a single pass
VESPA_SYNC_MAX_TASKS = 1024
DB_YIELD_PER_DEFAULT = 64
#####
# Connector Configs
#####

View File

@ -7,6 +7,7 @@ from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@ -65,12 +66,20 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
def generate_tasks(
self,
max_tasks: int,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: RedisLock,
tenant_id: str | None,
) -> tuple[int, int] | None:
"""We can limit the number of tasks generated here, which is useful to prevent
one tenant from overwhelming the sync queue.
This works because the dirty state of a document is in the DB, so more docs
get picked up after the limited set of tasks is complete.
"""
last_lock_time = time.monotonic()
async_results = []
@ -84,7 +93,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
num_docs = 0
for doc in db_session.scalars(stmt).yield_per(1):
for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc = cast(Document, doc)
current_time = time.monotonic()
if current_time - last_lock_time >= (
@ -132,4 +141,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
async_results.append(result)
self.skip_docs.add(doc.id)
if len(async_results) >= max_tasks:
break
return len(async_results), num_docs

View File

@ -9,6 +9,7 @@ from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@ -98,7 +99,7 @@ class RedisConnectorDelete:
stmt = construct_document_select_for_connector_credential_pair(
cc_pair.connector_id, cc_pair.credential_id
)
for doc_temp in db_session.scalars(stmt).yield_per(1):
for doc_temp in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc: DbDocument = doc_temp
current_time = time.monotonic()
if current_time - last_lock_time >= (

View File

@ -13,6 +13,7 @@ from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorPermissionSyncPayload(BaseModel):
@ -68,7 +69,10 @@ class RedisConnectorPermissionSync:
def get_active_task_count(self) -> int:
"""Count of active permission sync tasks"""
count = 0
for _ in self.redis.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
for _ in self.redis.scan_iter(
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
count += 1
return count

View File

@ -7,6 +7,8 @@ from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorExternalGroupSyncPayload(BaseModel):
started: datetime | None
@ -63,7 +65,8 @@ class RedisConnectorExternalGroupSync:
"""Count of active external group syncing tasks"""
count = 0
for _ in self.redis.scan_iter(
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*"
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
count += 1
return count

View File

@ -12,6 +12,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorPrune:
@ -63,7 +64,9 @@ class RedisConnectorPrune:
def get_active_task_count(self) -> int:
"""Count of active pruning tasks"""
count = 0
for key in self.redis.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
for key in self.redis.scan_iter(
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
count += 1
return count

View File

@ -8,6 +8,7 @@ from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@ -50,17 +51,21 @@ class RedisDocumentSet(RedisObjectHelper):
def generate_tasks(
self,
max_tasks: int,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: RedisLock,
tenant_id: str | None,
) -> tuple[int, int] | None:
"""Max tasks is ignored for now until we can build the logic to mark the
document set up to date over multiple batches.
"""
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc = cast(Document, doc)
current_time = time.monotonic()
if current_time - last_lock_time >= (

View File

@ -82,6 +82,7 @@ class RedisObjectHelper(ABC):
@abstractmethod
def generate_tasks(
self,
max_tasks: int,
celery_app: Celery,
db_session: Session,
redis_client: Redis,

View File

@ -29,6 +29,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
SCAN_ITER_COUNT_DEFAULT = 4096
class TenantRedis(redis.Redis):
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None:
@ -116,6 +118,7 @@ class TenantRedis(redis.Redis):
"hexists",
"hset",
"hdel",
"ttl",
] # Regular methods that need simple prefixing
if item == "scan_iter":

View File

@ -8,6 +8,7 @@ from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@ -51,12 +52,16 @@ class RedisUserGroup(RedisObjectHelper):
def generate_tasks(
self,
max_tasks: int,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: RedisLock,
tenant_id: str | None,
) -> tuple[int, int] | None:
"""Max tasks is ignored for now until we can build the logic to mark the
user group up to date over multiple batches.
"""
last_lock_time = time.monotonic()
async_results = []
@ -73,7 +78,7 @@ class RedisUserGroup(RedisObjectHelper):
return 0, 0
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(1):
for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc = cast(Document, doc)
current_time = time.monotonic()
if current_time - last_lock_time >= (