diff --git a/backend/ee/onyx/external_permissions/confluence/doc_sync.py b/backend/ee/onyx/external_permissions/confluence/doc_sync.py index 9805cdad6ee..d890d81c74c 100644 --- a/backend/ee/onyx/external_permissions/confluence/doc_sync.py +++ b/backend/ee/onyx/external_permissions/confluence/doc_sync.py @@ -13,6 +13,7 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence from onyx.connectors.confluence.utils import get_user_email_from_username__server from onyx.connectors.models import SlimDocument from onyx.db.models import ConnectorCredentialPair +from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() @@ -257,6 +258,7 @@ def _fetch_all_page_restrictions( slim_docs: list[SlimDocument], space_permissions_by_space_key: dict[str, ExternalAccess], is_cloud: bool, + callback: IndexingHeartbeatInterface | None, ) -> list[DocExternalAccess]: """ For all pages, if a page has restrictions, then use those restrictions. @@ -265,6 +267,12 @@ def _fetch_all_page_restrictions( document_restrictions: list[DocExternalAccess] = [] for slim_doc in slim_docs: + if callback: + if callback.should_stop(): + raise RuntimeError("confluence_doc_sync: Stop signal detected") + + callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1) + if slim_doc.perm_sync_data is None: raise ValueError( f"No permission sync data found for document {slim_doc.id}" @@ -334,7 +342,7 @@ def _fetch_all_page_restrictions( def confluence_doc_sync( - cc_pair: ConnectorCredentialPair, + cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None ) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres @@ -359,6 +367,12 @@ def confluence_doc_sync( logger.debug("Fetching all slim documents from confluence") for doc_batch in confluence_connector.retrieve_all_slim_documents(): logger.debug(f"Got {len(doc_batch)} slim documents from confluence") + if callback: + if callback.should_stop(): + raise RuntimeError("confluence_doc_sync: Stop signal detected") + + callback.progress("confluence_doc_sync", 1) + slim_docs.extend(doc_batch) logger.debug("Fetching all page restrictions for space") @@ -367,4 +381,5 @@ def confluence_doc_sync( slim_docs=slim_docs, space_permissions_by_space_key=space_permissions_by_space_key, is_cloud=is_cloud, + callback=callback, ) diff --git a/backend/ee/onyx/external_permissions/gmail/doc_sync.py b/backend/ee/onyx/external_permissions/gmail/doc_sync.py index 5860f401815..e93189f81d8 100644 --- a/backend/ee/onyx/external_permissions/gmail/doc_sync.py +++ b/backend/ee/onyx/external_permissions/gmail/doc_sync.py @@ -6,6 +6,7 @@ from onyx.access.models import ExternalAccess from onyx.connectors.gmail.connector import GmailConnector from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.db.models import ConnectorCredentialPair +from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() @@ -28,7 +29,7 @@ def _get_slim_doc_generator( def gmail_doc_sync( - cc_pair: ConnectorCredentialPair, + cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None ) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres @@ -44,6 +45,12 @@ def gmail_doc_sync( document_external_access: list[DocExternalAccess] = [] for slim_doc_batch in slim_doc_generator: for slim_doc in slim_doc_batch: + if callback: + if callback.should_stop(): + raise RuntimeError("gmail_doc_sync: Stop signal detected") + + callback.progress("gmail_doc_sync", 1) + if slim_doc.perm_sync_data is None: logger.warning(f"No permissions found for document {slim_doc.id}") continue diff --git a/backend/ee/onyx/external_permissions/google_drive/doc_sync.py b/backend/ee/onyx/external_permissions/google_drive/doc_sync.py index 098ede40e26..8347d42326d 100644 --- a/backend/ee/onyx/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/onyx/external_permissions/google_drive/doc_sync.py @@ -10,6 +10,7 @@ from onyx.connectors.google_utils.resources import get_drive_service from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.models import SlimDocument from onyx.db.models import ConnectorCredentialPair +from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() @@ -128,7 +129,7 @@ def _get_permissions_from_slim_doc( def gdrive_doc_sync( - cc_pair: ConnectorCredentialPair, + cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None ) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres @@ -146,6 +147,12 @@ def gdrive_doc_sync( document_external_accesses = [] for slim_doc_batch in slim_doc_generator: for slim_doc in slim_doc_batch: + if callback: + if callback.should_stop(): + raise RuntimeError("gdrive_doc_sync: Stop signal detected") + + callback.progress("gdrive_doc_sync", 1) + ext_access = _get_permissions_from_slim_doc( google_drive_connector=google_drive_connector, slim_doc=slim_doc, diff --git a/backend/ee/onyx/external_permissions/slack/doc_sync.py b/backend/ee/onyx/external_permissions/slack/doc_sync.py index ff1e237e337..63029467172 100644 --- a/backend/ee/onyx/external_permissions/slack/doc_sync.py +++ b/backend/ee/onyx/external_permissions/slack/doc_sync.py @@ -7,6 +7,7 @@ from onyx.connectors.slack.connector import get_channels from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries from onyx.connectors.slack.connector import SlackPollConnector from onyx.db.models import ConnectorCredentialPair +from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger @@ -14,7 +15,7 @@ logger = setup_logger() def _get_slack_document_ids_and_channels( - cc_pair: ConnectorCredentialPair, + cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None ) -> dict[str, list[str]]: slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config) slack_connector.load_credentials(cc_pair.credential.credential_json) @@ -24,6 +25,14 @@ def _get_slack_document_ids_and_channels( channel_doc_map: dict[str, list[str]] = {} for doc_metadata_batch in slim_doc_generator: for doc_metadata in doc_metadata_batch: + if callback: + if callback.should_stop(): + raise RuntimeError( + "_get_slack_document_ids_and_channels: Stop signal detected" + ) + + callback.progress("_get_slack_document_ids_and_channels", 1) + if doc_metadata.perm_sync_data is None: continue channel_id = doc_metadata.perm_sync_data["channel_id"] @@ -114,7 +123,7 @@ def _fetch_channel_permissions( def slack_doc_sync( - cc_pair: ConnectorCredentialPair, + cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None ) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres @@ -127,7 +136,7 @@ def slack_doc_sync( ) user_id_to_email_map = fetch_user_id_to_email_map(slack_client) channel_doc_map = _get_slack_document_ids_and_channels( - cc_pair=cc_pair, + cc_pair=cc_pair, callback=callback ) workspace_permissions = _fetch_workspace_permissions( user_id_to_email_map=user_id_to_email_map, diff --git a/backend/ee/onyx/external_permissions/sync_params.py b/backend/ee/onyx/external_permissions/sync_params.py index 1669dee6a05..8be6dcb2c0d 100644 --- a/backend/ee/onyx/external_permissions/sync_params.py +++ b/backend/ee/onyx/external_permissions/sync_params.py @@ -15,11 +15,13 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync from onyx.access.models import DocExternalAccess from onyx.configs.constants import DocumentSource from onyx.db.models import ConnectorCredentialPair +from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # Defining the input/output types for the sync functions DocSyncFuncType = Callable[ [ ConnectorCredentialPair, + IndexingHeartbeatInterface | None, ], list[DocExternalAccess], ] diff --git a/backend/onyx/background/celery/apps/app_base.py b/backend/onyx/background/celery/apps/app_base.py index c0661d2f104..d291996c1e6 100644 --- a/backend/onyx/background/celery/apps/app_base.py +++ b/backend/onyx/background/celery/apps/app_base.py @@ -198,7 +198,8 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: def wait_for_redis(sender: Any, **kwargs: Any) -> None: """Waits for redis to become ready subject to a hardcoded timeout. - Will raise WorkerShutdown to kill the celery worker if the timeout is reached.""" + Will raise WorkerShutdown to kill the celery worker if the timeout + is reached.""" r = get_redis_client(tenant_id=None) diff --git a/backend/onyx/background/celery/celery_redis.py b/backend/onyx/background/celery/celery_redis.py index 213388ac7c4..717af036675 100644 --- a/backend/onyx/background/celery/celery_redis.py +++ b/backend/onyx/background/celery/celery_redis.py @@ -91,6 +91,28 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int: return False +def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]: + """This is a redis specific way to build a list of tasks in a queue. + + This helps us read the queue once and then efficiently look for missing tasks + in the queue. + """ + + task_set: set[str] = set() + + for priority in range(len(OnyxCeleryPriority)): + queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue + + tasks = cast(list[bytes], r.lrange(queue_name, 0, -1)) + for task in tasks: + task_dict: dict[str, Any] = json.loads(task.decode("utf-8")) + task_id = task_dict.get("headers", {}).get("id") + if task_id: + task_set.add(task_id) + + return task_set + + def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]: """Returns a list of current workers containing name_filter, or all workers if name_filter is None. diff --git a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py index 81887dd9df5..1791e5585b3 100644 --- a/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/doc_permission_syncing/tasks.py @@ -3,13 +3,16 @@ from datetime import datetime from datetime import timedelta from datetime import timezone from time import sleep +from typing import cast from uuid import uuid4 from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded +from pydantic import ValidationError from redis import Redis +from redis.exceptions import LockError from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session @@ -22,6 +25,10 @@ from ee.onyx.external_permissions.sync_params import ( ) from onyx.access.models import DocExternalAccess from onyx.background.celery.apps.app_base import task_logger +from onyx.background.celery.celery_redis import celery_find_task +from onyx.background.celery.celery_redis import celery_get_queue_length +from onyx.background.celery.celery_redis import celery_get_queued_task_ids +from onyx.background.celery.celery_redis import celery_get_unacked_task_ids from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT @@ -32,6 +39,7 @@ from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks +from onyx.configs.constants import OnyxRedisSignals from onyx.db.connector import mark_cc_pair_as_permissions_synced from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.document import upsert_document_by_connector_credential_pair @@ -44,14 +52,19 @@ from onyx.db.models import ConnectorCredentialPair from onyx.db.sync_record import insert_sync_record from onyx.db.sync_record import update_sync_record_status from onyx.db.users import batch_add_ext_perm_user_if_not_exists +from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.redis.redis_connector import RedisConnector -from onyx.redis.redis_connector_doc_perm_sync import ( - RedisConnectorPermissionSyncPayload, -) +from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync +from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyncPayload from onyx.redis.redis_pool import get_redis_client +from onyx.redis.redis_pool import redis_lock_dump +from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT +from onyx.server.utils import make_short_id from onyx.utils.logger import doc_permission_sync_ctx +from onyx.utils.logger import LoggerContextVars from onyx.utils.logger import setup_logger + logger = setup_logger() @@ -105,7 +118,12 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b bind=True, ) def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None: + # TODO(rkuo): merge into check function after lookup table for fences is added + + # we need to use celery's redis client to access its redis data + # (which lives on a different db number) r = get_redis_client(tenant_id=tenant_id) + r_celery: Redis = self.app.broker_connection().channel().client # type: ignore lock_beat: RedisLock = r.lock( OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK, @@ -126,14 +144,32 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool if _is_external_doc_permissions_sync_due(cc_pair): cc_pair_ids_to_sync.append(cc_pair.id) + lock_beat.reacquire() for cc_pair_id in cc_pair_ids_to_sync: - tasks_created = try_creating_permissions_sync_task( + payload_id = try_creating_permissions_sync_task( self.app, cc_pair_id, r, tenant_id ) - if not tasks_created: + if not payload_id: continue - task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}") + task_logger.info( + f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}" + ) + + # we want to run this less frequently than the overall task + lock_beat.reacquire() + if not r.exists(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES): + # clear any permission fences that don't have associated celery tasks in progress + # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker), + # or be currently executing + try: + validate_permission_sync_fences(tenant_id, r, r_celery, lock_beat) + except Exception: + task_logger.exception( + "Exception while validating permission sync fences" + ) + + r.set(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=60) except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." @@ -152,13 +188,15 @@ def try_creating_permissions_sync_task( cc_pair_id: int, r: Redis, tenant_id: str | None, -) -> int | None: - """Returns an int if syncing is needed. The int represents the number of sync tasks generated. +) -> str | None: + """Returns a randomized payload id on success. Returns None if no syncing is required.""" - redis_connector = RedisConnector(tenant_id, cc_pair_id) - LOCK_TIMEOUT = 30 + payload_id: str | None = None + + redis_connector = RedisConnector(tenant_id, cc_pair_id) + lock: RedisLock = r.lock( DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks", timeout=LOCK_TIMEOUT, @@ -193,7 +231,13 @@ def try_creating_permissions_sync_task( ) # set a basic fence to start - payload = RedisConnectorPermissionSyncPayload(started=None, celery_task_id=None) + redis_connector.permissions.set_active() + payload = RedisConnectorPermissionSyncPayload( + id=make_short_id(), + submitted=datetime.now(timezone.utc), + started=None, + celery_task_id=None, + ) redis_connector.permissions.set_fence(payload) result = app.send_task( @@ -208,8 +252,11 @@ def try_creating_permissions_sync_task( ) # fill in the celery task id + redis_connector.permissions.set_active() payload.celery_task_id = result.id redis_connector.permissions.set_fence(payload) + + payload_id = payload.celery_task_id except Exception: task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}") return None @@ -217,7 +264,7 @@ def try_creating_permissions_sync_task( if lock.owned(): lock.release() - return 1 + return payload_id @shared_task( @@ -238,6 +285,8 @@ def connector_permission_sync_generator_task( This task assumes that the task has already been properly fenced """ + LoggerContextVars.reset() + doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get() doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id doc_permission_sync_ctx_dict["request_id"] = self.request.id @@ -325,12 +374,17 @@ def connector_permission_sync_generator_task( raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}") new_payload = RedisConnectorPermissionSyncPayload( + id=payload.id, + submitted=payload.submitted, started=datetime.now(timezone.utc), celery_task_id=payload.celery_task_id, ) redis_connector.permissions.set_fence(new_payload) - document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair) + callback = PermissionSyncCallback(redis_connector, lock, r) + document_external_accesses: list[DocExternalAccess] = doc_sync_func( + cc_pair, callback + ) task_logger.info( f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}" @@ -380,6 +434,8 @@ def update_external_document_permissions_task( connector_id: int, credential_id: int, ) -> bool: + start = time.monotonic() + document_external_access = DocExternalAccess.from_dict( serialized_doc_external_access ) @@ -409,16 +465,268 @@ def update_external_document_permissions_task( document_ids=[doc_id], ) - logger.debug( - f"Successfully synced postgres document permissions for {doc_id}" + elapsed = time.monotonic() - start + task_logger.info( + f"connector_id={connector_id} " + f"doc={doc_id} " + f"action=update_permissions " + f"elapsed={elapsed:.2f}" ) - return True except Exception: - logger.exception( - f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}" + task_logger.exception( + f"Exception in update_external_document_permissions_task: " + f"connector_id={connector_id} " + f"doc_id={doc_id}" ) return False + return True + + +def validate_permission_sync_fences( + tenant_id: str | None, + r: Redis, + r_celery: Redis, + lock_beat: RedisLock, +) -> None: + # building lookup table can be expensive, so we won't bother + # validating until the queue is small + PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024 + + queue_len = celery_get_queue_length( + OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery + ) + if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN: + return + + queued_upsert_tasks = celery_get_queued_task_ids( + OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery + ) + reserved_generator_tasks = celery_get_unacked_task_ids( + OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery + ) + + # validate all existing indexing jobs + for key_bytes in r.scan_iter( + RedisConnectorPermissionSync.FENCE_PREFIX + "*", + count=SCAN_ITER_COUNT_DEFAULT, + ): + lock_beat.reacquire() + validate_permission_sync_fence( + tenant_id, + key_bytes, + queued_upsert_tasks, + reserved_generator_tasks, + r, + r_celery, + ) + return + + +def validate_permission_sync_fence( + tenant_id: str | None, + key_bytes: bytes, + queued_tasks: set[str], + reserved_tasks: set[str], + r: Redis, + r_celery: Redis, +) -> None: + """Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist. + This can happen if the indexing worker hard crashes or is terminated. + Being in this bad state means the fence will never clear without help, so this function + gives the help. + + How this works: + 1. This function renews the active signal with a 5 minute TTL under the following conditions + 1.2. When the task is seen in the redis queue + 1.3. When the task is seen in the reserved / prefetched list + + 2. Externally, the active signal is renewed when: + 2.1. The fence is created + 2.2. The indexing watchdog checks the spawned task. + + 3. The TTL allows us to get through the transitions on fence startup + and when the task starts executing. + + More TTL clarification: it is seemingly impossible to exactly query Celery for + whether a task is in the queue or currently executing. + 1. An unknown task id is always returned as state PENDING. + 2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task + and the time it actually starts on the worker. + + queued_tasks: the celery queue of lightweight permission sync tasks + reserved_tasks: prefetched tasks for sync task generator + """ + # if the fence doesn't exist, there's nothing to do + fence_key = key_bytes.decode("utf-8") + cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) + if cc_pair_id_str is None: + task_logger.warning( + f"validate_permission_sync_fence - could not parse id from {fence_key}" + ) + return + + cc_pair_id = int(cc_pair_id_str) + # parse out metadata and initialize the helper class with it + redis_connector = RedisConnector(tenant_id, int(cc_pair_id)) + + # check to see if the fence/payload exists + if not redis_connector.permissions.fenced: + return + + # in the cloud, the payload format may have changed ... + # it's a little sloppy, but just reset the fence for now if that happens + # TODO: add intentional cleanup/abort logic + try: + payload = redis_connector.permissions.payload + except ValidationError: + task_logger.exception( + "validate_permission_sync_fence - " + "Resetting fence because fence schema is out of date: " + f"cc_pair={cc_pair_id} " + f"fence={fence_key}" + ) + + redis_connector.permissions.reset() + return + + if not payload: + return + + if not payload.celery_task_id: + return + + # OK, there's actually something for us to validate + + # either the generator task must be in flight or its subtasks must be + found = celery_find_task( + payload.celery_task_id, + OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, + r_celery, + ) + if found: + # the celery task exists in the redis queue + redis_connector.permissions.set_active() + return + + if payload.celery_task_id in reserved_tasks: + # the celery task was prefetched and is reserved within a worker + redis_connector.permissions.set_active() + return + + # look up every task in the current taskset in the celery queue + # every entry in the taskset should have an associated entry in the celery task queue + # because we get the celery tasks first, the entries in our own permissions taskset + # should be roughly a subset of the tasks in celery + + # this check isn't very exact, but should be sufficient over a period of time + # A single successful check over some number of attempts is sufficient. + + # TODO: if the number of tasks in celery is much lower than than the taskset length + # we might be able to shortcut the lookup since by definition some of the tasks + # must not exist in celery. + + tasks_scanned = 0 + tasks_not_in_celery = 0 # a non-zero number after completing our check is bad + + for member in r.sscan_iter(redis_connector.permissions.taskset_key): + tasks_scanned += 1 + + member_bytes = cast(bytes, member) + member_str = member_bytes.decode("utf-8") + if member_str in queued_tasks: + continue + + if member_str in reserved_tasks: + continue + + tasks_not_in_celery += 1 + + task_logger.info( + "validate_permission_sync_fence task check: " + f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}" + ) + + if tasks_not_in_celery == 0: + redis_connector.permissions.set_active() + return + + # we may want to enable this check if using the active task list somehow isn't good enough + # if redis_connector_index.generator_locked(): + # logger.info(f"{payload.celery_task_id} is currently executing.") + + # if we get here, we didn't find any direct indication that the associated celery tasks exist, + # but they still might be there due to gaps in our ability to check states during transitions + # Checking the active signal safeguards us against these transition periods + # (which has a duration that allows us to bridge those gaps) + if redis_connector.permissions.active(): + return + + # celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up. + task_logger.warning( + "validate_permission_sync_fence - " + "Resetting fence because no associated celery tasks were found: " + f"cc_pair={cc_pair_id} " + f"fence={fence_key}" + ) + + redis_connector.permissions.reset() + return + + +class PermissionSyncCallback(IndexingHeartbeatInterface): + PARENT_CHECK_INTERVAL = 60 + + def __init__( + self, + redis_connector: RedisConnector, + redis_lock: RedisLock, + redis_client: Redis, + ): + super().__init__() + self.redis_connector: RedisConnector = redis_connector + self.redis_lock: RedisLock = redis_lock + self.redis_client = redis_client + + self.started: datetime = datetime.now(timezone.utc) + self.redis_lock.reacquire() + + self.last_tag: str = "PermissionSyncCallback.__init__" + self.last_lock_reacquire: datetime = datetime.now(timezone.utc) + self.last_lock_monotonic = time.monotonic() + + def should_stop(self) -> bool: + if self.redis_connector.stop.fenced: + return True + + return False + + def progress(self, tag: str, amount: int) -> None: + try: + self.redis_connector.permissions.set_active() + + current_time = time.monotonic() + if current_time - self.last_lock_monotonic >= ( + CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4 + ): + self.redis_lock.reacquire() + self.last_lock_reacquire = datetime.now(timezone.utc) + self.last_lock_monotonic = time.monotonic() + + self.last_tag = tag + except LockError: + logger.exception( + f"PermissionSyncCallback - lock.reacquire exceptioned: " + f"lock_timeout={self.redis_lock.timeout} " + f"start={self.started} " + f"last_tag={self.last_tag} " + f"last_reacquired={self.last_lock_reacquire} " + f"now={datetime.now(timezone.utc)}" + ) + + redis_lock_dump(self.redis_lock, self.redis_client) + raise + """Monitoring CCPair permissions utils, called in monitor_vespa_sync""" @@ -444,20 +752,36 @@ def monitor_ccpair_permissions_taskset( if initial is None: return + try: + payload = redis_connector.permissions.payload + except ValidationError: + task_logger.exception( + "Permissions sync payload failed to validate. " + "Schema may have been updated." + ) + return + + if not payload: + return + remaining = redis_connector.permissions.get_remaining() task_logger.info( - f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}" + f"Permissions sync progress: " + f"cc_pair={cc_pair_id} " + f"id={payload.id} " + f"remaining={remaining} " + f"initial={initial}" ) if remaining > 0: return - payload: RedisConnectorPermissionSyncPayload | None = ( - redis_connector.permissions.payload + mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), payload.started) + task_logger.info( + f"Permissions sync finished: " + f"cc_pair={cc_pair_id} " + f"id={payload.id} " + f"num_synced={initial}" ) - start_time: datetime | None = payload.started if payload else None - - mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time) - task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}") update_sync_record_status( db_session=db_session, diff --git a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py index 42b108d13e9..85465f05a7e 100644 --- a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py @@ -1,3 +1,4 @@ +import time from datetime import datetime from datetime import timedelta from datetime import timezone @@ -9,6 +10,7 @@ from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis from redis.lock import Lock as RedisLock +from sqlalchemy.orm import Session from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source @@ -20,9 +22,12 @@ from ee.onyx.external_permissions.sync_params import ( GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC, ) from onyx.background.celery.apps.app_base import task_logger +from onyx.background.celery.celery_redis import celery_find_task +from onyx.background.celery.celery_redis import celery_get_unacked_task_ids from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT +from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues @@ -39,10 +44,12 @@ from onyx.db.models import ConnectorCredentialPair from onyx.db.sync_record import insert_sync_record from onyx.db.sync_record import update_sync_record_status from onyx.redis.redis_connector import RedisConnector +from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync from onyx.redis.redis_connector_ext_group_sync import ( RedisConnectorExternalGroupSyncPayload, ) from onyx.redis.redis_pool import get_redis_client +from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT from onyx.utils.logger import setup_logger logger = setup_logger() @@ -102,6 +109,10 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool: def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None: r = get_redis_client(tenant_id=tenant_id) + # we need to use celery's redis client to access its redis data + # (which lives on a different db number) + # r_celery: Redis = self.app.broker_connection().channel().client # type: ignore + lock_beat: RedisLock = r.lock( OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, @@ -136,6 +147,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool if _is_external_group_sync_due(cc_pair): cc_pair_ids_to_sync.append(cc_pair.id) + lock_beat.reacquire() for cc_pair_id in cc_pair_ids_to_sync: tasks_created = try_creating_external_group_sync_task( self.app, cc_pair_id, r, tenant_id @@ -144,6 +156,23 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool continue task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}") + + # we want to run this less frequently than the overall task + # lock_beat.reacquire() + # if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES): + # # clear any indexing fences that don't have associated celery tasks in progress + # # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker), + # # or be currently executing + # try: + # validate_external_group_sync_fences( + # tenant_id, self.app, r, r_celery, lock_beat + # ) + # except Exception: + # task_logger.exception( + # "Exception while validating external group sync fences" + # ) + + # r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60) except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." @@ -186,6 +215,12 @@ def try_creating_external_group_sync_task( redis_connector.external_group_sync.generator_clear() redis_connector.external_group_sync.taskset_clear() + payload = RedisConnectorExternalGroupSyncPayload( + submitted=datetime.now(timezone.utc), + started=None, + celery_task_id=None, + ) + custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}" result = app.send_task( @@ -199,11 +234,6 @@ def try_creating_external_group_sync_task( priority=OnyxCeleryPriority.HIGH, ) - payload = RedisConnectorExternalGroupSyncPayload( - started=datetime.now(timezone.utc), - celery_task_id=result.id, - ) - # create before setting fence to avoid race condition where the monitoring # task updates the sync record before it is created with get_session_with_tenant(tenant_id) as db_session: @@ -213,8 +243,8 @@ def try_creating_external_group_sync_task( sync_type=SyncType.EXTERNAL_GROUP, ) + payload.celery_task_id = result.id redis_connector.external_group_sync.set_fence(payload) - except Exception: task_logger.exception( f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}" @@ -241,7 +271,7 @@ def connector_external_group_sync_generator_task( tenant_id: str | None, ) -> None: """ - Permission sync task that handles external group syncing for a given connector credential pair + External group sync task for a given connector credential pair This task assumes that the task has already been properly fenced """ @@ -249,19 +279,59 @@ def connector_external_group_sync_generator_task( r = get_redis_client(tenant_id=tenant_id) + # this wait is needed to avoid a race condition where + # the primary worker sends the task and it is immediately executed + # before the primary worker can finalize the fence + start = time.monotonic() + while True: + if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT: + raise ValueError( + f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: " + f"fence={redis_connector.external_group_sync.fence_key}" + ) + + if not redis_connector.external_group_sync.fenced: # The fence must exist + raise ValueError( + f"connector_external_group_sync_generator_task - fence not found: " + f"fence={redis_connector.external_group_sync.fence_key}" + ) + + payload = redis_connector.external_group_sync.payload # The payload must exist + if not payload: + raise ValueError( + "connector_external_group_sync_generator_task: payload invalid or not found" + ) + + if payload.celery_task_id is None: + logger.info( + f"connector_external_group_sync_generator_task - Waiting for fence: " + f"fence={redis_connector.external_group_sync.fence_key}" + ) + time.sleep(1) + continue + + logger.info( + f"connector_external_group_sync_generator_task - Fence found, continuing...: " + f"fence={redis_connector.external_group_sync.fence_key}" + ) + break + lock: RedisLock = r.lock( OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX + f"_{redis_connector.id}", timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT, ) + acquired = lock.acquire(blocking=False) + if not acquired: + task_logger.warning( + f"External group sync task already running, exiting...: cc_pair={cc_pair_id}" + ) + return None + try: - acquired = lock.acquire(blocking=False) - if not acquired: - task_logger.warning( - f"External group sync task already running, exiting...: cc_pair={cc_pair_id}" - ) - return None + payload.started = datetime.now(timezone.utc) + redis_connector.external_group_sync.set_fence(payload) with get_session_with_tenant(tenant_id) as db_session: cc_pair = get_connector_credential_pair_from_id( @@ -330,3 +400,135 @@ def connector_external_group_sync_generator_task( redis_connector.external_group_sync.set_fence(None) if lock.owned(): lock.release() + + +def validate_external_group_sync_fences( + tenant_id: str | None, + celery_app: Celery, + r: Redis, + r_celery: Redis, + lock_beat: RedisLock, +) -> None: + reserved_sync_tasks = celery_get_unacked_task_ids( + OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery + ) + + # validate all existing indexing jobs + for key_bytes in r.scan_iter( + RedisConnectorExternalGroupSync.FENCE_PREFIX + "*", + count=SCAN_ITER_COUNT_DEFAULT, + ): + lock_beat.reacquire() + with get_session_with_tenant(tenant_id) as db_session: + validate_external_group_sync_fence( + tenant_id, + key_bytes, + reserved_sync_tasks, + r_celery, + db_session, + ) + return + + +def validate_external_group_sync_fence( + tenant_id: str | None, + key_bytes: bytes, + reserved_tasks: set[str], + r_celery: Redis, + db_session: Session, +) -> None: + """Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist. + This can happen if the indexing worker hard crashes or is terminated. + Being in this bad state means the fence will never clear without help, so this function + gives the help. + + How this works: + 1. This function renews the active signal with a 5 minute TTL under the following conditions + 1.2. When the task is seen in the redis queue + 1.3. When the task is seen in the reserved / prefetched list + + 2. Externally, the active signal is renewed when: + 2.1. The fence is created + 2.2. The indexing watchdog checks the spawned task. + + 3. The TTL allows us to get through the transitions on fence startup + and when the task starts executing. + + More TTL clarification: it is seemingly impossible to exactly query Celery for + whether a task is in the queue or currently executing. + 1. An unknown task id is always returned as state PENDING. + 2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task + and the time it actually starts on the worker. + """ + # if the fence doesn't exist, there's nothing to do + fence_key = key_bytes.decode("utf-8") + cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) + if cc_pair_id_str is None: + task_logger.warning( + f"validate_external_group_sync_fence - could not parse id from {fence_key}" + ) + return + + cc_pair_id = int(cc_pair_id_str) + + # parse out metadata and initialize the helper class with it + redis_connector = RedisConnector(tenant_id, int(cc_pair_id)) + + # check to see if the fence/payload exists + if not redis_connector.external_group_sync.fenced: + return + + payload = redis_connector.external_group_sync.payload + if not payload: + return + + # OK, there's actually something for us to validate + + if payload.celery_task_id is None: + # the fence is just barely set up. + # if redis_connector_index.active(): + # return + + # it would be odd to get here as there isn't that much that can go wrong during + # initial fence setup, but it's still worth making sure we can recover + logger.info( + "validate_external_group_sync_fence - " + f"Resetting fence in basic state without any activity: fence={fence_key}" + ) + redis_connector.external_group_sync.reset() + return + + found = celery_find_task( + payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery + ) + if found: + # the celery task exists in the redis queue + # redis_connector_index.set_active() + return + + if payload.celery_task_id in reserved_tasks: + # the celery task was prefetched and is reserved within the indexing worker + # redis_connector_index.set_active() + return + + # we may want to enable this check if using the active task list somehow isn't good enough + # if redis_connector_index.generator_locked(): + # logger.info(f"{payload.celery_task_id} is currently executing.") + + # if we get here, we didn't find any direct indication that the associated celery tasks exist, + # but they still might be there due to gaps in our ability to check states during transitions + # Checking the active signal safeguards us against these transition periods + # (which has a duration that allows us to bridge those gaps) + # if redis_connector_index.active(): + # return + + # celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up. + logger.warning( + "validate_external_group_sync_fence - " + "Resetting fence because no associated celery tasks were found: " + f"cc_pair={cc_pair_id} " + f"fence={fence_key}" + ) + + redis_connector.external_group_sync.reset() + return diff --git a/backend/onyx/background/celery/tasks/pruning/tasks.py b/backend/onyx/background/celery/tasks/pruning/tasks.py index c0e51e393b4..99a37ddd017 100644 --- a/backend/onyx/background/celery/tasks/pruning/tasks.py +++ b/backend/onyx/background/celery/tasks/pruning/tasks.py @@ -39,6 +39,7 @@ from onyx.db.sync_record import insert_sync_record from onyx.db.sync_record import update_sync_record_status from onyx.redis.redis_connector import RedisConnector from onyx.redis.redis_pool import get_redis_client +from onyx.utils.logger import LoggerContextVars from onyx.utils.logger import pruning_ctx from onyx.utils.logger import setup_logger @@ -251,6 +252,8 @@ def connector_pruning_generator_task( and compares those IDs to locally stored documents and deletes all locally stored IDs missing from the most recently pulled document ID list""" + LoggerContextVars.reset() + pruning_ctx_dict = pruning_ctx.get() pruning_ctx_dict["cc_pair_id"] = cc_pair_id pruning_ctx_dict["request_id"] = self.request.id @@ -399,7 +402,7 @@ def monitor_ccpair_pruning_taskset( mark_ccpair_as_pruned(int(cc_pair_id), db_session) task_logger.info( - f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}" + f"Connector pruning finished: cc_pair={cc_pair_id} num_pruned={initial}" ) update_sync_record_status( diff --git a/backend/onyx/background/celery/tasks/shared/tasks.py b/backend/onyx/background/celery/tasks/shared/tasks.py index 48a21bc978d..5530d9eebca 100644 --- a/backend/onyx/background/celery/tasks/shared/tasks.py +++ b/backend/onyx/background/celery/tasks/shared/tasks.py @@ -75,6 +75,8 @@ def document_by_cc_pair_cleanup_task( """ task_logger.debug(f"Task start: doc={document_id}") + start = time.monotonic() + try: with get_session_with_tenant(tenant_id) as db_session: action = "skip" @@ -154,11 +156,13 @@ def document_by_cc_pair_cleanup_task( db_session.commit() + elapsed = time.monotonic() - start task_logger.info( f"doc={document_id} " f"action={action} " f"refcount={count} " - f"chunks={chunks_affected}" + f"chunks={chunks_affected} " + f"elapsed={elapsed:.2f}" ) except SoftTimeLimitExceeded: task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}") diff --git a/backend/onyx/background/celery/tasks/vespa/tasks.py b/backend/onyx/background/celery/tasks/vespa/tasks.py index 7152903bf07..c695f92a8bf 100644 --- a/backend/onyx/background/celery/tasks/vespa/tasks.py +++ b/backend/onyx/background/celery/tasks/vespa/tasks.py @@ -989,6 +989,10 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." ) + return False + except Exception: + task_logger.exception("monitor_vespa_sync exceptioned.") + return False finally: if lock_beat.owned(): lock_beat.release() @@ -1078,6 +1082,7 @@ def vespa_metadata_sync_task( ) except SoftTimeLimitExceeded: task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}") + return False except Exception as ex: if isinstance(ex, RetryError): task_logger.warning( diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index 09a90a9ee20..e18a5ee3e7a 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -300,6 +300,8 @@ class OnyxRedisLocks: class OnyxRedisSignals: VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences" + VALIDATE_EXTERNAL_GROUP_SYNC_FENCES = "signal:validate_external_group_sync_fences" + VALIDATE_PERMISSION_SYNC_FENCES = "signal:validate_permission_sync_fences" class OnyxCeleryPriority(int, Enum): diff --git a/backend/onyx/redis/redis_connector_doc_perm_sync.py b/backend/onyx/redis/redis_connector_doc_perm_sync.py index 99f891e14fd..7e587362056 100644 --- a/backend/onyx/redis/redis_connector_doc_perm_sync.py +++ b/backend/onyx/redis/redis_connector_doc_perm_sync.py @@ -17,6 +17,8 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT class RedisConnectorPermissionSyncPayload(BaseModel): + id: str + submitted: datetime started: datetime | None celery_task_id: str | None @@ -41,6 +43,12 @@ class RedisConnectorPermissionSync: TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub + # used to signal the overall workflow is still active + # it's impossible to get the exact state of the system at a single point in time + # so we need a signal with a TTL to bridge gaps in our checks + ACTIVE_PREFIX = PREFIX + "_active" + ACTIVE_TTL = 3600 + def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None: self.tenant_id: str | None = tenant_id self.id = id @@ -54,6 +62,7 @@ class RedisConnectorPermissionSync: self.taskset_key = f"{self.TASKSET_PREFIX}_{id}" self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}" + self.active_key = f"{self.ACTIVE_PREFIX}_{id}" def taskset_clear(self) -> None: self.redis.delete(self.taskset_key) @@ -107,6 +116,20 @@ class RedisConnectorPermissionSync: self.redis.set(self.fence_key, payload.model_dump_json()) + def set_active(self) -> None: + """This sets a signal to keep the permissioning flow from getting cleaned up within + the expiration time. + + The slack in timing is needed to avoid race conditions where simply checking + the celery queue and task status could result in race conditions.""" + self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL) + + def active(self) -> bool: + if self.redis.exists(self.active_key): + return True + + return False + @property def generator_complete(self) -> int | None: """the fence payload is an int representing the starting number of @@ -173,6 +196,7 @@ class RedisConnectorPermissionSync: return len(async_results) def reset(self) -> None: + self.redis.delete(self.active_key) self.redis.delete(self.generator_progress_key) self.redis.delete(self.generator_complete_key) self.redis.delete(self.taskset_key) @@ -187,6 +211,9 @@ class RedisConnectorPermissionSync: @staticmethod def reset_all(r: redis.Redis) -> None: """Deletes all redis values for all connectors""" + for key in r.scan_iter(RedisConnectorPermissionSync.ACTIVE_PREFIX + "*"): + r.delete(key) + for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"): r.delete(key) diff --git a/backend/onyx/redis/redis_connector_ext_group_sync.py b/backend/onyx/redis/redis_connector_ext_group_sync.py index 4d29ab5956a..2f0783dbcc7 100644 --- a/backend/onyx/redis/redis_connector_ext_group_sync.py +++ b/backend/onyx/redis/redis_connector_ext_group_sync.py @@ -11,6 +11,7 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT class RedisConnectorExternalGroupSyncPayload(BaseModel): + submitted: datetime started: datetime | None celery_task_id: str | None @@ -135,6 +136,12 @@ class RedisConnectorExternalGroupSync: ) -> int | None: pass + def reset(self) -> None: + self.redis.delete(self.generator_progress_key) + self.redis.delete(self.generator_complete_key) + self.redis.delete(self.taskset_key) + self.redis.delete(self.fence_key) + @staticmethod def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None: taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}" diff --git a/backend/onyx/redis/redis_connector_index.py b/backend/onyx/redis/redis_connector_index.py index 5b62da7b6ba..215468f352e 100644 --- a/backend/onyx/redis/redis_connector_index.py +++ b/backend/onyx/redis/redis_connector_index.py @@ -33,8 +33,8 @@ class RedisConnectorIndex: TERMINATE_TTL = 600 # used to signal the overall workflow is still active - # there are gaps in time between states where we need some slack - # to correctly transition + # it's impossible to get the exact state of the system at a single point in time + # so we need a signal with a TTL to bridge gaps in our checks ACTIVE_PREFIX = PREFIX + "_active" ACTIVE_TTL = 3600 diff --git a/backend/onyx/redis/redis_pool.py b/backend/onyx/redis/redis_pool.py index 10a2c7655c4..278c21fb821 100644 --- a/backend/onyx/redis/redis_pool.py +++ b/backend/onyx/redis/redis_pool.py @@ -122,7 +122,7 @@ class TenantRedis(redis.Redis): "ttl", ] # Regular methods that need simple prefixing - if item == "scan_iter": + if item == "scan_iter" or item == "sscan_iter": return self._prefix_scan_iter(original_attr) elif item in methods_to_wrap and callable(original_attr): return self._prefix_method(original_attr) diff --git a/backend/onyx/server/documents/cc_pair.py b/backend/onyx/server/documents/cc_pair.py index ced98377ffa..3ba5984df38 100644 --- a/backend/onyx/server/documents/cc_pair.py +++ b/backend/onyx/server/documents/cc_pair.py @@ -422,27 +422,29 @@ def sync_cc_pair( if redis_connector.permissions.fenced: raise HTTPException( status_code=HTTPStatus.CONFLICT, - detail="Doc permissions sync task already in progress.", + detail="Permissions sync task already in progress.", ) logger.info( - f"Doc permissions sync cc_pair={cc_pair_id} " + f"Permissions sync cc_pair={cc_pair_id} " f"connector_id={cc_pair.connector_id} " f"credential_id={cc_pair.credential_id} " f"{cc_pair.connector.name} connector." ) - tasks_created = try_creating_permissions_sync_task( + payload_id = try_creating_permissions_sync_task( primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get() ) - if not tasks_created: + if not payload_id: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, - detail="Doc permissions sync task creation failed.", + detail="Permissions sync task creation failed.", ) + logger.info(f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}") + return StatusResponse( success=True, - message="Successfully created the doc permissions sync task.", + message="Successfully created the permissions sync task.", ) diff --git a/backend/onyx/server/utils.py b/backend/onyx/server/utils.py index c6da9614309..8dc7a429b87 100644 --- a/backend/onyx/server/utils.py +++ b/backend/onyx/server/utils.py @@ -1,4 +1,6 @@ +import base64 import json +import os from datetime import datetime from typing import Any @@ -66,3 +68,10 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]: ) return masked_creds + + +def make_short_id() -> str: + """Fast way to generate a random 8 character id ... useful for tagging data + to trace it through a flow. This is definitely not guaranteed to be unique and is + targeted at the stated use case.""" + return base64.b32encode(os.urandom(5)).decode("utf-8")[:8] # 5 bytes → 8 chars diff --git a/backend/onyx/utils/logger.py b/backend/onyx/utils/logger.py index d4aa6c251b9..eb649a3bdf6 100644 --- a/backend/onyx/utils/logger.py +++ b/backend/onyx/utils/logger.py @@ -26,6 +26,13 @@ doc_permission_sync_ctx: contextvars.ContextVar[ ] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict()) +class LoggerContextVars: + @staticmethod + def reset() -> None: + pruning_ctx.set(dict()) + doc_permission_sync_ctx.set(dict()) + + class TaskAttemptSingleton: """Used to tell if this process is an indexing job, and if so what is the unique identifier for this indexing attempt. For things like the API server, @@ -70,27 +77,32 @@ class OnyxLoggingAdapter(logging.LoggerAdapter): ) -> tuple[str, MutableMapping[str, Any]]: # If this is an indexing job, add the attempt ID to the log message # This helps filter the logs for this specific indexing - index_attempt_id = TaskAttemptSingleton.get_index_attempt_id() - cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id() + while True: + pruning_ctx_dict = pruning_ctx.get() + if len(pruning_ctx_dict) > 0: + if "request_id" in pruning_ctx_dict: + msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}" - doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get() - pruning_ctx_dict = pruning_ctx.get() - if len(pruning_ctx_dict) > 0: - if "request_id" in pruning_ctx_dict: - msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}" + if "cc_pair_id" in pruning_ctx_dict: + msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}" + break + + doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get() + if len(doc_permission_sync_ctx_dict) > 0: + if "request_id" in doc_permission_sync_ctx_dict: + msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}" + break + + index_attempt_id = TaskAttemptSingleton.get_index_attempt_id() + cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id() - if "cc_pair_id" in pruning_ctx_dict: - msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}" - elif len(doc_permission_sync_ctx_dict) > 0: - if "request_id" in doc_permission_sync_ctx_dict: - msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}" - else: if index_attempt_id is not None: msg = f"[Index Attempt: {index_attempt_id}] {msg}" if cc_pair_id is not None: msg = f"[CC Pair: {cc_pair_id}] {msg}" + break # Add tenant information if it differs from default # This will always be the case for authenticated API requests if MULTI_TENANT: