Bugfix/limit permission size (#4695)

* add utility function

* add utility functions to DocExternalAccess

* refactor db access out of individual celery tasks and put it directly into the heavy task

* code review and remove leftovers

* fix circular imports

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer
2025-05-12 17:46:31 -07:00
committed by GitHub
parent 551a05aef0
commit 392b87fb4f
9 changed files with 260 additions and 125 deletions

View File

@@ -1,22 +1,19 @@
import time
from datetime import datetime
from logging import Logger
from typing import Any
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.variable_functionality import fetch_versioned_implementation
class RedisConnectorPermissionSyncPayload(BaseModel):
@@ -160,47 +157,64 @@ class RedisConnectorPermissionSync:
self.redis.set(self.generator_complete_key, payload)
def generate_tasks(
def update_db(
self,
celery_app: Celery,
lock: RedisLock | None,
new_permissions: list[DocExternalAccess],
source_string: str,
connector_id: int,
credential_id: int,
task_logger: Logger | None = None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
document_update_permissions_fn = fetch_versioned_implementation(
"onyx.background.celery.tasks.doc_permission_syncing.tasks",
"document_update_permissions",
)
num_permissions = 0
# Create a task for each document permission sync
for doc_perm in new_permissions:
for permissions in new_permissions:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# Add task for document permissions sync
custom_task_id = f"{self.subtask_prefix}_{uuid4()}"
self.redis.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
kwargs=dict(
tenant_id=self.tenant_id,
serialized_doc_external_access=doc_perm.to_dict(),
source_string=source_string,
connector_id=connector_id,
credential_id=credential_id,
),
queue=OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
ignore_result=True,
if (
permissions.external_access.num_entries
> permissions.external_access.MAX_NUM_ENTRIES
):
if task_logger:
num_users = len(permissions.external_access.external_user_emails)
num_groups = len(
permissions.external_access.external_user_group_ids
)
task_logger.warning(
f"Permissions length exceeded, skipping...: "
f"{permissions.doc_id} "
f"{num_users=} {num_groups=} "
f"{permissions.external_access.MAX_NUM_ENTRIES=}"
)
continue
# NOTE(rkuo): this used to fire a task instead of directly writing to the DB,
# but the permissions can be excessively large if sent over the wire.
# On the other hand, the downside of doing db updates here is that we can
# block and fail if we can't make the calls to the DB ... but that's probably
# a rare enough case to be acceptable.
# This can internally exception due to db issues but still continue
# we may want to change this
document_update_permissions_fn(
self.tenant_id, permissions, source_string, connector_id, credential_id
)
async_results.append(result)
return len(async_results)
num_permissions += 1
return num_permissions
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)

View File

@@ -0,0 +1,60 @@
from sqlalchemy.orm import Session
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import TaskStatus
from onyx.db.models import TaskQueueState
from onyx.redis.redis_connector import RedisConnector
from onyx.server.documents.models import DeletionAttemptSnapshot
def _get_deletion_status(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
"""
cc_pair = get_connector_credential_pair(
connector_id=connector_id, credential_id=credential_id, db_session=db_session
)
if not cc_pair:
return None
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if redis_connector.delete.fenced:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.PENDING,
)
return None
def get_deletion_attempt_snapshot(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id
)
if not deletion_task:
return None
return DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=deletion_task.status,
)