import time from datetime import datetime from datetime import timedelta from datetime import timezone from time import sleep from uuid import uuid4 from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs from ee.onyx.db.document import upsert_document_external_perms from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS from ee.onyx.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP from ee.onyx.external_permissions.sync_params import ( DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION, ) from onyx.access.models import DocExternalAccess from onyx.background.celery.apps.app_base import task_logger from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX from onyx.configs.constants import DocumentSource from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks from onyx.db.connector import mark_cc_pair_as_permissions_synced from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.document import upsert_document_by_connector_credential_pair from onyx.db.engine import get_session_with_tenant from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import SyncStatus from onyx.db.enums import SyncType from onyx.db.models import ConnectorCredentialPair from onyx.db.sync_record import insert_sync_record from onyx.db.sync_record import update_sync_record_status from onyx.db.users import batch_add_ext_perm_user_if_not_exists from onyx.redis.redis_connector import RedisConnector from onyx.redis.redis_connector_doc_perm_sync import ( RedisConnectorPermissionSyncPayload, ) from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import doc_permission_sync_ctx from onyx.utils.logger import setup_logger logger = setup_logger() DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3 # 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT LIGHT_SOFT_TIME_LIMIT = 105 LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15 """Jobs / utils for kicking off doc permissions sync tasks.""" def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool: """Returns boolean indicating if external doc permissions sync is due.""" if cc_pair.access_type != AccessType.SYNC: return False # skip doc permissions sync if not active if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: return False if cc_pair.status == ConnectorCredentialPairStatus.DELETING: return False # If the last sync is None, it has never been run so we run the sync last_perm_sync = cc_pair.last_time_perm_sync if last_perm_sync is None: return True source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source) # If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync. if not source_sync_period: return True # If the last sync is greater than the full fetch period, we run the sync next_sync = last_perm_sync + timedelta(seconds=source_sync_period) if datetime.now(timezone.utc) >= next_sync: return True return False @shared_task( name=OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC, ignore_result=True, soft_time_limit=JOB_TIMEOUT, bind=True, ) def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None: r = get_redis_client(tenant_id=tenant_id) lock_beat: RedisLock = r.lock( OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) # these tasks should never overlap if not lock_beat.acquire(blocking=False): return None try: # get all cc pairs that need to be synced cc_pair_ids_to_sync: list[int] = [] with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_all_auto_sync_cc_pairs(db_session) for cc_pair in cc_pairs: if _is_external_doc_permissions_sync_due(cc_pair): cc_pair_ids_to_sync.append(cc_pair.id) for cc_pair_id in cc_pair_ids_to_sync: tasks_created = try_creating_permissions_sync_task( self.app, cc_pair_id, r, tenant_id ) if not tasks_created: continue task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}") except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." ) except Exception: task_logger.exception(f"Unexpected exception: tenant={tenant_id}") finally: if lock_beat.owned(): lock_beat.release() return True def try_creating_permissions_sync_task( app: Celery, cc_pair_id: int, r: Redis, tenant_id: str | None, ) -> int | None: """Returns an int if syncing is needed. The int represents the number of sync tasks generated. Returns None if no syncing is required.""" redis_connector = RedisConnector(tenant_id, cc_pair_id) LOCK_TIMEOUT = 30 lock: RedisLock = r.lock( DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks", timeout=LOCK_TIMEOUT, ) acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2) if not acquired: return None try: if redis_connector.permissions.fenced: return None if redis_connector.delete.fenced: return None if redis_connector.prune.fenced: return None redis_connector.permissions.generator_clear() redis_connector.permissions.taskset_clear() custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}" # create before setting fence to avoid race condition where the monitoring # task updates the sync record before it is created with get_session_with_tenant(tenant_id) as db_session: insert_sync_record( db_session=db_session, entity_id=cc_pair_id, sync_type=SyncType.EXTERNAL_PERMISSIONS, ) # set a basic fence to start payload = RedisConnectorPermissionSyncPayload(started=None, celery_task_id=None) redis_connector.permissions.set_fence(payload) result = app.send_task( OnyxCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK, kwargs=dict( cc_pair_id=cc_pair_id, tenant_id=tenant_id, ), queue=OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, task_id=custom_task_id, priority=OnyxCeleryPriority.HIGH, ) # fill in the celery task id payload.celery_task_id = result.id redis_connector.permissions.set_fence(payload) except Exception: task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}") return None finally: if lock.owned(): lock.release() return 1 @shared_task( name=OnyxCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK, acks_late=False, soft_time_limit=JOB_TIMEOUT, track_started=True, trail=False, bind=True, ) def connector_permission_sync_generator_task( self: Task, cc_pair_id: int, tenant_id: str | None, ) -> None: """ Permission sync task that handles document permission syncing for a given connector credential pair This task assumes that the task has already been properly fenced """ doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get() doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id doc_permission_sync_ctx_dict["request_id"] = self.request.id doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict) redis_connector = RedisConnector(tenant_id, cc_pair_id) r = get_redis_client(tenant_id=tenant_id) # this wait is needed to avoid a race condition where # the primary worker sends the task and it is immediately executed # before the primary worker can finalize the fence start = time.monotonic() while True: if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT: raise ValueError( f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: " f"fence={redis_connector.permissions.fence_key}" ) if not redis_connector.permissions.fenced: # The fence must exist raise ValueError( f"connector_permission_sync_generator_task - fence not found: " f"fence={redis_connector.permissions.fence_key}" ) payload = redis_connector.permissions.payload # The payload must exist if not payload: raise ValueError( "connector_permission_sync_generator_task: payload invalid or not found" ) if payload.celery_task_id is None: logger.info( f"connector_permission_sync_generator_task - Waiting for fence: " f"fence={redis_connector.permissions.fence_key}" ) sleep(1) continue logger.info( f"connector_permission_sync_generator_task - Fence found, continuing...: " f"fence={redis_connector.permissions.fence_key}" ) break lock: RedisLock = r.lock( OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX + f"_{redis_connector.id}", timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT, ) acquired = lock.acquire(blocking=False) if not acquired: task_logger.warning( f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}" ) return None try: with get_session_with_tenant(tenant_id) as db_session: cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) if cc_pair is None: raise ValueError( f"No connector credential pair found for id: {cc_pair_id}" ) source_type = cc_pair.connector.source doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type) if doc_sync_func is None: if source_type in DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION: return None raise ValueError( f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}" ) logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}") payload = redis_connector.permissions.payload if not payload: raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}") new_payload = RedisConnectorPermissionSyncPayload( started=datetime.now(timezone.utc), celery_task_id=payload.celery_task_id, ) redis_connector.permissions.set_fence(new_payload) document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair) task_logger.info( f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}" ) tasks_generated = redis_connector.permissions.generate_tasks( celery_app=self.app, lock=lock, new_permissions=document_external_accesses, source_string=source_type, connector_id=cc_pair.connector.id, credential_id=cc_pair.credential.id, ) if tasks_generated is None: return None task_logger.info( f"RedisConnector.permissions.generate_tasks finished. " f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}" ) redis_connector.permissions.generator_complete = tasks_generated except Exception as e: task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}") redis_connector.permissions.generator_clear() redis_connector.permissions.taskset_clear() redis_connector.permissions.set_fence(None) raise e finally: if lock.owned(): lock.release() @shared_task( name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK, soft_time_limit=LIGHT_SOFT_TIME_LIMIT, time_limit=LIGHT_TIME_LIMIT, max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES, bind=True, ) def update_external_document_permissions_task( self: Task, tenant_id: str | None, serialized_doc_external_access: dict, source_string: str, connector_id: int, credential_id: int, ) -> bool: document_external_access = DocExternalAccess.from_dict( serialized_doc_external_access ) doc_id = document_external_access.doc_id external_access = document_external_access.external_access try: with get_session_with_tenant(tenant_id) as db_session: # Add the users to the DB if they don't exist batch_add_ext_perm_user_if_not_exists( db_session=db_session, emails=list(external_access.external_user_emails), ) # Then we upsert the document's external permissions in postgres created_new_doc = upsert_document_external_perms( db_session=db_session, doc_id=doc_id, external_access=external_access, source_type=DocumentSource(source_string), ) if created_new_doc: # If a new document was created, we associate it with the cc_pair upsert_document_by_connector_credential_pair( db_session=db_session, connector_id=connector_id, credential_id=credential_id, document_ids=[doc_id], ) logger.debug( f"Successfully synced postgres document permissions for {doc_id}" ) return True except Exception: logger.exception( f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}" ) return False """Monitoring CCPair permissions utils, called in monitor_vespa_sync""" def monitor_ccpair_permissions_taskset( tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session ) -> None: fence_key = key_bytes.decode("utf-8") cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) if cc_pair_id_str is None: task_logger.warning( f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}" ) return cc_pair_id = int(cc_pair_id_str) redis_connector = RedisConnector(tenant_id, cc_pair_id) if not redis_connector.permissions.fenced: return initial = redis_connector.permissions.generator_complete if initial is None: return remaining = redis_connector.permissions.get_remaining() task_logger.info( f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}" ) if remaining > 0: return payload: RedisConnectorPermissionSyncPayload | None = ( redis_connector.permissions.payload ) start_time: datetime | None = payload.started if payload else None mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time) task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}") update_sync_record_status( db_session=db_session, entity_id=cc_pair_id, sync_type=SyncType.EXTERNAL_PERMISSIONS, sync_status=SyncStatus.SUCCESS, num_docs_synced=initial, ) redis_connector.permissions.reset()