From 478dd1c4bb5f37bfd48b5f4a2047c9cb2c2b0f1a Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 26 Sep 2024 13:14:48 -0700 Subject: [PATCH] functional multi tenant connector deletion --- .../danswer/background/celery/celery_app.py | 120 ++++++++++++------ .../danswer/background/connector_deletion.py | 7 +- backend/danswer/background/task_utils.py | 4 +- 3 files changed, 91 insertions(+), 40 deletions(-) diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index ffd805c2986..d1a9324b49f 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -67,11 +67,12 @@ _SYNC_BATCH_SIZE = 100 def cleanup_connector_credential_pair_task( connector_id: int, credential_id: int, + tenant_id: str ) -> int: """Connector deletion task. This is run as an async task because it is a somewhat slow job. Needs to potentially update a large number of Postgres and Vespa docs, including deleting them or updating the ACL""" - engine = get_sqlalchemy_engine() + engine = get_sqlalchemy_engine(schema=tenant_id) with Session(engine) as db_session: # validate that the connector / credential pair is deletable cc_pair = get_connector_credential_pair( @@ -101,6 +102,7 @@ def cleanup_connector_credential_pair_task( db_session=db_session, document_index=document_index, cc_pair=cc_pair, + tenant_id=tenant_id, ) except Exception as e: logger.exception(f"Failed to run connector_deletion due to {e}") @@ -109,7 +111,7 @@ def cleanup_connector_credential_pair_task( @build_celery_task_wrapper(name_cc_prune_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) -def prune_documents_task(connector_id: int, credential_id: int) -> None: +def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str) -> None: """connector pruning task. For a cc pair, this task pulls all document IDs from the source and compares those IDs to locally stored documents and deletes all locally stored IDs missing from the most recently pulled document ID list""" @@ -167,6 +169,7 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: connector_id=connector_id, credential_id=credential_id, document_index=document_index, + tenant_id=tenant_id, ) except Exception as e: logger.exception( @@ -177,7 +180,7 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None: @build_celery_task_wrapper(name_document_set_sync_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_document_set_task(document_set_id: int) -> None: +def sync_document_set_task(document_set_id: int, tenant_id: str) -> None: """For document sets marked as not up to date, sync the state from postgres into the datastore. Also handles deletions.""" @@ -210,7 +213,7 @@ def sync_document_set_task(document_set_id: int) -> None: ] document_index.update(update_requests=update_requests) - with Session(get_sqlalchemy_engine()) as db_session: + with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session: try: cursor = None while True: @@ -261,10 +264,10 @@ def sync_document_set_task(document_set_id: int) -> None: name="check_for_document_sets_sync_task", soft_time_limit=JOB_TIMEOUT, ) -def check_for_document_sets_sync_task() -> None: +def check_for_document_sets_sync_task(tenant_id: str) -> None: """Runs periodically to check if any sync tasks should be run and adds them to the queue""" - with Session(get_sqlalchemy_engine()) as db_session: + with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session: # check if any document sets are not synced document_set_info = fetch_document_sets( user_id=None, db_session=db_session, include_outdated=True @@ -281,9 +284,10 @@ def check_for_document_sets_sync_task() -> None: name="check_for_cc_pair_deletion_task", soft_time_limit=JOB_TIMEOUT, ) -def check_for_cc_pair_deletion_task() -> None: +def check_for_cc_pair_deletion_task(tenant_id: str) -> None: + print('\n\n\n\n\n\n\n\n\n\n\nscheduling deletion task') """Runs periodically to check if any deletion tasks should be run""" - with Session(get_sqlalchemy_engine()) as db_session: + with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session: # check if any document sets are not synced cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in cc_pairs: @@ -293,6 +297,7 @@ def check_for_cc_pair_deletion_task() -> None: kwargs=dict( connector_id=cc_pair.connector.id, credential_id=cc_pair.credential.id, + tenant_id=tenant_id ), ) @@ -303,7 +308,7 @@ def check_for_cc_pair_deletion_task() -> None: bind=True, base=AbortableTask, ) -def kombu_message_cleanup_task(self: Any) -> int: +def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int: """Runs periodically to clean up the kombu_message table""" # we will select messages older than this amount to clean up @@ -315,7 +320,7 @@ def kombu_message_cleanup_task(self: Any) -> int: ctx["deleted"] = 0 ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT - with Session(get_sqlalchemy_engine()) as db_session: + with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session: # Exit the task if we can't take the advisory lock result = db_session.execute( text("SELECT pg_try_advisory_lock(:id)"), @@ -416,11 +421,11 @@ def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool: name="check_for_prune_task", soft_time_limit=JOB_TIMEOUT, ) -def check_for_prune_task() -> None: +def check_for_prune_task(tenant_id: str) -> None: """Runs periodically to check if any prune tasks should be run and adds them to the queue""" - with Session(get_sqlalchemy_engine()) as db_session: + with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session: all_cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in all_cc_pairs: @@ -435,38 +440,79 @@ def check_for_prune_task() -> None: kwargs=dict( connector_id=cc_pair.connector.id, credential_id=cc_pair.credential.id, - ) + tenant_id=tenant_id, + ) ) +from danswer.configs.app_configs import MULTI_TENANT +from danswer.background.update import get_all_tenant_ids ##### # Celery Beat (Periodic Tasks) Settings ##### -celery_app.conf.beat_schedule = { - "check-for-document-set-sync": { - "task": "check_for_document_sets_sync_task", - "schedule": timedelta(seconds=5), - }, - "check-for-cc-pair-deletion": { - "task": "check_for_cc_pair_deletion_task", - # don't need to check too often, since we kick off a deletion initially - # during the API call that actually marks the CC pair for deletion - "schedule": timedelta(minutes=1), - }, -} -celery_app.conf.beat_schedule.update( - { - "check-for-prune": { + +def schedule_tenant_tasks(): + if MULTI_TENANT: + tenants = get_all_tenant_ids() + else: + tenants = ['public'] # Default tenant in single-tenancy mode + + # Filter out any invalid tenants if necessary + valid_tenants = [tenant for tenant in tenants if not tenant.startswith('pg_')] + logger.info(f"Scheduling tasks for tenants: {valid_tenants}") + + for tenant_id in valid_tenants: + print(f"Scheduling tasks for tenant: {tenant_id}") + # Schedule tasks specific to each tenant + celery_app.conf.beat_schedule[f"check-for-document-set-sync-{tenant_id}"] = { + "task": "check_for_document_sets_sync_task", + "schedule": timedelta(seconds=5), + "args": (tenant_id,), + } + celery_app.conf.beat_schedule[f"check-for-cc-pair-deletion-{tenant_id}"] = { + "task": "check_for_cc_pair_deletion_task", + "schedule": timedelta(seconds=5), + "args": (tenant_id,), + } + celery_app.conf.beat_schedule[f"check-for-prune-{tenant_id}"] = { "task": "check_for_prune_task", "schedule": timedelta(seconds=5), - }, - } -) -celery_app.conf.beat_schedule.update( - { - "kombu-message-cleanup": { + "args": (tenant_id,), + } + + # Schedule tasks that are not tenant-specific + celery_app.conf.beat_schedule["kombu-message-cleanup"] = { "task": "kombu_message_cleanup_task", "schedule": timedelta(seconds=3600), - }, - } -) + "args": (tenant_id,), + } + +schedule_tenant_tasks() +# celery_app.conf.beat_schedule = { +# "check-for-document-set-sync": { +# "task": "check_for_document_sets_sync_task", +# "schedule": timedelta(seconds=5), +# }, +# "check-for-cc-pair-deletion": { +# "task": "check_for_cc_pair_deletion_task", +# # don't need to check too often, since we kick off a deletion initially +# # during the API call that actually marks the CC pair for deletion +# "schedule": timedelta(minutes=1), +# }, +# } +# celery_app.conf.beat_schedule.update( +# { +# "check-for-prune": { +# "task": "check_for_prune_task", +# "schedule": timedelta(seconds=5), +# }, +# } +# ) +# celery_app.conf.beat_schedule.update( +# { +# "kombu-message-cleanup": { +# "task": "kombu_message_cleanup_task", +# "schedule": timedelta(seconds=3600), +# }, +# } +# ) diff --git a/backend/danswer/background/connector_deletion.py b/backend/danswer/background/connector_deletion.py index 90883564910..c0f4b3911ed 100644 --- a/backend/danswer/background/connector_deletion.py +++ b/backend/danswer/background/connector_deletion.py @@ -35,6 +35,7 @@ from danswer.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, ) from danswer.utils.variable_functionality import noop_fallback +from danswer.configs.app_configs import DEFAULT_SCHEMA logger = setup_logger() @@ -46,12 +47,13 @@ def delete_connector_credential_pair_batch( connector_id: int, credential_id: int, document_index: DocumentIndex, + tenant_id: str = DEFAULT_SCHEMA, ) -> None: """ Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore it gets permanently deleted. """ - with Session(get_sqlalchemy_engine()) as db_session: + with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session: # acquire lock for all documents in this batch so that indexing can't # override the deletion with prepare_to_modify_documents( @@ -124,6 +126,7 @@ def delete_connector_credential_pair( db_session: Session, document_index: DocumentIndex, cc_pair: ConnectorCredentialPair, + tenant_id: str = DEFAULT_SCHEMA, ) -> int: connector_id = cc_pair.connector_id credential_id = cc_pair.credential_id @@ -135,6 +138,7 @@ def delete_connector_credential_pair( connector_id=connector_id, credential_id=credential_id, limit=_DELETION_BATCH_SIZE, + ) if not documents: break @@ -144,6 +148,7 @@ def delete_connector_credential_pair( connector_id=connector_id, credential_id=credential_id, document_index=document_index, + tenant_id=tenant_id, ) num_docs_deleted += len(documents) diff --git a/backend/danswer/background/task_utils.py b/backend/danswer/background/task_utils.py index 6e122678813..10f7f95d043 100644 --- a/backend/danswer/background/task_utils.py +++ b/backend/danswer/background/task_utils.py @@ -14,8 +14,8 @@ from danswer.db.tasks import mark_task_start from danswer.db.tasks import register_task -def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str: - return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}" +def name_cc_cleanup_task(connector_id: int, credential_id: int, tenant_id: str = "") -> str: + return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}_{tenant_id}" def name_document_set_sync_task(document_set_id: int) -> str: