mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-25 12:23:49 +02:00
Validate permission locks (#3799)
* WIP for external group sync lock fixes * prototyping permissions validation * validate permission sync tasks in celery * mypy * cleanup and wire off external group sync checks for now * add active key to reset * improve logging * reset on payload format change * return False on exception * missed a return * add count of tasks scanned * add comment * better logging * add return * more return * catch payload exceptions * code review fixes * push to restart test --------- Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
@@ -13,6 +13,7 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
|||||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||||
from onyx.connectors.models import SlimDocument
|
from onyx.connectors.models import SlimDocument
|
||||||
from onyx.db.models import ConnectorCredentialPair
|
from onyx.db.models import ConnectorCredentialPair
|
||||||
|
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@@ -257,6 +258,7 @@ def _fetch_all_page_restrictions(
|
|||||||
slim_docs: list[SlimDocument],
|
slim_docs: list[SlimDocument],
|
||||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||||
is_cloud: bool,
|
is_cloud: bool,
|
||||||
|
callback: IndexingHeartbeatInterface | None,
|
||||||
) -> list[DocExternalAccess]:
|
) -> list[DocExternalAccess]:
|
||||||
"""
|
"""
|
||||||
For all pages, if a page has restrictions, then use those restrictions.
|
For all pages, if a page has restrictions, then use those restrictions.
|
||||||
@@ -265,6 +267,12 @@ def _fetch_all_page_restrictions(
|
|||||||
document_restrictions: list[DocExternalAccess] = []
|
document_restrictions: list[DocExternalAccess] = []
|
||||||
|
|
||||||
for slim_doc in slim_docs:
|
for slim_doc in slim_docs:
|
||||||
|
if callback:
|
||||||
|
if callback.should_stop():
|
||||||
|
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||||
|
|
||||||
|
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
|
||||||
|
|
||||||
if slim_doc.perm_sync_data is None:
|
if slim_doc.perm_sync_data is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No permission sync data found for document {slim_doc.id}"
|
f"No permission sync data found for document {slim_doc.id}"
|
||||||
@@ -334,7 +342,7 @@ def _fetch_all_page_restrictions(
|
|||||||
|
|
||||||
|
|
||||||
def confluence_doc_sync(
|
def confluence_doc_sync(
|
||||||
cc_pair: ConnectorCredentialPair,
|
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||||
) -> list[DocExternalAccess]:
|
) -> list[DocExternalAccess]:
|
||||||
"""
|
"""
|
||||||
Adds the external permissions to the documents in postgres
|
Adds the external permissions to the documents in postgres
|
||||||
@@ -359,6 +367,12 @@ def confluence_doc_sync(
|
|||||||
logger.debug("Fetching all slim documents from confluence")
|
logger.debug("Fetching all slim documents from confluence")
|
||||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||||
|
if callback:
|
||||||
|
if callback.should_stop():
|
||||||
|
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||||
|
|
||||||
|
callback.progress("confluence_doc_sync", 1)
|
||||||
|
|
||||||
slim_docs.extend(doc_batch)
|
slim_docs.extend(doc_batch)
|
||||||
|
|
||||||
logger.debug("Fetching all page restrictions for space")
|
logger.debug("Fetching all page restrictions for space")
|
||||||
@@ -367,4 +381,5 @@ def confluence_doc_sync(
|
|||||||
slim_docs=slim_docs,
|
slim_docs=slim_docs,
|
||||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||||
is_cloud=is_cloud,
|
is_cloud=is_cloud,
|
||||||
|
callback=callback,
|
||||||
)
|
)
|
||||||
|
@@ -6,6 +6,7 @@ from onyx.access.models import ExternalAccess
|
|||||||
from onyx.connectors.gmail.connector import GmailConnector
|
from onyx.connectors.gmail.connector import GmailConnector
|
||||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||||
from onyx.db.models import ConnectorCredentialPair
|
from onyx.db.models import ConnectorCredentialPair
|
||||||
|
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@@ -28,7 +29,7 @@ def _get_slim_doc_generator(
|
|||||||
|
|
||||||
|
|
||||||
def gmail_doc_sync(
|
def gmail_doc_sync(
|
||||||
cc_pair: ConnectorCredentialPair,
|
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||||
) -> list[DocExternalAccess]:
|
) -> list[DocExternalAccess]:
|
||||||
"""
|
"""
|
||||||
Adds the external permissions to the documents in postgres
|
Adds the external permissions to the documents in postgres
|
||||||
@@ -44,6 +45,12 @@ def gmail_doc_sync(
|
|||||||
document_external_access: list[DocExternalAccess] = []
|
document_external_access: list[DocExternalAccess] = []
|
||||||
for slim_doc_batch in slim_doc_generator:
|
for slim_doc_batch in slim_doc_generator:
|
||||||
for slim_doc in slim_doc_batch:
|
for slim_doc in slim_doc_batch:
|
||||||
|
if callback:
|
||||||
|
if callback.should_stop():
|
||||||
|
raise RuntimeError("gmail_doc_sync: Stop signal detected")
|
||||||
|
|
||||||
|
callback.progress("gmail_doc_sync", 1)
|
||||||
|
|
||||||
if slim_doc.perm_sync_data is None:
|
if slim_doc.perm_sync_data is None:
|
||||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||||
continue
|
continue
|
||||||
|
@@ -10,6 +10,7 @@ from onyx.connectors.google_utils.resources import get_drive_service
|
|||||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||||
from onyx.connectors.models import SlimDocument
|
from onyx.connectors.models import SlimDocument
|
||||||
from onyx.db.models import ConnectorCredentialPair
|
from onyx.db.models import ConnectorCredentialPair
|
||||||
|
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@@ -128,7 +129,7 @@ def _get_permissions_from_slim_doc(
|
|||||||
|
|
||||||
|
|
||||||
def gdrive_doc_sync(
|
def gdrive_doc_sync(
|
||||||
cc_pair: ConnectorCredentialPair,
|
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||||
) -> list[DocExternalAccess]:
|
) -> list[DocExternalAccess]:
|
||||||
"""
|
"""
|
||||||
Adds the external permissions to the documents in postgres
|
Adds the external permissions to the documents in postgres
|
||||||
@@ -146,6 +147,12 @@ def gdrive_doc_sync(
|
|||||||
document_external_accesses = []
|
document_external_accesses = []
|
||||||
for slim_doc_batch in slim_doc_generator:
|
for slim_doc_batch in slim_doc_generator:
|
||||||
for slim_doc in slim_doc_batch:
|
for slim_doc in slim_doc_batch:
|
||||||
|
if callback:
|
||||||
|
if callback.should_stop():
|
||||||
|
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
|
||||||
|
|
||||||
|
callback.progress("gdrive_doc_sync", 1)
|
||||||
|
|
||||||
ext_access = _get_permissions_from_slim_doc(
|
ext_access = _get_permissions_from_slim_doc(
|
||||||
google_drive_connector=google_drive_connector,
|
google_drive_connector=google_drive_connector,
|
||||||
slim_doc=slim_doc,
|
slim_doc=slim_doc,
|
||||||
|
@@ -7,6 +7,7 @@ from onyx.connectors.slack.connector import get_channels
|
|||||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||||
from onyx.connectors.slack.connector import SlackPollConnector
|
from onyx.connectors.slack.connector import SlackPollConnector
|
||||||
from onyx.db.models import ConnectorCredentialPair
|
from onyx.db.models import ConnectorCredentialPair
|
||||||
|
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
@@ -14,7 +15,7 @@ logger = setup_logger()
|
|||||||
|
|
||||||
|
|
||||||
def _get_slack_document_ids_and_channels(
|
def _get_slack_document_ids_and_channels(
|
||||||
cc_pair: ConnectorCredentialPair,
|
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||||
) -> dict[str, list[str]]:
|
) -> dict[str, list[str]]:
|
||||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||||
@@ -24,6 +25,14 @@ def _get_slack_document_ids_and_channels(
|
|||||||
channel_doc_map: dict[str, list[str]] = {}
|
channel_doc_map: dict[str, list[str]] = {}
|
||||||
for doc_metadata_batch in slim_doc_generator:
|
for doc_metadata_batch in slim_doc_generator:
|
||||||
for doc_metadata in doc_metadata_batch:
|
for doc_metadata in doc_metadata_batch:
|
||||||
|
if callback:
|
||||||
|
if callback.should_stop():
|
||||||
|
raise RuntimeError(
|
||||||
|
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||||
|
)
|
||||||
|
|
||||||
|
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||||
|
|
||||||
if doc_metadata.perm_sync_data is None:
|
if doc_metadata.perm_sync_data is None:
|
||||||
continue
|
continue
|
||||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||||
@@ -114,7 +123,7 @@ def _fetch_channel_permissions(
|
|||||||
|
|
||||||
|
|
||||||
def slack_doc_sync(
|
def slack_doc_sync(
|
||||||
cc_pair: ConnectorCredentialPair,
|
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||||
) -> list[DocExternalAccess]:
|
) -> list[DocExternalAccess]:
|
||||||
"""
|
"""
|
||||||
Adds the external permissions to the documents in postgres
|
Adds the external permissions to the documents in postgres
|
||||||
@@ -127,7 +136,7 @@ def slack_doc_sync(
|
|||||||
)
|
)
|
||||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||||
cc_pair=cc_pair,
|
cc_pair=cc_pair, callback=callback
|
||||||
)
|
)
|
||||||
workspace_permissions = _fetch_workspace_permissions(
|
workspace_permissions = _fetch_workspace_permissions(
|
||||||
user_id_to_email_map=user_id_to_email_map,
|
user_id_to_email_map=user_id_to_email_map,
|
||||||
|
@@ -15,11 +15,13 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
|||||||
from onyx.access.models import DocExternalAccess
|
from onyx.access.models import DocExternalAccess
|
||||||
from onyx.configs.constants import DocumentSource
|
from onyx.configs.constants import DocumentSource
|
||||||
from onyx.db.models import ConnectorCredentialPair
|
from onyx.db.models import ConnectorCredentialPair
|
||||||
|
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||||
|
|
||||||
# Defining the input/output types for the sync functions
|
# Defining the input/output types for the sync functions
|
||||||
DocSyncFuncType = Callable[
|
DocSyncFuncType = Callable[
|
||||||
[
|
[
|
||||||
ConnectorCredentialPair,
|
ConnectorCredentialPair,
|
||||||
|
IndexingHeartbeatInterface | None,
|
||||||
],
|
],
|
||||||
list[DocExternalAccess],
|
list[DocExternalAccess],
|
||||||
]
|
]
|
||||||
|
@@ -198,7 +198,8 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
|||||||
|
|
||||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||||
"""Waits for redis to become ready subject to a hardcoded timeout.
|
"""Waits for redis to become ready subject to a hardcoded timeout.
|
||||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||||
|
is reached."""
|
||||||
|
|
||||||
r = get_redis_client(tenant_id=None)
|
r = get_redis_client(tenant_id=None)
|
||||||
|
|
||||||
|
@@ -91,6 +91,28 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
|
||||||
|
"""This is a redis specific way to build a list of tasks in a queue.
|
||||||
|
|
||||||
|
This helps us read the queue once and then efficiently look for missing tasks
|
||||||
|
in the queue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_set: set[str] = set()
|
||||||
|
|
||||||
|
for priority in range(len(OnyxCeleryPriority)):
|
||||||
|
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
|
||||||
|
|
||||||
|
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
|
||||||
|
for task in tasks:
|
||||||
|
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
|
||||||
|
task_id = task_dict.get("headers", {}).get("id")
|
||||||
|
if task_id:
|
||||||
|
task_set.add(task_id)
|
||||||
|
|
||||||
|
return task_set
|
||||||
|
|
||||||
|
|
||||||
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
||||||
"""Returns a list of current workers containing name_filter, or all workers if
|
"""Returns a list of current workers containing name_filter, or all workers if
|
||||||
name_filter is None.
|
name_filter is None.
|
||||||
|
@@ -3,13 +3,16 @@ from datetime import datetime
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
from time import sleep
|
from time import sleep
|
||||||
|
from typing import cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from celery import Task
|
from celery import Task
|
||||||
from celery.exceptions import SoftTimeLimitExceeded
|
from celery.exceptions import SoftTimeLimitExceeded
|
||||||
|
from pydantic import ValidationError
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
from redis.exceptions import LockError
|
||||||
from redis.lock import Lock as RedisLock
|
from redis.lock import Lock as RedisLock
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -22,6 +25,10 @@ from ee.onyx.external_permissions.sync_params import (
|
|||||||
)
|
)
|
||||||
from onyx.access.models import DocExternalAccess
|
from onyx.access.models import DocExternalAccess
|
||||||
from onyx.background.celery.apps.app_base import task_logger
|
from onyx.background.celery.apps.app_base import task_logger
|
||||||
|
from onyx.background.celery.celery_redis import celery_find_task
|
||||||
|
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||||
|
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||||
|
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||||
@@ -32,6 +39,7 @@ from onyx.configs.constants import OnyxCeleryPriority
|
|||||||
from onyx.configs.constants import OnyxCeleryQueues
|
from onyx.configs.constants import OnyxCeleryQueues
|
||||||
from onyx.configs.constants import OnyxCeleryTask
|
from onyx.configs.constants import OnyxCeleryTask
|
||||||
from onyx.configs.constants import OnyxRedisLocks
|
from onyx.configs.constants import OnyxRedisLocks
|
||||||
|
from onyx.configs.constants import OnyxRedisSignals
|
||||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||||
@@ -44,14 +52,19 @@ from onyx.db.models import ConnectorCredentialPair
|
|||||||
from onyx.db.sync_record import insert_sync_record
|
from onyx.db.sync_record import insert_sync_record
|
||||||
from onyx.db.sync_record import update_sync_record_status
|
from onyx.db.sync_record import update_sync_record_status
|
||||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||||
|
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||||
from onyx.redis.redis_connector import RedisConnector
|
from onyx.redis.redis_connector import RedisConnector
|
||||||
from onyx.redis.redis_connector_doc_perm_sync import (
|
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||||
RedisConnectorPermissionSyncPayload,
|
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyncPayload
|
||||||
)
|
|
||||||
from onyx.redis.redis_pool import get_redis_client
|
from onyx.redis.redis_pool import get_redis_client
|
||||||
|
from onyx.redis.redis_pool import redis_lock_dump
|
||||||
|
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||||
|
from onyx.server.utils import make_short_id
|
||||||
from onyx.utils.logger import doc_permission_sync_ctx
|
from onyx.utils.logger import doc_permission_sync_ctx
|
||||||
|
from onyx.utils.logger import LoggerContextVars
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -105,7 +118,12 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
|||||||
bind=True,
|
bind=True,
|
||||||
)
|
)
|
||||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||||
|
# TODO(rkuo): merge into check function after lookup table for fences is added
|
||||||
|
|
||||||
|
# we need to use celery's redis client to access its redis data
|
||||||
|
# (which lives on a different db number)
|
||||||
r = get_redis_client(tenant_id=tenant_id)
|
r = get_redis_client(tenant_id=tenant_id)
|
||||||
|
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||||
|
|
||||||
lock_beat: RedisLock = r.lock(
|
lock_beat: RedisLock = r.lock(
|
||||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||||
@@ -126,14 +144,32 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
|||||||
if _is_external_doc_permissions_sync_due(cc_pair):
|
if _is_external_doc_permissions_sync_due(cc_pair):
|
||||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||||
|
|
||||||
|
lock_beat.reacquire()
|
||||||
for cc_pair_id in cc_pair_ids_to_sync:
|
for cc_pair_id in cc_pair_ids_to_sync:
|
||||||
tasks_created = try_creating_permissions_sync_task(
|
payload_id = try_creating_permissions_sync_task(
|
||||||
self.app, cc_pair_id, r, tenant_id
|
self.app, cc_pair_id, r, tenant_id
|
||||||
)
|
)
|
||||||
if not tasks_created:
|
if not payload_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
|
task_logger.info(
|
||||||
|
f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# we want to run this less frequently than the overall task
|
||||||
|
lock_beat.reacquire()
|
||||||
|
if not r.exists(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES):
|
||||||
|
# clear any permission fences that don't have associated celery tasks in progress
|
||||||
|
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||||
|
# or be currently executing
|
||||||
|
try:
|
||||||
|
validate_permission_sync_fences(tenant_id, r, r_celery, lock_beat)
|
||||||
|
except Exception:
|
||||||
|
task_logger.exception(
|
||||||
|
"Exception while validating permission sync fences"
|
||||||
|
)
|
||||||
|
|
||||||
|
r.set(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=60)
|
||||||
except SoftTimeLimitExceeded:
|
except SoftTimeLimitExceeded:
|
||||||
task_logger.info(
|
task_logger.info(
|
||||||
"Soft time limit exceeded, task is being terminated gracefully."
|
"Soft time limit exceeded, task is being terminated gracefully."
|
||||||
@@ -152,13 +188,15 @@ def try_creating_permissions_sync_task(
|
|||||||
cc_pair_id: int,
|
cc_pair_id: int,
|
||||||
r: Redis,
|
r: Redis,
|
||||||
tenant_id: str | None,
|
tenant_id: str | None,
|
||||||
) -> int | None:
|
) -> str | None:
|
||||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
"""Returns a randomized payload id on success.
|
||||||
Returns None if no syncing is required."""
|
Returns None if no syncing is required."""
|
||||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
|
||||||
|
|
||||||
LOCK_TIMEOUT = 30
|
LOCK_TIMEOUT = 30
|
||||||
|
|
||||||
|
payload_id: str | None = None
|
||||||
|
|
||||||
|
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||||
|
|
||||||
lock: RedisLock = r.lock(
|
lock: RedisLock = r.lock(
|
||||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
||||||
timeout=LOCK_TIMEOUT,
|
timeout=LOCK_TIMEOUT,
|
||||||
@@ -193,7 +231,13 @@ def try_creating_permissions_sync_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# set a basic fence to start
|
# set a basic fence to start
|
||||||
payload = RedisConnectorPermissionSyncPayload(started=None, celery_task_id=None)
|
redis_connector.permissions.set_active()
|
||||||
|
payload = RedisConnectorPermissionSyncPayload(
|
||||||
|
id=make_short_id(),
|
||||||
|
submitted=datetime.now(timezone.utc),
|
||||||
|
started=None,
|
||||||
|
celery_task_id=None,
|
||||||
|
)
|
||||||
redis_connector.permissions.set_fence(payload)
|
redis_connector.permissions.set_fence(payload)
|
||||||
|
|
||||||
result = app.send_task(
|
result = app.send_task(
|
||||||
@@ -208,8 +252,11 @@ def try_creating_permissions_sync_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# fill in the celery task id
|
# fill in the celery task id
|
||||||
|
redis_connector.permissions.set_active()
|
||||||
payload.celery_task_id = result.id
|
payload.celery_task_id = result.id
|
||||||
redis_connector.permissions.set_fence(payload)
|
redis_connector.permissions.set_fence(payload)
|
||||||
|
|
||||||
|
payload_id = payload.celery_task_id
|
||||||
except Exception:
|
except Exception:
|
||||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||||
return None
|
return None
|
||||||
@@ -217,7 +264,7 @@ def try_creating_permissions_sync_task(
|
|||||||
if lock.owned():
|
if lock.owned():
|
||||||
lock.release()
|
lock.release()
|
||||||
|
|
||||||
return 1
|
return payload_id
|
||||||
|
|
||||||
|
|
||||||
@shared_task(
|
@shared_task(
|
||||||
@@ -238,6 +285,8 @@ def connector_permission_sync_generator_task(
|
|||||||
This task assumes that the task has already been properly fenced
|
This task assumes that the task has already been properly fenced
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
LoggerContextVars.reset()
|
||||||
|
|
||||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||||
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
|
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||||
doc_permission_sync_ctx_dict["request_id"] = self.request.id
|
doc_permission_sync_ctx_dict["request_id"] = self.request.id
|
||||||
@@ -325,12 +374,17 @@ def connector_permission_sync_generator_task(
|
|||||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||||
|
|
||||||
new_payload = RedisConnectorPermissionSyncPayload(
|
new_payload = RedisConnectorPermissionSyncPayload(
|
||||||
|
id=payload.id,
|
||||||
|
submitted=payload.submitted,
|
||||||
started=datetime.now(timezone.utc),
|
started=datetime.now(timezone.utc),
|
||||||
celery_task_id=payload.celery_task_id,
|
celery_task_id=payload.celery_task_id,
|
||||||
)
|
)
|
||||||
redis_connector.permissions.set_fence(new_payload)
|
redis_connector.permissions.set_fence(new_payload)
|
||||||
|
|
||||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
callback = PermissionSyncCallback(redis_connector, lock, r)
|
||||||
|
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
|
||||||
|
cc_pair, callback
|
||||||
|
)
|
||||||
|
|
||||||
task_logger.info(
|
task_logger.info(
|
||||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||||
@@ -380,6 +434,8 @@ def update_external_document_permissions_task(
|
|||||||
connector_id: int,
|
connector_id: int,
|
||||||
credential_id: int,
|
credential_id: int,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
start = time.monotonic()
|
||||||
|
|
||||||
document_external_access = DocExternalAccess.from_dict(
|
document_external_access = DocExternalAccess.from_dict(
|
||||||
serialized_doc_external_access
|
serialized_doc_external_access
|
||||||
)
|
)
|
||||||
@@ -409,16 +465,268 @@ def update_external_document_permissions_task(
|
|||||||
document_ids=[doc_id],
|
document_ids=[doc_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
elapsed = time.monotonic() - start
|
||||||
f"Successfully synced postgres document permissions for {doc_id}"
|
task_logger.info(
|
||||||
|
f"connector_id={connector_id} "
|
||||||
|
f"doc={doc_id} "
|
||||||
|
f"action=update_permissions "
|
||||||
|
f"elapsed={elapsed:.2f}"
|
||||||
)
|
)
|
||||||
return True
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
task_logger.exception(
|
||||||
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
|
f"Exception in update_external_document_permissions_task: "
|
||||||
|
f"connector_id={connector_id} "
|
||||||
|
f"doc_id={doc_id}"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def validate_permission_sync_fences(
|
||||||
|
tenant_id: str | None,
|
||||||
|
r: Redis,
|
||||||
|
r_celery: Redis,
|
||||||
|
lock_beat: RedisLock,
|
||||||
|
) -> None:
|
||||||
|
# building lookup table can be expensive, so we won't bother
|
||||||
|
# validating until the queue is small
|
||||||
|
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
|
||||||
|
|
||||||
|
queue_len = celery_get_queue_length(
|
||||||
|
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||||
|
)
|
||||||
|
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
|
||||||
|
return
|
||||||
|
|
||||||
|
queued_upsert_tasks = celery_get_queued_task_ids(
|
||||||
|
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||||
|
)
|
||||||
|
reserved_generator_tasks = celery_get_unacked_task_ids(
|
||||||
|
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||||
|
)
|
||||||
|
|
||||||
|
# validate all existing indexing jobs
|
||||||
|
for key_bytes in r.scan_iter(
|
||||||
|
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
|
||||||
|
count=SCAN_ITER_COUNT_DEFAULT,
|
||||||
|
):
|
||||||
|
lock_beat.reacquire()
|
||||||
|
validate_permission_sync_fence(
|
||||||
|
tenant_id,
|
||||||
|
key_bytes,
|
||||||
|
queued_upsert_tasks,
|
||||||
|
reserved_generator_tasks,
|
||||||
|
r,
|
||||||
|
r_celery,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def validate_permission_sync_fence(
|
||||||
|
tenant_id: str | None,
|
||||||
|
key_bytes: bytes,
|
||||||
|
queued_tasks: set[str],
|
||||||
|
reserved_tasks: set[str],
|
||||||
|
r: Redis,
|
||||||
|
r_celery: Redis,
|
||||||
|
) -> None:
|
||||||
|
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||||
|
This can happen if the indexing worker hard crashes or is terminated.
|
||||||
|
Being in this bad state means the fence will never clear without help, so this function
|
||||||
|
gives the help.
|
||||||
|
|
||||||
|
How this works:
|
||||||
|
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||||
|
1.2. When the task is seen in the redis queue
|
||||||
|
1.3. When the task is seen in the reserved / prefetched list
|
||||||
|
|
||||||
|
2. Externally, the active signal is renewed when:
|
||||||
|
2.1. The fence is created
|
||||||
|
2.2. The indexing watchdog checks the spawned task.
|
||||||
|
|
||||||
|
3. The TTL allows us to get through the transitions on fence startup
|
||||||
|
and when the task starts executing.
|
||||||
|
|
||||||
|
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||||
|
whether a task is in the queue or currently executing.
|
||||||
|
1. An unknown task id is always returned as state PENDING.
|
||||||
|
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||||
|
and the time it actually starts on the worker.
|
||||||
|
|
||||||
|
queued_tasks: the celery queue of lightweight permission sync tasks
|
||||||
|
reserved_tasks: prefetched tasks for sync task generator
|
||||||
|
"""
|
||||||
|
# if the fence doesn't exist, there's nothing to do
|
||||||
|
fence_key = key_bytes.decode("utf-8")
|
||||||
|
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||||
|
if cc_pair_id_str is None:
|
||||||
|
task_logger.warning(
|
||||||
|
f"validate_permission_sync_fence - could not parse id from {fence_key}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
cc_pair_id = int(cc_pair_id_str)
|
||||||
|
# parse out metadata and initialize the helper class with it
|
||||||
|
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||||
|
|
||||||
|
# check to see if the fence/payload exists
|
||||||
|
if not redis_connector.permissions.fenced:
|
||||||
|
return
|
||||||
|
|
||||||
|
# in the cloud, the payload format may have changed ...
|
||||||
|
# it's a little sloppy, but just reset the fence for now if that happens
|
||||||
|
# TODO: add intentional cleanup/abort logic
|
||||||
|
try:
|
||||||
|
payload = redis_connector.permissions.payload
|
||||||
|
except ValidationError:
|
||||||
|
task_logger.exception(
|
||||||
|
"validate_permission_sync_fence - "
|
||||||
|
"Resetting fence because fence schema is out of date: "
|
||||||
|
f"cc_pair={cc_pair_id} "
|
||||||
|
f"fence={fence_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_connector.permissions.reset()
|
||||||
|
return
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not payload.celery_task_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
# OK, there's actually something for us to validate
|
||||||
|
|
||||||
|
# either the generator task must be in flight or its subtasks must be
|
||||||
|
found = celery_find_task(
|
||||||
|
payload.celery_task_id,
|
||||||
|
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||||
|
r_celery,
|
||||||
|
)
|
||||||
|
if found:
|
||||||
|
# the celery task exists in the redis queue
|
||||||
|
redis_connector.permissions.set_active()
|
||||||
|
return
|
||||||
|
|
||||||
|
if payload.celery_task_id in reserved_tasks:
|
||||||
|
# the celery task was prefetched and is reserved within a worker
|
||||||
|
redis_connector.permissions.set_active()
|
||||||
|
return
|
||||||
|
|
||||||
|
# look up every task in the current taskset in the celery queue
|
||||||
|
# every entry in the taskset should have an associated entry in the celery task queue
|
||||||
|
# because we get the celery tasks first, the entries in our own permissions taskset
|
||||||
|
# should be roughly a subset of the tasks in celery
|
||||||
|
|
||||||
|
# this check isn't very exact, but should be sufficient over a period of time
|
||||||
|
# A single successful check over some number of attempts is sufficient.
|
||||||
|
|
||||||
|
# TODO: if the number of tasks in celery is much lower than than the taskset length
|
||||||
|
# we might be able to shortcut the lookup since by definition some of the tasks
|
||||||
|
# must not exist in celery.
|
||||||
|
|
||||||
|
tasks_scanned = 0
|
||||||
|
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
|
||||||
|
|
||||||
|
for member in r.sscan_iter(redis_connector.permissions.taskset_key):
|
||||||
|
tasks_scanned += 1
|
||||||
|
|
||||||
|
member_bytes = cast(bytes, member)
|
||||||
|
member_str = member_bytes.decode("utf-8")
|
||||||
|
if member_str in queued_tasks:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if member_str in reserved_tasks:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tasks_not_in_celery += 1
|
||||||
|
|
||||||
|
task_logger.info(
|
||||||
|
"validate_permission_sync_fence task check: "
|
||||||
|
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if tasks_not_in_celery == 0:
|
||||||
|
redis_connector.permissions.set_active()
|
||||||
|
return
|
||||||
|
|
||||||
|
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||||
|
# if redis_connector_index.generator_locked():
|
||||||
|
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||||
|
|
||||||
|
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||||
|
# but they still might be there due to gaps in our ability to check states during transitions
|
||||||
|
# Checking the active signal safeguards us against these transition periods
|
||||||
|
# (which has a duration that allows us to bridge those gaps)
|
||||||
|
if redis_connector.permissions.active():
|
||||||
|
return
|
||||||
|
|
||||||
|
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||||
|
task_logger.warning(
|
||||||
|
"validate_permission_sync_fence - "
|
||||||
|
"Resetting fence because no associated celery tasks were found: "
|
||||||
|
f"cc_pair={cc_pair_id} "
|
||||||
|
f"fence={fence_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_connector.permissions.reset()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||||
|
PARENT_CHECK_INTERVAL = 60
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis_connector: RedisConnector,
|
||||||
|
redis_lock: RedisLock,
|
||||||
|
redis_client: Redis,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.redis_connector: RedisConnector = redis_connector
|
||||||
|
self.redis_lock: RedisLock = redis_lock
|
||||||
|
self.redis_client = redis_client
|
||||||
|
|
||||||
|
self.started: datetime = datetime.now(timezone.utc)
|
||||||
|
self.redis_lock.reacquire()
|
||||||
|
|
||||||
|
self.last_tag: str = "PermissionSyncCallback.__init__"
|
||||||
|
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||||
|
self.last_lock_monotonic = time.monotonic()
|
||||||
|
|
||||||
|
def should_stop(self) -> bool:
|
||||||
|
if self.redis_connector.stop.fenced:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def progress(self, tag: str, amount: int) -> None:
|
||||||
|
try:
|
||||||
|
self.redis_connector.permissions.set_active()
|
||||||
|
|
||||||
|
current_time = time.monotonic()
|
||||||
|
if current_time - self.last_lock_monotonic >= (
|
||||||
|
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||||
|
):
|
||||||
|
self.redis_lock.reacquire()
|
||||||
|
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||||
|
self.last_lock_monotonic = time.monotonic()
|
||||||
|
|
||||||
|
self.last_tag = tag
|
||||||
|
except LockError:
|
||||||
|
logger.exception(
|
||||||
|
f"PermissionSyncCallback - lock.reacquire exceptioned: "
|
||||||
|
f"lock_timeout={self.redis_lock.timeout} "
|
||||||
|
f"start={self.started} "
|
||||||
|
f"last_tag={self.last_tag} "
|
||||||
|
f"last_reacquired={self.last_lock_reacquire} "
|
||||||
|
f"now={datetime.now(timezone.utc)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
|
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
|
||||||
|
|
||||||
@@ -444,20 +752,36 @@ def monitor_ccpair_permissions_taskset(
|
|||||||
if initial is None:
|
if initial is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = redis_connector.permissions.payload
|
||||||
|
except ValidationError:
|
||||||
|
task_logger.exception(
|
||||||
|
"Permissions sync payload failed to validate. "
|
||||||
|
"Schema may have been updated."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
return
|
||||||
|
|
||||||
remaining = redis_connector.permissions.get_remaining()
|
remaining = redis_connector.permissions.get_remaining()
|
||||||
task_logger.info(
|
task_logger.info(
|
||||||
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
f"Permissions sync progress: "
|
||||||
|
f"cc_pair={cc_pair_id} "
|
||||||
|
f"id={payload.id} "
|
||||||
|
f"remaining={remaining} "
|
||||||
|
f"initial={initial}"
|
||||||
)
|
)
|
||||||
if remaining > 0:
|
if remaining > 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
payload: RedisConnectorPermissionSyncPayload | None = (
|
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), payload.started)
|
||||||
redis_connector.permissions.payload
|
task_logger.info(
|
||||||
|
f"Permissions sync finished: "
|
||||||
|
f"cc_pair={cc_pair_id} "
|
||||||
|
f"id={payload.id} "
|
||||||
|
f"num_synced={initial}"
|
||||||
)
|
)
|
||||||
start_time: datetime | None = payload.started if payload else None
|
|
||||||
|
|
||||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
|
||||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
|
||||||
|
|
||||||
update_sync_record_status(
|
update_sync_record_status(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
@@ -9,6 +10,7 @@ from celery import Task
|
|||||||
from celery.exceptions import SoftTimeLimitExceeded
|
from celery.exceptions import SoftTimeLimitExceeded
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
from redis.lock import Lock as RedisLock
|
from redis.lock import Lock as RedisLock
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||||
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
|
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
|
||||||
@@ -20,9 +22,12 @@ from ee.onyx.external_permissions.sync_params import (
|
|||||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||||
)
|
)
|
||||||
from onyx.background.celery.apps.app_base import task_logger
|
from onyx.background.celery.apps.app_base import task_logger
|
||||||
|
from onyx.background.celery.celery_redis import celery_find_task
|
||||||
|
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||||
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||||
|
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||||
from onyx.configs.constants import OnyxCeleryPriority
|
from onyx.configs.constants import OnyxCeleryPriority
|
||||||
from onyx.configs.constants import OnyxCeleryQueues
|
from onyx.configs.constants import OnyxCeleryQueues
|
||||||
@@ -39,10 +44,12 @@ from onyx.db.models import ConnectorCredentialPair
|
|||||||
from onyx.db.sync_record import insert_sync_record
|
from onyx.db.sync_record import insert_sync_record
|
||||||
from onyx.db.sync_record import update_sync_record_status
|
from onyx.db.sync_record import update_sync_record_status
|
||||||
from onyx.redis.redis_connector import RedisConnector
|
from onyx.redis.redis_connector import RedisConnector
|
||||||
|
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||||
from onyx.redis.redis_connector_ext_group_sync import (
|
from onyx.redis.redis_connector_ext_group_sync import (
|
||||||
RedisConnectorExternalGroupSyncPayload,
|
RedisConnectorExternalGroupSyncPayload,
|
||||||
)
|
)
|
||||||
from onyx.redis.redis_pool import get_redis_client
|
from onyx.redis.redis_pool import get_redis_client
|
||||||
|
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@@ -102,6 +109,10 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
|||||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||||
r = get_redis_client(tenant_id=tenant_id)
|
r = get_redis_client(tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# we need to use celery's redis client to access its redis data
|
||||||
|
# (which lives on a different db number)
|
||||||
|
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||||
|
|
||||||
lock_beat: RedisLock = r.lock(
|
lock_beat: RedisLock = r.lock(
|
||||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||||
@@ -136,6 +147,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
|||||||
if _is_external_group_sync_due(cc_pair):
|
if _is_external_group_sync_due(cc_pair):
|
||||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||||
|
|
||||||
|
lock_beat.reacquire()
|
||||||
for cc_pair_id in cc_pair_ids_to_sync:
|
for cc_pair_id in cc_pair_ids_to_sync:
|
||||||
tasks_created = try_creating_external_group_sync_task(
|
tasks_created = try_creating_external_group_sync_task(
|
||||||
self.app, cc_pair_id, r, tenant_id
|
self.app, cc_pair_id, r, tenant_id
|
||||||
@@ -144,6 +156,23 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
|
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
|
||||||
|
|
||||||
|
# we want to run this less frequently than the overall task
|
||||||
|
# lock_beat.reacquire()
|
||||||
|
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
|
||||||
|
# # clear any indexing fences that don't have associated celery tasks in progress
|
||||||
|
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||||
|
# # or be currently executing
|
||||||
|
# try:
|
||||||
|
# validate_external_group_sync_fences(
|
||||||
|
# tenant_id, self.app, r, r_celery, lock_beat
|
||||||
|
# )
|
||||||
|
# except Exception:
|
||||||
|
# task_logger.exception(
|
||||||
|
# "Exception while validating external group sync fences"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
|
||||||
except SoftTimeLimitExceeded:
|
except SoftTimeLimitExceeded:
|
||||||
task_logger.info(
|
task_logger.info(
|
||||||
"Soft time limit exceeded, task is being terminated gracefully."
|
"Soft time limit exceeded, task is being terminated gracefully."
|
||||||
@@ -186,6 +215,12 @@ def try_creating_external_group_sync_task(
|
|||||||
redis_connector.external_group_sync.generator_clear()
|
redis_connector.external_group_sync.generator_clear()
|
||||||
redis_connector.external_group_sync.taskset_clear()
|
redis_connector.external_group_sync.taskset_clear()
|
||||||
|
|
||||||
|
payload = RedisConnectorExternalGroupSyncPayload(
|
||||||
|
submitted=datetime.now(timezone.utc),
|
||||||
|
started=None,
|
||||||
|
celery_task_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||||
|
|
||||||
result = app.send_task(
|
result = app.send_task(
|
||||||
@@ -199,11 +234,6 @@ def try_creating_external_group_sync_task(
|
|||||||
priority=OnyxCeleryPriority.HIGH,
|
priority=OnyxCeleryPriority.HIGH,
|
||||||
)
|
)
|
||||||
|
|
||||||
payload = RedisConnectorExternalGroupSyncPayload(
|
|
||||||
started=datetime.now(timezone.utc),
|
|
||||||
celery_task_id=result.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# create before setting fence to avoid race condition where the monitoring
|
# create before setting fence to avoid race condition where the monitoring
|
||||||
# task updates the sync record before it is created
|
# task updates the sync record before it is created
|
||||||
with get_session_with_tenant(tenant_id) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
@@ -213,8 +243,8 @@ def try_creating_external_group_sync_task(
|
|||||||
sync_type=SyncType.EXTERNAL_GROUP,
|
sync_type=SyncType.EXTERNAL_GROUP,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
payload.celery_task_id = result.id
|
||||||
redis_connector.external_group_sync.set_fence(payload)
|
redis_connector.external_group_sync.set_fence(payload)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
task_logger.exception(
|
task_logger.exception(
|
||||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||||
@@ -241,7 +271,7 @@ def connector_external_group_sync_generator_task(
|
|||||||
tenant_id: str | None,
|
tenant_id: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Permission sync task that handles external group syncing for a given connector credential pair
|
External group sync task for a given connector credential pair
|
||||||
This task assumes that the task has already been properly fenced
|
This task assumes that the task has already been properly fenced
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -249,19 +279,59 @@ def connector_external_group_sync_generator_task(
|
|||||||
|
|
||||||
r = get_redis_client(tenant_id=tenant_id)
|
r = get_redis_client(tenant_id=tenant_id)
|
||||||
|
|
||||||
|
# this wait is needed to avoid a race condition where
|
||||||
|
# the primary worker sends the task and it is immediately executed
|
||||||
|
# before the primary worker can finalize the fence
|
||||||
|
start = time.monotonic()
|
||||||
|
while True:
|
||||||
|
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||||
|
raise ValueError(
|
||||||
|
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
|
||||||
|
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not redis_connector.external_group_sync.fenced: # The fence must exist
|
||||||
|
raise ValueError(
|
||||||
|
f"connector_external_group_sync_generator_task - fence not found: "
|
||||||
|
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = redis_connector.external_group_sync.payload # The payload must exist
|
||||||
|
if not payload:
|
||||||
|
raise ValueError(
|
||||||
|
"connector_external_group_sync_generator_task: payload invalid or not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
if payload.celery_task_id is None:
|
||||||
|
logger.info(
|
||||||
|
f"connector_external_group_sync_generator_task - Waiting for fence: "
|
||||||
|
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||||
|
)
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
|
||||||
|
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
lock: RedisLock = r.lock(
|
lock: RedisLock = r.lock(
|
||||||
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||||
+ f"_{redis_connector.id}",
|
+ f"_{redis_connector.id}",
|
||||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
acquired = lock.acquire(blocking=False)
|
||||||
|
if not acquired:
|
||||||
|
task_logger.warning(
|
||||||
|
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
acquired = lock.acquire(blocking=False)
|
payload.started = datetime.now(timezone.utc)
|
||||||
if not acquired:
|
redis_connector.external_group_sync.set_fence(payload)
|
||||||
task_logger.warning(
|
|
||||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
with get_session_with_tenant(tenant_id) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
cc_pair = get_connector_credential_pair_from_id(
|
cc_pair = get_connector_credential_pair_from_id(
|
||||||
@@ -330,3 +400,135 @@ def connector_external_group_sync_generator_task(
|
|||||||
redis_connector.external_group_sync.set_fence(None)
|
redis_connector.external_group_sync.set_fence(None)
|
||||||
if lock.owned():
|
if lock.owned():
|
||||||
lock.release()
|
lock.release()
|
||||||
|
|
||||||
|
|
||||||
|
def validate_external_group_sync_fences(
|
||||||
|
tenant_id: str | None,
|
||||||
|
celery_app: Celery,
|
||||||
|
r: Redis,
|
||||||
|
r_celery: Redis,
|
||||||
|
lock_beat: RedisLock,
|
||||||
|
) -> None:
|
||||||
|
reserved_sync_tasks = celery_get_unacked_task_ids(
|
||||||
|
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||||
|
)
|
||||||
|
|
||||||
|
# validate all existing indexing jobs
|
||||||
|
for key_bytes in r.scan_iter(
|
||||||
|
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
|
||||||
|
count=SCAN_ITER_COUNT_DEFAULT,
|
||||||
|
):
|
||||||
|
lock_beat.reacquire()
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
|
validate_external_group_sync_fence(
|
||||||
|
tenant_id,
|
||||||
|
key_bytes,
|
||||||
|
reserved_sync_tasks,
|
||||||
|
r_celery,
|
||||||
|
db_session,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def validate_external_group_sync_fence(
|
||||||
|
tenant_id: str | None,
|
||||||
|
key_bytes: bytes,
|
||||||
|
reserved_tasks: set[str],
|
||||||
|
r_celery: Redis,
|
||||||
|
db_session: Session,
|
||||||
|
) -> None:
|
||||||
|
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||||
|
This can happen if the indexing worker hard crashes or is terminated.
|
||||||
|
Being in this bad state means the fence will never clear without help, so this function
|
||||||
|
gives the help.
|
||||||
|
|
||||||
|
How this works:
|
||||||
|
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||||
|
1.2. When the task is seen in the redis queue
|
||||||
|
1.3. When the task is seen in the reserved / prefetched list
|
||||||
|
|
||||||
|
2. Externally, the active signal is renewed when:
|
||||||
|
2.1. The fence is created
|
||||||
|
2.2. The indexing watchdog checks the spawned task.
|
||||||
|
|
||||||
|
3. The TTL allows us to get through the transitions on fence startup
|
||||||
|
and when the task starts executing.
|
||||||
|
|
||||||
|
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||||
|
whether a task is in the queue or currently executing.
|
||||||
|
1. An unknown task id is always returned as state PENDING.
|
||||||
|
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||||
|
and the time it actually starts on the worker.
|
||||||
|
"""
|
||||||
|
# if the fence doesn't exist, there's nothing to do
|
||||||
|
fence_key = key_bytes.decode("utf-8")
|
||||||
|
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||||
|
if cc_pair_id_str is None:
|
||||||
|
task_logger.warning(
|
||||||
|
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
cc_pair_id = int(cc_pair_id_str)
|
||||||
|
|
||||||
|
# parse out metadata and initialize the helper class with it
|
||||||
|
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||||
|
|
||||||
|
# check to see if the fence/payload exists
|
||||||
|
if not redis_connector.external_group_sync.fenced:
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = redis_connector.external_group_sync.payload
|
||||||
|
if not payload:
|
||||||
|
return
|
||||||
|
|
||||||
|
# OK, there's actually something for us to validate
|
||||||
|
|
||||||
|
if payload.celery_task_id is None:
|
||||||
|
# the fence is just barely set up.
|
||||||
|
# if redis_connector_index.active():
|
||||||
|
# return
|
||||||
|
|
||||||
|
# it would be odd to get here as there isn't that much that can go wrong during
|
||||||
|
# initial fence setup, but it's still worth making sure we can recover
|
||||||
|
logger.info(
|
||||||
|
"validate_external_group_sync_fence - "
|
||||||
|
f"Resetting fence in basic state without any activity: fence={fence_key}"
|
||||||
|
)
|
||||||
|
redis_connector.external_group_sync.reset()
|
||||||
|
return
|
||||||
|
|
||||||
|
found = celery_find_task(
|
||||||
|
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||||
|
)
|
||||||
|
if found:
|
||||||
|
# the celery task exists in the redis queue
|
||||||
|
# redis_connector_index.set_active()
|
||||||
|
return
|
||||||
|
|
||||||
|
if payload.celery_task_id in reserved_tasks:
|
||||||
|
# the celery task was prefetched and is reserved within the indexing worker
|
||||||
|
# redis_connector_index.set_active()
|
||||||
|
return
|
||||||
|
|
||||||
|
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||||
|
# if redis_connector_index.generator_locked():
|
||||||
|
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||||
|
|
||||||
|
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||||
|
# but they still might be there due to gaps in our ability to check states during transitions
|
||||||
|
# Checking the active signal safeguards us against these transition periods
|
||||||
|
# (which has a duration that allows us to bridge those gaps)
|
||||||
|
# if redis_connector_index.active():
|
||||||
|
# return
|
||||||
|
|
||||||
|
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||||
|
logger.warning(
|
||||||
|
"validate_external_group_sync_fence - "
|
||||||
|
"Resetting fence because no associated celery tasks were found: "
|
||||||
|
f"cc_pair={cc_pair_id} "
|
||||||
|
f"fence={fence_key}"
|
||||||
|
)
|
||||||
|
|
||||||
|
redis_connector.external_group_sync.reset()
|
||||||
|
return
|
||||||
|
@@ -39,6 +39,7 @@ from onyx.db.sync_record import insert_sync_record
|
|||||||
from onyx.db.sync_record import update_sync_record_status
|
from onyx.db.sync_record import update_sync_record_status
|
||||||
from onyx.redis.redis_connector import RedisConnector
|
from onyx.redis.redis_connector import RedisConnector
|
||||||
from onyx.redis.redis_pool import get_redis_client
|
from onyx.redis.redis_pool import get_redis_client
|
||||||
|
from onyx.utils.logger import LoggerContextVars
|
||||||
from onyx.utils.logger import pruning_ctx
|
from onyx.utils.logger import pruning_ctx
|
||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
|
|
||||||
@@ -251,6 +252,8 @@ def connector_pruning_generator_task(
|
|||||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||||
from the most recently pulled document ID list"""
|
from the most recently pulled document ID list"""
|
||||||
|
|
||||||
|
LoggerContextVars.reset()
|
||||||
|
|
||||||
pruning_ctx_dict = pruning_ctx.get()
|
pruning_ctx_dict = pruning_ctx.get()
|
||||||
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
|
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||||
pruning_ctx_dict["request_id"] = self.request.id
|
pruning_ctx_dict["request_id"] = self.request.id
|
||||||
@@ -399,7 +402,7 @@ def monitor_ccpair_pruning_taskset(
|
|||||||
|
|
||||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||||
task_logger.info(
|
task_logger.info(
|
||||||
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
|
f"Connector pruning finished: cc_pair={cc_pair_id} num_pruned={initial}"
|
||||||
)
|
)
|
||||||
|
|
||||||
update_sync_record_status(
|
update_sync_record_status(
|
||||||
|
@@ -75,6 +75,8 @@ def document_by_cc_pair_cleanup_task(
|
|||||||
"""
|
"""
|
||||||
task_logger.debug(f"Task start: doc={document_id}")
|
task_logger.debug(f"Task start: doc={document_id}")
|
||||||
|
|
||||||
|
start = time.monotonic()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with get_session_with_tenant(tenant_id) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
action = "skip"
|
action = "skip"
|
||||||
@@ -154,11 +156,13 @@ def document_by_cc_pair_cleanup_task(
|
|||||||
|
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - start
|
||||||
task_logger.info(
|
task_logger.info(
|
||||||
f"doc={document_id} "
|
f"doc={document_id} "
|
||||||
f"action={action} "
|
f"action={action} "
|
||||||
f"refcount={count} "
|
f"refcount={count} "
|
||||||
f"chunks={chunks_affected}"
|
f"chunks={chunks_affected} "
|
||||||
|
f"elapsed={elapsed:.2f}"
|
||||||
)
|
)
|
||||||
except SoftTimeLimitExceeded:
|
except SoftTimeLimitExceeded:
|
||||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||||
|
@@ -989,6 +989,10 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
|||||||
task_logger.info(
|
task_logger.info(
|
||||||
"Soft time limit exceeded, task is being terminated gracefully."
|
"Soft time limit exceeded, task is being terminated gracefully."
|
||||||
)
|
)
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
task_logger.exception("monitor_vespa_sync exceptioned.")
|
||||||
|
return False
|
||||||
finally:
|
finally:
|
||||||
if lock_beat.owned():
|
if lock_beat.owned():
|
||||||
lock_beat.release()
|
lock_beat.release()
|
||||||
@@ -1078,6 +1082,7 @@ def vespa_metadata_sync_task(
|
|||||||
)
|
)
|
||||||
except SoftTimeLimitExceeded:
|
except SoftTimeLimitExceeded:
|
||||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||||
|
return False
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
if isinstance(ex, RetryError):
|
if isinstance(ex, RetryError):
|
||||||
task_logger.warning(
|
task_logger.warning(
|
||||||
|
@@ -300,6 +300,8 @@ class OnyxRedisLocks:
|
|||||||
|
|
||||||
class OnyxRedisSignals:
|
class OnyxRedisSignals:
|
||||||
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
|
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
|
||||||
|
VALIDATE_EXTERNAL_GROUP_SYNC_FENCES = "signal:validate_external_group_sync_fences"
|
||||||
|
VALIDATE_PERMISSION_SYNC_FENCES = "signal:validate_permission_sync_fences"
|
||||||
|
|
||||||
|
|
||||||
class OnyxCeleryPriority(int, Enum):
|
class OnyxCeleryPriority(int, Enum):
|
||||||
|
@@ -17,6 +17,8 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
|||||||
|
|
||||||
|
|
||||||
class RedisConnectorPermissionSyncPayload(BaseModel):
|
class RedisConnectorPermissionSyncPayload(BaseModel):
|
||||||
|
id: str
|
||||||
|
submitted: datetime
|
||||||
started: datetime | None
|
started: datetime | None
|
||||||
celery_task_id: str | None
|
celery_task_id: str | None
|
||||||
|
|
||||||
@@ -41,6 +43,12 @@ class RedisConnectorPermissionSync:
|
|||||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset
|
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset
|
||||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub
|
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub
|
||||||
|
|
||||||
|
# used to signal the overall workflow is still active
|
||||||
|
# it's impossible to get the exact state of the system at a single point in time
|
||||||
|
# so we need a signal with a TTL to bridge gaps in our checks
|
||||||
|
ACTIVE_PREFIX = PREFIX + "_active"
|
||||||
|
ACTIVE_TTL = 3600
|
||||||
|
|
||||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||||
self.tenant_id: str | None = tenant_id
|
self.tenant_id: str | None = tenant_id
|
||||||
self.id = id
|
self.id = id
|
||||||
@@ -54,6 +62,7 @@ class RedisConnectorPermissionSync:
|
|||||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||||
|
|
||||||
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
||||||
|
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
|
||||||
|
|
||||||
def taskset_clear(self) -> None:
|
def taskset_clear(self) -> None:
|
||||||
self.redis.delete(self.taskset_key)
|
self.redis.delete(self.taskset_key)
|
||||||
@@ -107,6 +116,20 @@ class RedisConnectorPermissionSync:
|
|||||||
|
|
||||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||||
|
|
||||||
|
def set_active(self) -> None:
|
||||||
|
"""This sets a signal to keep the permissioning flow from getting cleaned up within
|
||||||
|
the expiration time.
|
||||||
|
|
||||||
|
The slack in timing is needed to avoid race conditions where simply checking
|
||||||
|
the celery queue and task status could result in race conditions."""
|
||||||
|
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||||
|
|
||||||
|
def active(self) -> bool:
|
||||||
|
if self.redis.exists(self.active_key):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def generator_complete(self) -> int | None:
|
def generator_complete(self) -> int | None:
|
||||||
"""the fence payload is an int representing the starting number of
|
"""the fence payload is an int representing the starting number of
|
||||||
@@ -173,6 +196,7 @@ class RedisConnectorPermissionSync:
|
|||||||
return len(async_results)
|
return len(async_results)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
|
self.redis.delete(self.active_key)
|
||||||
self.redis.delete(self.generator_progress_key)
|
self.redis.delete(self.generator_progress_key)
|
||||||
self.redis.delete(self.generator_complete_key)
|
self.redis.delete(self.generator_complete_key)
|
||||||
self.redis.delete(self.taskset_key)
|
self.redis.delete(self.taskset_key)
|
||||||
@@ -187,6 +211,9 @@ class RedisConnectorPermissionSync:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def reset_all(r: redis.Redis) -> None:
|
def reset_all(r: redis.Redis) -> None:
|
||||||
"""Deletes all redis values for all connectors"""
|
"""Deletes all redis values for all connectors"""
|
||||||
|
for key in r.scan_iter(RedisConnectorPermissionSync.ACTIVE_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"):
|
for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"):
|
||||||
r.delete(key)
|
r.delete(key)
|
||||||
|
|
||||||
|
@@ -11,6 +11,7 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
|||||||
|
|
||||||
|
|
||||||
class RedisConnectorExternalGroupSyncPayload(BaseModel):
|
class RedisConnectorExternalGroupSyncPayload(BaseModel):
|
||||||
|
submitted: datetime
|
||||||
started: datetime | None
|
started: datetime | None
|
||||||
celery_task_id: str | None
|
celery_task_id: str | None
|
||||||
|
|
||||||
@@ -135,6 +136,12 @@ class RedisConnectorExternalGroupSync:
|
|||||||
) -> int | None:
|
) -> int | None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.redis.delete(self.generator_progress_key)
|
||||||
|
self.redis.delete(self.generator_complete_key)
|
||||||
|
self.redis.delete(self.taskset_key)
|
||||||
|
self.redis.delete(self.fence_key)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
||||||
taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}"
|
taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}"
|
||||||
|
@@ -33,8 +33,8 @@ class RedisConnectorIndex:
|
|||||||
TERMINATE_TTL = 600
|
TERMINATE_TTL = 600
|
||||||
|
|
||||||
# used to signal the overall workflow is still active
|
# used to signal the overall workflow is still active
|
||||||
# there are gaps in time between states where we need some slack
|
# it's impossible to get the exact state of the system at a single point in time
|
||||||
# to correctly transition
|
# so we need a signal with a TTL to bridge gaps in our checks
|
||||||
ACTIVE_PREFIX = PREFIX + "_active"
|
ACTIVE_PREFIX = PREFIX + "_active"
|
||||||
ACTIVE_TTL = 3600
|
ACTIVE_TTL = 3600
|
||||||
|
|
||||||
|
@@ -122,7 +122,7 @@ class TenantRedis(redis.Redis):
|
|||||||
"ttl",
|
"ttl",
|
||||||
] # Regular methods that need simple prefixing
|
] # Regular methods that need simple prefixing
|
||||||
|
|
||||||
if item == "scan_iter":
|
if item == "scan_iter" or item == "sscan_iter":
|
||||||
return self._prefix_scan_iter(original_attr)
|
return self._prefix_scan_iter(original_attr)
|
||||||
elif item in methods_to_wrap and callable(original_attr):
|
elif item in methods_to_wrap and callable(original_attr):
|
||||||
return self._prefix_method(original_attr)
|
return self._prefix_method(original_attr)
|
||||||
|
@@ -422,27 +422,29 @@ def sync_cc_pair(
|
|||||||
if redis_connector.permissions.fenced:
|
if redis_connector.permissions.fenced:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.CONFLICT,
|
status_code=HTTPStatus.CONFLICT,
|
||||||
detail="Doc permissions sync task already in progress.",
|
detail="Permissions sync task already in progress.",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Doc permissions sync cc_pair={cc_pair_id} "
|
f"Permissions sync cc_pair={cc_pair_id} "
|
||||||
f"connector_id={cc_pair.connector_id} "
|
f"connector_id={cc_pair.connector_id} "
|
||||||
f"credential_id={cc_pair.credential_id} "
|
f"credential_id={cc_pair.credential_id} "
|
||||||
f"{cc_pair.connector.name} connector."
|
f"{cc_pair.connector.name} connector."
|
||||||
)
|
)
|
||||||
tasks_created = try_creating_permissions_sync_task(
|
payload_id = try_creating_permissions_sync_task(
|
||||||
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||||
)
|
)
|
||||||
if not tasks_created:
|
if not payload_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||||
detail="Doc permissions sync task creation failed.",
|
detail="Permissions sync task creation failed.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}")
|
||||||
|
|
||||||
return StatusResponse(
|
return StatusResponse(
|
||||||
success=True,
|
success=True,
|
||||||
message="Successfully created the doc permissions sync task.",
|
message="Successfully created the permissions sync task.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -66,3 +68,10 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return masked_creds
|
return masked_creds
|
||||||
|
|
||||||
|
|
||||||
|
def make_short_id() -> str:
|
||||||
|
"""Fast way to generate a random 8 character id ... useful for tagging data
|
||||||
|
to trace it through a flow. This is definitely not guaranteed to be unique and is
|
||||||
|
targeted at the stated use case."""
|
||||||
|
return base64.b32encode(os.urandom(5)).decode("utf-8")[:8] # 5 bytes → 8 chars
|
||||||
|
@@ -26,6 +26,13 @@ doc_permission_sync_ctx: contextvars.ContextVar[
|
|||||||
] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict())
|
] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict())
|
||||||
|
|
||||||
|
|
||||||
|
class LoggerContextVars:
|
||||||
|
@staticmethod
|
||||||
|
def reset() -> None:
|
||||||
|
pruning_ctx.set(dict())
|
||||||
|
doc_permission_sync_ctx.set(dict())
|
||||||
|
|
||||||
|
|
||||||
class TaskAttemptSingleton:
|
class TaskAttemptSingleton:
|
||||||
"""Used to tell if this process is an indexing job, and if so what is the
|
"""Used to tell if this process is an indexing job, and if so what is the
|
||||||
unique identifier for this indexing attempt. For things like the API server,
|
unique identifier for this indexing attempt. For things like the API server,
|
||||||
@@ -70,27 +77,32 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
|
|||||||
) -> tuple[str, MutableMapping[str, Any]]:
|
) -> tuple[str, MutableMapping[str, Any]]:
|
||||||
# If this is an indexing job, add the attempt ID to the log message
|
# If this is an indexing job, add the attempt ID to the log message
|
||||||
# This helps filter the logs for this specific indexing
|
# This helps filter the logs for this specific indexing
|
||||||
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
|
while True:
|
||||||
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
|
pruning_ctx_dict = pruning_ctx.get()
|
||||||
|
if len(pruning_ctx_dict) > 0:
|
||||||
|
if "request_id" in pruning_ctx_dict:
|
||||||
|
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
|
||||||
|
|
||||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
if "cc_pair_id" in pruning_ctx_dict:
|
||||||
pruning_ctx_dict = pruning_ctx.get()
|
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
|
||||||
if len(pruning_ctx_dict) > 0:
|
break
|
||||||
if "request_id" in pruning_ctx_dict:
|
|
||||||
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
|
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||||
|
if len(doc_permission_sync_ctx_dict) > 0:
|
||||||
|
if "request_id" in doc_permission_sync_ctx_dict:
|
||||||
|
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
|
||||||
|
break
|
||||||
|
|
||||||
|
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
|
||||||
|
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
|
||||||
|
|
||||||
if "cc_pair_id" in pruning_ctx_dict:
|
|
||||||
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
|
|
||||||
elif len(doc_permission_sync_ctx_dict) > 0:
|
|
||||||
if "request_id" in doc_permission_sync_ctx_dict:
|
|
||||||
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
|
|
||||||
else:
|
|
||||||
if index_attempt_id is not None:
|
if index_attempt_id is not None:
|
||||||
msg = f"[Index Attempt: {index_attempt_id}] {msg}"
|
msg = f"[Index Attempt: {index_attempt_id}] {msg}"
|
||||||
|
|
||||||
if cc_pair_id is not None:
|
if cc_pair_id is not None:
|
||||||
msg = f"[CC Pair: {cc_pair_id}] {msg}"
|
msg = f"[CC Pair: {cc_pair_id}] {msg}"
|
||||||
|
|
||||||
|
break
|
||||||
# Add tenant information if it differs from default
|
# Add tenant information if it differs from default
|
||||||
# This will always be the case for authenticated API requests
|
# This will always be the case for authenticated API requests
|
||||||
if MULTI_TENANT:
|
if MULTI_TENANT:
|
||||||
|
Reference in New Issue
Block a user