Bugfix/limit permission size (#4695)

* add utility function

* add utility functions to DocExternalAccess

* refactor db access out of individual celery tasks and put it directly into the heavy task

* code review and remove leftovers

* fix circular imports

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer
2025-05-12 17:46:31 -07:00
committed by GitHub
parent 551a05aef0
commit 392b87fb4f
9 changed files with 260 additions and 125 deletions

View File

@@ -16,6 +16,10 @@ from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from tenacity import retry
from tenacity import retry_if_exception
from tenacity import stop_after_delay
from tenacity import wait_random_exponential
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
@@ -31,7 +35,6 @@ from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
@@ -50,6 +53,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -58,6 +62,7 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.utils import is_retryable_sqlalchemy_error
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
@@ -74,11 +79,12 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
logger = setup_logger()
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
DOCUMENT_PERMISSIONS_UPDATE_STOP_AFTER = 10 * 60
DOCUMENT_PERMISSIONS_UPDATE_MAX_WAIT = 60
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
@@ -472,13 +478,13 @@ def connector_permission_sync_generator_task(
tasks_generated = 0
for doc_external_access in document_external_accesses:
redis_connector.permissions.generate_tasks(
celery_app=self.app,
redis_connector.permissions.update_db(
lock=lock,
new_permissions=[doc_external_access],
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
task_logger=task_logger,
)
tasks_generated += 1
@@ -491,6 +497,7 @@ def connector_permission_sync_generator_task(
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id} {error_msg}"
)
@@ -511,33 +518,28 @@ def connector_permission_sync_generator_task(
)
@shared_task(
name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
bind=True,
# NOTE(rkuo): this should probably move to the db layer
@retry(
retry=retry_if_exception(is_retryable_sqlalchemy_error),
wait=wait_random_exponential(
multiplier=1, max=DOCUMENT_PERMISSIONS_UPDATE_MAX_WAIT
),
stop=stop_after_delay(DOCUMENT_PERMISSIONS_UPDATE_STOP_AFTER),
)
def update_external_document_permissions_task(
self: Task,
def document_update_permissions(
tenant_id: str,
serialized_doc_external_access: dict,
source_string: str,
permissions: DocExternalAccess,
source_type_str: str,
connector_id: int,
credential_id: int,
) -> bool:
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
doc_id = document_external_access.doc_id
external_access = document_external_access.external_access
doc_id = permissions.doc_id
external_access = permissions.external_access
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Add the users to the DB if they don't exist
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
@@ -549,7 +551,7 @@ def update_external_document_permissions_task(
db_session=db_session,
doc_id=doc_id,
external_access=external_access,
source_type=DocumentSource(source_string),
source_type=DocumentSource(source_type_str),
)
if created_new_doc:
@@ -568,32 +570,105 @@ def update_external_document_permissions_task(
f"action=update_permissions "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
)
task_logger.exception(
f"update_external_document_permissions_task exceptioned: "
f"document_update_permissions exceptioned: "
f"connector_id={connector_id} doc_id={doc_id}"
)
completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
raise e
finally:
task_logger.info(
f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
f"document_update_permissions completed: connector_id={connector_id} doc={doc_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
task_logger.info(
f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
)
return True
# NOTE(rkuo): Deprecating this due to degenerate behavior in Redis from sending
# large permissions through celery (over 1MB in size)
# @shared_task(
# name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
# soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
# time_limit=LIGHT_TIME_LIMIT,
# max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
# bind=True,
# )
# def update_external_document_permissions_task(
# self: Task,
# tenant_id: str,
# serialized_doc_external_access: dict,
# source_string: str,
# connector_id: int,
# credential_id: int,
# ) -> bool:
# start = time.monotonic()
# completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
# document_external_access = DocExternalAccess.from_dict(
# serialized_doc_external_access
# )
# doc_id = document_external_access.doc_id
# external_access = document_external_access.external_access
# try:
# with get_session_with_current_tenant() as db_session:
# # Add the users to the DB if they don't exist
# batch_add_ext_perm_user_if_not_exists(
# db_session=db_session,
# emails=list(external_access.external_user_emails),
# continue_on_error=True,
# )
# # Then upsert the document's external permissions
# created_new_doc = upsert_document_external_perms(
# db_session=db_session,
# doc_id=doc_id,
# external_access=external_access,
# source_type=DocumentSource(source_string),
# )
# if created_new_doc:
# # If a new document was created, we associate it with the cc_pair
# upsert_document_by_connector_credential_pair(
# db_session=db_session,
# connector_id=connector_id,
# credential_id=credential_id,
# document_ids=[doc_id],
# )
# elapsed = time.monotonic() - start
# task_logger.info(
# f"connector_id={connector_id} "
# f"doc={doc_id} "
# f"action=update_permissions "
# f"elapsed={elapsed:.2f}"
# )
# completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
# except Exception as e:
# error_msg = format_error_for_logging(e)
# task_logger.warning(
# f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
# )
# task_logger.exception(
# f"update_external_document_permissions_task exceptioned: "
# f"connector_id={connector_id} doc_id={doc_id}"
# )
# completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
# finally:
# task_logger.info(
# f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
# )
# if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
# return False
# task_logger.info(
# f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
# )
# return True
def validate_permission_sync_fences(
tenant_id: str,
r: Redis,

View File

@@ -8,6 +8,11 @@ from onyx.configs.constants import PUBLIC_DOC_PAT
@dataclass(frozen=True)
class ExternalAccess:
# arbitrary limit to prevent excessively large permissions sets
# not internally enforced ... the caller can check this before using the instance
MAX_NUM_ENTRIES = 1000
# Emails of external users with access to the doc externally
external_user_emails: set[str]
# Names or external IDs of groups with access to the doc
@@ -31,6 +36,10 @@ class ExternalAccess:
f"is_public={self.is_public})"
)
@property
def num_entries(self) -> int:
return len(self.external_user_emails) + len(self.external_user_group_ids)
@dataclass(frozen=True)
class DocExternalAccess:

View File

@@ -4,7 +4,6 @@ from typing import Any
from typing import cast
import httpx
from sqlalchemy.orm import Session
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
@@ -16,72 +15,14 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import Document
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import TaskStatus
from onyx.db.models import TaskQueueState
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.server.documents.models import DeletionAttemptSnapshot
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_deletion_status(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
"""
cc_pair = get_connector_credential_pair(
connector_id=connector_id, credential_id=credential_id, db_session=db_session
)
if not cc_pair:
return None
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if redis_connector.delete.fenced:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.PENDING,
)
return None
def get_deletion_attempt_snapshot(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id
)
if not deletion_task:
return None
return DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=deletion_task.status,
)
def document_batch_to_ids(
doc_batch: list[Document],
) -> set[str]:

View File

@@ -255,6 +255,9 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
def _get_users_by_emails(
db_session: Session, lower_emails: list[str]
) -> tuple[list[User], list[str]]:
"""given a list of lowercase emails,
returns a list[User] of Users whose emails match and a list[str]
the missing emails that had no User"""
stmt = select(User).filter(func.lower(User.email).in_(lower_emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list

View File

@@ -1,5 +1,7 @@
from typing import Any
from psycopg2 import errorcodes
from psycopg2 import OperationalError
from sqlalchemy import inspect
from onyx.db.models import Base
@@ -7,3 +9,21 @@ from onyx.db.models import Base
def model_to_dict(model: Base) -> dict[str, Any]:
return {c.key: getattr(model, c.key) for c in inspect(model).mapper.column_attrs} # type: ignore
RETRYABLE_PG_CODES = {
errorcodes.SERIALIZATION_FAILURE, # '40001'
errorcodes.DEADLOCK_DETECTED, # '40P01'
errorcodes.CONNECTION_EXCEPTION, # '08000'
errorcodes.CONNECTION_DOES_NOT_EXIST, # '08003'
errorcodes.CONNECTION_FAILURE, # '08006'
errorcodes.TRANSACTION_ROLLBACK, # '40000'
}
def is_retryable_sqlalchemy_error(exc: BaseException) -> bool:
"""Helper function for use with tenacity's retry_if_exception as the callback"""
if isinstance(exc, OperationalError):
pgcode = getattr(getattr(exc, "orig", None), "pgcode", None)
return pgcode in RETRYABLE_PG_CODES
return False

View File

@@ -1,22 +1,19 @@
import time
from datetime import datetime
from logging import Logger
from typing import Any
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.variable_functionality import fetch_versioned_implementation
class RedisConnectorPermissionSyncPayload(BaseModel):
@@ -160,47 +157,64 @@ class RedisConnectorPermissionSync:
self.redis.set(self.generator_complete_key, payload)
def generate_tasks(
def update_db(
self,
celery_app: Celery,
lock: RedisLock | None,
new_permissions: list[DocExternalAccess],
source_string: str,
connector_id: int,
credential_id: int,
task_logger: Logger | None = None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
document_update_permissions_fn = fetch_versioned_implementation(
"onyx.background.celery.tasks.doc_permission_syncing.tasks",
"document_update_permissions",
)
num_permissions = 0
# Create a task for each document permission sync
for doc_perm in new_permissions:
for permissions in new_permissions:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# Add task for document permissions sync
custom_task_id = f"{self.subtask_prefix}_{uuid4()}"
self.redis.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
kwargs=dict(
tenant_id=self.tenant_id,
serialized_doc_external_access=doc_perm.to_dict(),
source_string=source_string,
connector_id=connector_id,
credential_id=credential_id,
),
queue=OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
ignore_result=True,
if (
permissions.external_access.num_entries
> permissions.external_access.MAX_NUM_ENTRIES
):
if task_logger:
num_users = len(permissions.external_access.external_user_emails)
num_groups = len(
permissions.external_access.external_user_group_ids
)
task_logger.warning(
f"Permissions length exceeded, skipping...: "
f"{permissions.doc_id} "
f"{num_users=} {num_groups=} "
f"{permissions.external_access.MAX_NUM_ENTRIES=}"
)
continue
# NOTE(rkuo): this used to fire a task instead of directly writing to the DB,
# but the permissions can be excessively large if sent over the wire.
# On the other hand, the downside of doing db updates here is that we can
# block and fail if we can't make the calls to the DB ... but that's probably
# a rare enough case to be acceptable.
# This can internally exception due to db issues but still continue
# we may want to change this
document_update_permissions_fn(
self.tenant_id, permissions, source_string, connector_id, credential_id
)
async_results.append(result)
return len(async_results)
num_permissions += 1
return num_permissions
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)

View File

@@ -0,0 +1,60 @@
from sqlalchemy.orm import Session
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import TaskStatus
from onyx.db.models import TaskQueueState
from onyx.redis.redis_connector import RedisConnector
from onyx.server.documents.models import DeletionAttemptSnapshot
def _get_deletion_status(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
"""
cc_pair = get_connector_credential_pair(
connector_id=connector_id, credential_id=credential_id, db_session=db_session
)
if not cc_pair:
return None
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if redis_connector.delete.fenced:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.PENDING,
)
return None
def get_deletion_attempt_snapshot(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id
)
if not deletion_task:
return None
return DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=deletion_task.status,
)

View File

@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.celery_utils import get_deletion_attempt_snapshot
from onyx.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
@@ -45,6 +44,7 @@ from onyx.db.models import User
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.search_settings import get_current_search_settings
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_utils import get_deletion_attempt_snapshot
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CCPairFullInfo
from onyx.server.documents.models import CCPropertyUpdateRequest

View File

@@ -52,9 +52,11 @@ class OnyxRedisCommand(Enum):
purge_usergroup_taskset = "purge_usergroup_taskset"
purge_locks_blocking_deletion = "purge_locks_blocking_deletion"
purge_vespa_syncing = "purge_vespa_syncing"
purge_pidbox = "purge_pidbox"
get_user_token = "get_user_token"
delete_user_token = "delete_user_token"
add_invited_user = "add_invited_user"
get_list_element = "get_list_element"
def __str__(self) -> str:
return self.value
@@ -145,6 +147,17 @@ def onyx_redis(
return purge_by_match_and_type(
"*connectorsync:vespa_syncing*", "string", batch, dry_run, r
)
elif command == OnyxRedisCommand.purge_pidbox:
return purge_by_match_and_type(
"*reply.celery.pidbox", "list", batch, dry_run, r
)
elif command == OnyxRedisCommand.get_list_element:
# just hardcoded for now
result = r.lrange(
"0097a564-d343-3c1f-9fd1-af8cce038115.reply.celery.pidbox", 0, 0
)
print(f"{result}")
return 0
elif command == OnyxRedisCommand.get_user_token:
if not user_email:
logger.error("You must specify --user-email with get_user_token")