From fbf51b70d0d004b7030a57865ee8c69c753f7c1d Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Thu, 26 Sep 2024 17:50:55 -0700 Subject: [PATCH] Feature/celery multi (#2470) * first cut at redis * some new helper functions for the db * ignore kombu tables in alembic migrations (used by celery) * multiline commands for readability, add vespa_metadata_sync queue to worker * typo fix * fix returning tuple fields * add constants * fix _get_access_for_document * docstrings! * fix double function declaration and typing * fix type hinting * add a global redis pool * Add get_document function * use task_logger in various celery tasks * add celeryconfig.py to simplify configuration. Will be used in a subsequent commit * Add celery redis helper. used in a subsequent PR * kombu warning getting spammy since celery is not self managing its queue in Postgres any more * add last_modified and last_synced to documents * fix task naming convention * use celeryconfig.py * the big one. adds queues and tasks, updates functions to use the queues with priorities, etc * change vespa index log line to debug * mypy fixes * update alembic migration * fix fence ordering, rename to "monitor", fix fetch_versioned_implementation call * mypy * switch to monotonic time * fix startup dependencies on redis * rebase alembic migration * kombu cleanup - fail silently * mypy * add redis_host environment override * update REDIS_HOST env var in docker-compose.dev.yml * update the rest of the docker files * in flight * harden indexing-status endpoint against db changes happening in the background. Needs further improvement but OK for now. * allow no task syncs to run because we create certain objects with no entries but initially marked as out of date * add back writing to vespa on indexing * actually working connector deletion * update contributing guide * backporting fixes from background_deletion * renaming cache to cache_volume * add redis password to various deployments * try setting up pr testing for helm * fix indent * hopefully this release version actually exists * fix command line option to --chart-dirs * fetch-depth 0 * edit values.yaml * try setting ct working directory * bypass testing only on change for now * move files and lint them * update helm testing * some issues suggest using --config works * add vespa repo * add postgresql repo * increase timeout * try amd64 runner * fix redis password reference * add comment to helm chart testing workflow * rename helm testing workflow to disable it * adding clarifying comments * address code review * missed a file * remove commented warning ... just not needed * fix imports * refactor to use update_single * mypy fixes * add vespa test * multiple celery workers * update logs as well and set prefetch multipliers appropriate to the worker intent * add db refresh to connector deletion * add some preliminary locking * organize tasks into separate files * celery auto associates tasks created inside another task, which bloats the result metadata considerably. trail=False prevents this. * code review fixes * move monitor_usergroup_taskset to ee, improve logging * add multi workers to dev_run_background_jobs.py * update supervisord with some recommended settings for celery * name celery workers and shorten dev script prefixing * add configurable sql alchemy engine settings on startup (needed for various intents like API server, different celery workers and tasks, etc) * fix comments * autoscale sqlalchemy pool size to celery concurrency (allow override later?) * supervisord needs the percent symbols escaped * use name as primary check, some minor refactoring and type hinting too. * addressing code review * fix import * fix prune_documents_task references --------- Co-authored-by: Richard Kuo --- .../danswer/background/celery/celery_app.py | 1181 +++-------------- .../danswer/background/celery/celery_utils.py | 28 + .../celery/tasks/connector_deletion/tasks.py | 133 ++ .../background/celery/tasks/periodic/tasks.py | 140 ++ .../background/celery/tasks/pruning/tasks.py | 120 ++ .../background/celery/tasks/vespa/tasks.py | 526 ++++++++ .../danswer/background/connector_deletion.py | 101 ++ backend/danswer/background/update.py | 2 + backend/danswer/configs/constants.py | 6 +- backend/danswer/db/engine.py | 89 +- backend/danswer/server/documents/cc_pair.py | 2 +- .../danswer/server/manage/administrative.py | 12 +- .../background/celery/tasks/vespa/tasks.py | 52 + backend/ee/danswer/background/celery_utils.py | 49 - backend/scripts/dev_run_background_jobs.py | 68 +- backend/supervisord.conf | 54 +- 16 files changed, 1501 insertions(+), 1062 deletions(-) create mode 100644 backend/danswer/background/celery/tasks/connector_deletion/tasks.py create mode 100644 backend/danswer/background/celery/tasks/periodic/tasks.py create mode 100644 backend/danswer/background/celery/tasks/pruning/tasks.py create mode 100644 backend/danswer/background/celery/tasks/vespa/tasks.py create mode 100644 backend/ee/danswer/background/celery/tasks/vespa/tasks.py diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index aedc3fec4..0440f275c 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -1,91 +1,39 @@ -import json import logging -import traceback +import time from datetime import timedelta from typing import Any -from typing import cast import redis +from celery import bootsteps # type: ignore from celery import Celery from celery import current_task from celery import signals from celery import Task -from celery.contrib.abortable import AbortableTask # type: ignore -from celery.exceptions import SoftTimeLimitExceeded -from celery.exceptions import TaskRevokedError +from celery.exceptions import WorkerShutdown from celery.signals import beat_init from celery.signals import worker_init +from celery.signals import worker_ready +from celery.signals import worker_shutdown from celery.states import READY_STATES from celery.utils.log import get_task_logger -from redis import Redis -from sqlalchemy import inspect -from sqlalchemy import text -from sqlalchemy.orm import Session -from sqlalchemy.orm.exc import ObjectDeletedError -from danswer.access.access import get_access_for_document from danswer.background.celery.celery_redis import RedisConnectorCredentialPair from danswer.background.celery.celery_redis import RedisConnectorDeletion from danswer.background.celery.celery_redis import RedisDocumentSet from danswer.background.celery.celery_redis import RedisUserGroup -from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector -from danswer.background.celery.celery_utils import should_prune_cc_pair -from danswer.background.connector_deletion import delete_connector_credential_pair_batch -from danswer.background.task_utils import build_celery_task_wrapper -from danswer.background.task_utils import name_cc_prune_task -from danswer.configs.app_configs import JOB_TIMEOUT -from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.background.celery.celery_utils import celery_is_worker_primary +from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerRedisLocks from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME -from danswer.configs.constants import POSTGRES_CELERY_WORKER_APP_NAME -from danswer.configs.constants import PostgresAdvisoryLocks -from danswer.connectors.factory import instantiate_connector -from danswer.connectors.models import InputType -from danswer.db.connector import fetch_connector_by_id -from danswer.db.connector_credential_pair import add_deletion_failure_message -from danswer.db.connector_credential_pair import ( - delete_connector_credential_pair__no_commit, -) -from danswer.db.connector_credential_pair import get_connector_credential_pair -from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id -from danswer.db.connector_credential_pair import get_connector_credential_pairs -from danswer.db.document import count_documents_by_needs_sync -from danswer.db.document import delete_document_by_connector_credential_pair__no_commit -from danswer.db.document import delete_documents_complete__no_commit -from danswer.db.document import get_document -from danswer.db.document import get_document_connector_count -from danswer.db.document import get_documents_for_connector_credential_pair -from danswer.db.document import mark_document_as_synced -from danswer.db.document_set import delete_document_set -from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit -from danswer.db.document_set import fetch_document_sets -from danswer.db.document_set import fetch_document_sets_for_document -from danswer.db.document_set import get_document_set_by_id -from danswer.db.document_set import mark_document_set_as_synced -from danswer.db.engine import get_sqlalchemy_engine -from danswer.db.engine import init_sqlalchemy_engine -from danswer.db.enums import ConnectorCredentialPairStatus -from danswer.db.enums import IndexingStatus -from danswer.db.index_attempt import delete_index_attempts -from danswer.db.index_attempt import get_last_attempt -from danswer.db.models import ConnectorCredentialPair -from danswer.db.models import DocumentSet -from danswer.db.models import UserGroup -from danswer.db.search_settings import get_current_search_settings -from danswer.document_index.document_index_utils import get_both_index_names -from danswer.document_index.factory import get_default_document_index -from danswer.document_index.interfaces import UpdateRequest +from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME +from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME +from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME +from danswer.db.engine import SqlEngine from danswer.redis.redis_pool import RedisPool -from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.utils.logger import ColoredFormatter from danswer.utils.logger import PlainFormatter from danswer.utils.logger import setup_logger -from danswer.utils.variable_functionality import fetch_versioned_implementation -from danswer.utils.variable_functionality import ( - fetch_versioned_implementation_with_fallback, -) -from danswer.utils.variable_functionality import noop_fallback logger = setup_logger() @@ -100,692 +48,6 @@ celery_app.config_from_object( ) # Load configuration from 'celeryconfig.py' -##### -# Tasks that need to be run in job queue, registered via APIs -# -# If imports from this module are needed, use local imports to avoid circular importing -##### - - -@build_celery_task_wrapper(name_cc_prune_task) -@celery_app.task(soft_time_limit=JOB_TIMEOUT) -def prune_documents_task(connector_id: int, credential_id: int) -> None: - """connector pruning task. For a cc pair, this task pulls all document IDs from the source - and compares those IDs to locally stored documents and deletes all locally stored IDs missing - from the most recently pulled document ID list""" - with Session(get_sqlalchemy_engine()) as db_session: - try: - cc_pair = get_connector_credential_pair( - db_session=db_session, - connector_id=connector_id, - credential_id=credential_id, - ) - - if not cc_pair: - task_logger.warning( - f"ccpair not found for {connector_id} {credential_id}" - ) - return - - runnable_connector = instantiate_connector( - db_session=db_session, - source=cc_pair.connector.source, - input_type=InputType.PRUNE, - connector_specific_config=cc_pair.connector.connector_specific_config, - credential=cc_pair.credential, - ) - - all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector( - runnable_connector - ) - - all_indexed_document_ids = { - doc.id - for doc in get_documents_for_connector_credential_pair( - db_session=db_session, - connector_id=connector_id, - credential_id=credential_id, - ) - } - - doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids) - - curr_ind_name, sec_ind_name = get_both_index_names(db_session) - document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name - ) - - if len(doc_ids_to_remove) == 0: - task_logger.info( - f"No docs to prune from {cc_pair.connector.source} connector" - ) - return - - task_logger.info( - f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector" - ) - delete_connector_credential_pair_batch( - document_ids=doc_ids_to_remove, - connector_id=connector_id, - credential_id=credential_id, - document_index=document_index, - ) - except Exception as e: - task_logger.exception( - f"Failed to run pruning for connector id {connector_id}." - ) - raise e - - -def try_generate_stale_document_sync_tasks( - db_session: Session, r: Redis, lock_beat: redis.lock.Lock -) -> int | None: - """This picks up stale documents (typically from indexing) and queues them for sync to Vespa. - - Returns an int if syncing is needed. The int represents the number of sync tasks generated. - Returns None if no syncing is required. - """ - # the fence is up, do nothing - if r.exists(RedisConnectorCredentialPair.get_fence_key()): - return None - - r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset - - # add tasks to celery and build up the task set to monitor in redis - stale_doc_count = count_documents_by_needs_sync(db_session) - if stale_doc_count == 0: - return None - - task_logger.info( - f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair." - ) - - task_logger.info("RedisConnector.generate_tasks starting by cc_pair.") - - # rkuo: we could technically sync all stale docs in one big pass. - # but I feel it's more understandable to group the docs by cc_pair - total_tasks_generated = 0 - cc_pairs = get_connector_credential_pairs(db_session) - for cc_pair in cc_pairs: - rc = RedisConnectorCredentialPair(cc_pair.id) - tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat) - - if tasks_generated is None: - continue - - if tasks_generated == 0: - continue - - task_logger.info( - f"RedisConnector.generate_tasks finished. " - f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" - ) - - total_tasks_generated += tasks_generated - - task_logger.info( - f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}" - ) - - r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated) - return total_tasks_generated - - -def try_generate_document_set_sync_tasks( - document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock -) -> int | None: - """Returns an int if syncing is needed. The int represents the number of sync tasks generated. - Note that syncing can still be required even if the number of sync tasks generated is zero. - Returns None if no syncing is required. - """ - lock_beat.reacquire() - - rds = RedisDocumentSet(document_set.id) - - # don't generate document set sync tasks if tasks are still pending - if r.exists(rds.fence_key): - return None - - # don't generate sync tasks if we're up to date - # race condition with the monitor/cleanup function if we use a cached result! - db_session.refresh(document_set) - if document_set.is_up_to_date: - return None - - # add tasks to celery and build up the task set to monitor in redis - r.delete(rds.taskset_key) - - task_logger.info( - f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}" - ) - - # Add all documents that need to be updated into the queue - tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat) - if tasks_generated is None: - return None - - # Currently we are allowing the sync to proceed with 0 tasks. - # It's possible for sets/groups to be generated initially with no entries - # and they still need to be marked as up to date. - # if tasks_generated == 0: - # return 0 - - task_logger.info( - f"RedisDocumentSet.generate_tasks finished. " - f"document_set_id={document_set.id} tasks_generated={tasks_generated}" - ) - - # set this only after all tasks have been added - r.set(rds.fence_key, tasks_generated) - return tasks_generated - - -def try_generate_user_group_sync_tasks( - usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock -) -> int | None: - """Returns an int if syncing is needed. The int represents the number of sync tasks generated. - Note that syncing can still be required even if the number of sync tasks generated is zero. - Returns None if no syncing is required. - """ - lock_beat.reacquire() - - rug = RedisUserGroup(usergroup.id) - - # don't generate sync tasks if tasks are still pending - if r.exists(rug.fence_key): - return None - - # race condition with the monitor/cleanup function if we use a cached result! - db_session.refresh(usergroup) - if usergroup.is_up_to_date: - return None - - # add tasks to celery and build up the task set to monitor in redis - r.delete(rug.taskset_key) - - # Add all documents that need to be updated into the queue - task_logger.info( - f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}" - ) - tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat) - if tasks_generated is None: - return None - - # Currently we are allowing the sync to proceed with 0 tasks. - # It's possible for sets/groups to be generated initially with no entries - # and they still need to be marked as up to date. - # if tasks_generated == 0: - # return 0 - - task_logger.info( - f"RedisUserGroup.generate_tasks finished. " - f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}" - ) - - # set this only after all tasks have been added - r.set(rug.fence_key, tasks_generated) - return tasks_generated - - -def try_generate_document_cc_pair_cleanup_tasks( - cc_pair: ConnectorCredentialPair, - db_session: Session, - r: Redis, - lock_beat: redis.lock.Lock, -) -> int | None: - """Returns an int if syncing is needed. The int represents the number of sync tasks generated. - Note that syncing can still be required even if the number of sync tasks generated is zero. - Returns None if no syncing is required. - """ - - lock_beat.reacquire() - - rcd = RedisConnectorDeletion(cc_pair.id) - - # don't generate sync tasks if tasks are still pending - if r.exists(rcd.fence_key): - return None - - # we need to refresh the state of the object inside the fence - # to avoid a race condition with db.commit/fence deletion - # at the end of this taskset - try: - db_session.refresh(cc_pair) - except ObjectDeletedError: - return None - - if cc_pair.status != ConnectorCredentialPairStatus.DELETING: - return None - - search_settings = get_current_search_settings(db_session) - - last_indexing = get_last_attempt( - connector_id=cc_pair.connector_id, - credential_id=cc_pair.credential_id, - search_settings_id=search_settings.id, - db_session=db_session, - ) - if last_indexing: - if ( - last_indexing.status == IndexingStatus.IN_PROGRESS - or last_indexing.status == IndexingStatus.NOT_STARTED - ): - return None - - # add tasks to celery and build up the task set to monitor in redis - r.delete(rcd.taskset_key) - - # Add all documents that need to be updated into the queue - task_logger.info( - f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}" - ) - tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat) - if tasks_generated is None: - return None - - # Currently we are allowing the sync to proceed with 0 tasks. - # It's possible for sets/groups to be generated initially with no entries - # and they still need to be marked as up to date. - # if tasks_generated == 0: - # return 0 - - task_logger.info( - f"RedisConnectorDeletion.generate_tasks finished. " - f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" - ) - - # set this only after all tasks have been added - r.set(rcd.fence_key, tasks_generated) - return tasks_generated - - -##### -# Periodic Tasks -##### -@celery_app.task( - name="check_for_vespa_sync_task", - soft_time_limit=JOB_TIMEOUT, -) -def check_for_vespa_sync_task() -> None: - """Runs periodically to check if any document needs syncing. - Generates sets of tasks for Celery if syncing is needed.""" - - r = redis_pool.get_client() - - lock_beat = r.lock( - DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK, - timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, - ) - - try: - # these tasks should never overlap - if not lock_beat.acquire(blocking=False): - return - - with Session(get_sqlalchemy_engine()) as db_session: - try_generate_stale_document_sync_tasks(db_session, r, lock_beat) - - # check if any document sets are not synced - document_set_info = fetch_document_sets( - user_id=None, db_session=db_session, include_outdated=True - ) - for document_set, _ in document_set_info: - try_generate_document_set_sync_tasks( - document_set, db_session, r, lock_beat - ) - - # check if any user groups are not synced - try: - fetch_user_groups = fetch_versioned_implementation( - "danswer.db.user_group", "fetch_user_groups" - ) - - user_groups = fetch_user_groups( - db_session=db_session, only_up_to_date=False - ) - for usergroup in user_groups: - try_generate_user_group_sync_tasks( - usergroup, db_session, r, lock_beat - ) - except ModuleNotFoundError: - # Always exceptions on the MIT version, which is expected - pass - except SoftTimeLimitExceeded: - task_logger.info( - "Soft time limit exceeded, task is being terminated gracefully." - ) - except Exception: - task_logger.exception("Unexpected exception") - finally: - if lock_beat.owned(): - lock_beat.release() - - -@celery_app.task( - name="check_for_connector_deletion_task", - soft_time_limit=JOB_TIMEOUT, -) -def check_for_connector_deletion_task() -> None: - r = redis_pool.get_client() - - lock_beat = r.lock( - DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK, - timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, - ) - - try: - # these tasks should never overlap - if not lock_beat.acquire(blocking=False): - return - - with Session(get_sqlalchemy_engine()) as db_session: - cc_pairs = get_connector_credential_pairs(db_session) - for cc_pair in cc_pairs: - try_generate_document_cc_pair_cleanup_tasks( - cc_pair, db_session, r, lock_beat - ) - except SoftTimeLimitExceeded: - task_logger.info( - "Soft time limit exceeded, task is being terminated gracefully." - ) - except Exception: - task_logger.exception("Unexpected exception") - finally: - if lock_beat.owned(): - lock_beat.release() - - -@celery_app.task( - name="kombu_message_cleanup_task", - soft_time_limit=JOB_TIMEOUT, - bind=True, - base=AbortableTask, -) -def kombu_message_cleanup_task(self: Any) -> int: - """Runs periodically to clean up the kombu_message table""" - - # we will select messages older than this amount to clean up - KOMBU_MESSAGE_CLEANUP_AGE = 7 # days - KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000 - - ctx = {} - ctx["last_processed_id"] = 0 - ctx["deleted"] = 0 - ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE - ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT - with Session(get_sqlalchemy_engine()) as db_session: - # Exit the task if we can't take the advisory lock - result = db_session.execute( - text("SELECT pg_try_advisory_lock(:id)"), - {"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value}, - ).scalar() - if not result: - return 0 - - while True: - if self.is_aborted(): - raise TaskRevokedError("kombu_message_cleanup_task was aborted.") - - b = kombu_message_cleanup_task_helper(ctx, db_session) - if not b: - break - - db_session.commit() - - if ctx["deleted"] > 0: - task_logger.info( - f"Deleted {ctx['deleted']} orphaned messages from kombu_message." - ) - - return ctx["deleted"] - - -def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool: - """ - Helper function to clean up old messages from the `kombu_message` table that are no longer relevant. - - This function retrieves messages from the `kombu_message` table that are no longer visible and - older than a specified interval. It checks if the corresponding task_id exists in the - `celery_taskmeta` table. If the task_id does not exist, the message is deleted. - - Args: - ctx (dict): A context dictionary containing configuration parameters such as: - - 'cleanup_age' (int): The age in days after which messages are considered old. - - 'page_limit' (int): The maximum number of messages to process in one batch. - - 'last_processed_id' (int): The ID of the last processed message to handle pagination. - - 'deleted' (int): A counter to track the number of deleted messages. - db_session (Session): The SQLAlchemy database session for executing queries. - - Returns: - bool: Returns True if there are more rows to process, False if not. - """ - - inspector = inspect(db_session.bind) - if not inspector: - return False - - # With the move to redis as celery's broker and backend, kombu tables may not even exist. - # We can fail silently. - if not inspector.has_table("kombu_message"): - return False - - query = text( - """ - SELECT id, timestamp, payload - FROM kombu_message WHERE visible = 'false' - AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days - AND id > :last_processed_id - ORDER BY id - LIMIT :page_limit -""" - ) - kombu_messages = db_session.execute( - query, - { - "interval_days": f"{ctx['cleanup_age']} days", - "page_limit": ctx["page_limit"], - "last_processed_id": ctx["last_processed_id"], - }, - ).fetchall() - - if len(kombu_messages) == 0: - return False - - for msg in kombu_messages: - payload = json.loads(msg[2]) - task_id = payload["headers"]["id"] - - # Check if task_id exists in celery_taskmeta - task_exists = db_session.execute( - text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"), - {"task_id": task_id}, - ).fetchone() - - # If task_id does not exist, delete the message - if not task_exists: - result = db_session.execute( - text("DELETE FROM kombu_message WHERE id = :message_id"), - {"message_id": msg[0]}, - ) - if result.rowcount > 0: # type: ignore - ctx["deleted"] += 1 - - ctx["last_processed_id"] = msg[0] - - return True - - -@celery_app.task( - name="check_for_prune_task", - soft_time_limit=JOB_TIMEOUT, -) -def check_for_prune_task() -> None: - """Runs periodically to check if any prune tasks should be run and adds them - to the queue""" - - with Session(get_sqlalchemy_engine()) as db_session: - all_cc_pairs = get_connector_credential_pairs(db_session) - - for cc_pair in all_cc_pairs: - if not cc_pair.connector.prune_freq: - continue - - if should_prune_cc_pair( - connector=cc_pair.connector, - credential=cc_pair.credential, - db_session=db_session, - ): - task_logger.info(f"Pruning the {cc_pair.connector.name} connector") - - prune_documents_task.apply_async( - kwargs=dict( - connector_id=cc_pair.connector.id, - credential_id=cc_pair.credential.id, - ) - ) - - -@celery_app.task( - name="vespa_metadata_sync_task", - bind=True, - soft_time_limit=45, - time_limit=60, - max_retries=3, -) -def vespa_metadata_sync_task(self: Task, document_id: str) -> bool: - task_logger.info(f"document_id={document_id}") - - try: - with Session(get_sqlalchemy_engine()) as db_session: - curr_ind_name, sec_ind_name = get_both_index_names(db_session) - document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name - ) - - doc = get_document(document_id, db_session) - if not doc: - return False - - # document set sync - doc_sets = fetch_document_sets_for_document(document_id, db_session) - update_doc_sets: set[str] = set(doc_sets) - - # User group sync - doc_access = get_access_for_document( - document_id=document_id, db_session=db_session - ) - update_request = UpdateRequest( - document_ids=[document_id], - document_sets=update_doc_sets, - access=doc_access, - boost=doc.boost, - hidden=doc.hidden, - ) - - # update Vespa. OK if doc doesn't exist. Raises exception otherwise. - document_index.update_single(update_request=update_request) - - # update db last. Worst case = we crash right before this and - # the sync might repeat again later - mark_document_as_synced(document_id, db_session) - except SoftTimeLimitExceeded: - task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}") - except Exception as e: - task_logger.exception("Unexpected exception") - - # Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64 - countdown = 2 ** (self.request.retries + 4) - self.retry(exc=e, countdown=countdown) - - return True - - -@celery_app.task( - name="document_by_cc_pair_cleanup_task", - bind=True, - soft_time_limit=45, - time_limit=60, - max_retries=3, -) -def document_by_cc_pair_cleanup_task( - self: Task, document_id: str, connector_id: int, credential_id: int -) -> bool: - task_logger.info(f"document_id={document_id}") - - try: - with Session(get_sqlalchemy_engine()) as db_session: - curr_ind_name, sec_ind_name = get_both_index_names(db_session) - document_index = get_default_document_index( - primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name - ) - - count = get_document_connector_count(db_session, document_id) - if count == 1: - # count == 1 means this is the only remaining cc_pair reference to the doc - # delete it from vespa and the db - document_index.delete(doc_ids=[document_id]) - delete_documents_complete__no_commit( - db_session=db_session, - document_ids=[document_id], - ) - elif count > 1: - # count > 1 means the document still has cc_pair references - doc = get_document(document_id, db_session) - if not doc: - return False - - # the below functions do not include cc_pairs being deleted. - # i.e. they will correctly omit access for the current cc_pair - doc_access = get_access_for_document( - document_id=document_id, db_session=db_session - ) - - doc_sets = fetch_document_sets_for_document(document_id, db_session) - update_doc_sets: set[str] = set(doc_sets) - - update_request = UpdateRequest( - document_ids=[document_id], - document_sets=update_doc_sets, - access=doc_access, - boost=doc.boost, - hidden=doc.hidden, - ) - - # update Vespa. OK if doc doesn't exist. Raises exception otherwise. - document_index.update_single(update_request=update_request) - - # there are still other cc_pair references to the doc, so just resync to Vespa - delete_document_by_connector_credential_pair__no_commit( - db_session=db_session, - document_id=document_id, - connector_credential_pair_identifier=ConnectorCredentialPairIdentifier( - connector_id=connector_id, - credential_id=credential_id, - ), - ) - - mark_document_as_synced(document_id, db_session) - else: - pass - - # update_docs_last_modified__no_commit( - # db_session=db_session, - # document_ids=[document_id], - # ) - - db_session.commit() - except SoftTimeLimitExceeded: - task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}") - except Exception as e: - task_logger.exception("Unexpected exception") - - # Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64 - countdown = 2 ** (self.request.retries + 4) - self.retry(exc=e, countdown=countdown) - - return True - - @signals.task_postrun.connect def celery_task_postrun( sender: Any | None = None, @@ -847,235 +109,113 @@ def celery_task_postrun( return -def monitor_connector_taskset(r: Redis) -> None: - fence_value = r.get(RedisConnectorCredentialPair.get_fence_key()) - if fence_value is None: - return - - try: - initial_count = int(cast(int, fence_value)) - except ValueError: - task_logger.error("The value is not an integer.") - return - - count = r.scard(RedisConnectorCredentialPair.get_taskset_key()) - task_logger.info(f"Stale documents: remaining={count} initial={initial_count}") - if count == 0: - r.delete(RedisConnectorCredentialPair.get_taskset_key()) - r.delete(RedisConnectorCredentialPair.get_fence_key()) - task_logger.info(f"Successfully synced stale documents. count={initial_count}") - - -def monitor_document_set_taskset(key_bytes: bytes, r: Redis) -> None: - fence_key = key_bytes.decode("utf-8") - document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key) - if document_set_id is None: - task_logger.warning("could not parse document set id from {key}") - return - - rds = RedisDocumentSet(document_set_id) - - fence_value = r.get(rds.fence_key) - if fence_value is None: - return - - try: - initial_count = int(cast(int, fence_value)) - except ValueError: - task_logger.error("The value is not an integer.") - return - - count = cast(int, r.scard(rds.taskset_key)) - task_logger.info( - f"Document set sync: document_set_id={document_set_id} remaining={count} initial={initial_count}" - ) - if count > 0: - return - - with Session(get_sqlalchemy_engine()) as db_session: - document_set = cast( - DocumentSet, - get_document_set_by_id( - db_session=db_session, document_set_id=document_set_id - ), - ) # casting since we "know" a document set with this ID exists - if document_set: - if not document_set.connector_credential_pairs: - # if there are no connectors, then delete the document set. - delete_document_set( - document_set_row=document_set, db_session=db_session - ) - task_logger.info( - f"Successfully deleted document set with ID: '{document_set_id}'!" - ) - else: - mark_document_set_as_synced(document_set_id, db_session) - task_logger.info( - f"Successfully synced document set with ID: '{document_set_id}'!" - ) - - r.delete(rds.taskset_key) - r.delete(rds.fence_key) - - -def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None: - fence_key = key_bytes.decode("utf-8") - cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key) - if cc_pair_id is None: - task_logger.warning("could not parse document set id from {key}") - return - - rcd = RedisConnectorDeletion(cc_pair_id) - - fence_value = r.get(rcd.fence_key) - if fence_value is None: - return - - try: - initial_count = int(cast(int, fence_value)) - except ValueError: - task_logger.error("The value is not an integer.") - return - - count = cast(int, r.scard(rcd.taskset_key)) - task_logger.info( - f"Connector deletion: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}" - ) - if count > 0: - return - - with Session(get_sqlalchemy_engine()) as db_session: - cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) - if not cc_pair: - return - - try: - # clean up the rest of the related Postgres entities - # index attempts - delete_index_attempts( - db_session=db_session, - cc_pair_id=cc_pair.id, - ) - - # document sets - delete_document_set_cc_pair_relationship__no_commit( - db_session=db_session, - connector_id=cc_pair.connector_id, - credential_id=cc_pair.credential_id, - ) - - # user groups - cleanup_user_groups = fetch_versioned_implementation_with_fallback( - "danswer.db.user_group", - "delete_user_group_cc_pair_relationship__no_commit", - noop_fallback, - ) - cleanup_user_groups( - cc_pair_id=cc_pair.id, - db_session=db_session, - ) - - # finally, delete the cc-pair - delete_connector_credential_pair__no_commit( - db_session=db_session, - connector_id=cc_pair.connector_id, - credential_id=cc_pair.credential_id, - ) - # if there are no credentials left, delete the connector - connector = fetch_connector_by_id( - db_session=db_session, - connector_id=cc_pair.connector_id, - ) - if not connector or not len(connector.credentials): - task_logger.info( - "Found no credentials left for connector, deleting connector" - ) - db_session.delete(connector) - db_session.commit() - except Exception as e: - stack_trace = traceback.format_exc() - error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}" - add_deletion_failure_message(db_session, cc_pair.id, error_message) - task_logger.exception( - f"Failed to run connector_deletion. " - f"connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}" - ) - raise e - - task_logger.info( - f"Successfully deleted connector_credential_pair with connector_id: '{cc_pair.connector_id}' " - f"and credential_id: '{cc_pair.credential_id}'. " - f"Deleted {initial_count} docs." - ) - - r.delete(rcd.taskset_key) - r.delete(rcd.fence_key) - - -@celery_app.task(name="monitor_vespa_sync", soft_time_limit=300) -def monitor_vespa_sync() -> None: - """This is a celery beat task that monitors and finalizes metadata sync tasksets. - It scans for fence values and then gets the counts of any associated tasksets. - If the count is 0, that means all tasks finished and we should clean up. - - This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't - do anything too expensive in this function! - """ - r = redis_pool.get_client() - - lock_beat = r.lock( - DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK, - timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, - ) - - try: - # prevent overlapping tasks - if not lock_beat.acquire(blocking=False): - return - - if r.exists(RedisConnectorCredentialPair.get_fence_key()): - monitor_connector_taskset(r) - - for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): - monitor_document_set_taskset(key_bytes, r) - - for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): - monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback( - "danswer.background.celery_utils", - "monitor_usergroup_taskset", - noop_fallback, - ) - - monitor_usergroup_taskset(key_bytes, r) - - for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): - monitor_connector_deletion_taskset(key_bytes, r) - - # r_celery = celery_app.broker_connection().channel().client - # length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery) - # task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}") - except SoftTimeLimitExceeded: - task_logger.info( - "Soft time limit exceeded, task is being terminated gracefully." - ) - finally: - if lock_beat.owned(): - lock_beat.release() - - @beat_init.connect def on_beat_init(sender: Any, **kwargs: Any) -> None: - init_sqlalchemy_engine(POSTGRES_CELERY_BEAT_APP_NAME) + SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME) + SqlEngine.init_engine(pool_size=2, max_overflow=0) @worker_init.connect def on_worker_init(sender: Any, **kwargs: Any) -> None: - init_sqlalchemy_engine(POSTGRES_CELERY_WORKER_APP_NAME) + # decide some initial startup settings based on the celery worker's hostname + # (set at the command line) + hostname = sender.hostname + if hostname.startswith("light"): + SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME) + SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) + elif hostname.startswith("heavy"): + SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME) + SqlEngine.init_engine(pool_size=8, max_overflow=0) + else: + SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME) + SqlEngine.init_engine(pool_size=8, max_overflow=0) - # TODO(rkuo): this is singleton work that should be done on startup exactly once - # if we run multiple workers, we'll need to centralize where this cleanup happens r = redis_pool.get_client() + WAIT_INTERVAL = 5 + WAIT_LIMIT = 60 + + time_start = time.monotonic() + logger.info("Redis: Readiness check starting.") + while True: + try: + if r.ping(): + break + except Exception: + pass + + time_elapsed = time.monotonic() - time_start + logger.info( + f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" + ) + if time_elapsed > WAIT_LIMIT: + msg = ( + f"Redis: Readiness check did not succeed within the timeout " + f"({WAIT_LIMIT} seconds). Exiting..." + ) + logger.error(msg) + raise WorkerShutdown(msg) + + time.sleep(WAIT_INTERVAL) + + logger.info("Redis: Readiness check succeeded. Continuing...") + + if not celery_is_worker_primary(sender): + logger.info("Running as a secondary celery worker.") + logger.info("Waiting for primary worker to be ready...") + time_start = time.monotonic() + while True: + if r.exists(DanswerRedisLocks.PRIMARY_WORKER): + break + + time.monotonic() + time_elapsed = time.monotonic() - time_start + logger.info( + f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" + ) + if time_elapsed > WAIT_LIMIT: + msg = ( + f"Primary worker was not ready within the timeout. " + f"({WAIT_LIMIT} seconds). Exiting..." + ) + logger.error(msg) + raise WorkerShutdown(msg) + + time.sleep(WAIT_INTERVAL) + + logger.info("Wait for primary worker completed successfully. Continuing...") + return + + logger.info("Running as the primary celery worker.") + + # This is singleton work that should be done on startup exactly once + # by the primary worker + r = redis_pool.get_client() + + # For the moment, we're assuming that we are the only primary worker + # that should be running. + # TODO: maybe check for or clean up another zombie primary worker if we detect it + r.delete(DanswerRedisLocks.PRIMARY_WORKER) + + # this process wide lock is taken to help other workers start up in order. + # it is planned to use this lock to enforce singleton behavior on the primary + # worker, since the primary worker does redis cleanup on startup, but this isn't + # implemented yet. + lock = r.lock( + DanswerRedisLocks.PRIMARY_WORKER, + timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, + ) + + logger.info("Primary worker lock: Acquire starting.") + acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2) + if acquired: + logger.info("Primary worker lock: Acquire succeeded.") + else: + logger.error("Primary worker lock: Acquire failed!") + raise WorkerShutdown("Primary worker lock could not be acquired!") + + sender.primary_worker_lock = lock + r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK) @@ -1101,6 +241,26 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: r.delete(key) +@worker_ready.connect +def on_worker_ready(sender: Any, **kwargs: Any) -> None: + task_logger.info("worker_ready signal received.") + + +@worker_shutdown.connect +def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: + if not celery_is_worker_primary(sender): + return + + if not sender.primary_worker_lock: + return + + logger.info("Releasing primary worker lock.") + lock = sender.primary_worker_lock + if lock.owned(): + lock.release() + sender.primary_worker_lock = None + + class CeleryTaskPlainFormatter(PlainFormatter): def format(self, record: logging.LogRecord) -> str: task = current_task @@ -1172,6 +332,89 @@ def on_setup_logging( task_logger.propagate = False +class HubPeriodicTask(bootsteps.StartStopStep): + """Regularly reacquires the primary worker lock outside of the task queue. + Use the task_logger in this class to avoid double logging.""" + + # it's unclear to me whether using the hub's timer or the bootstep timer is better + requires = {"celery.worker.components:Hub"} + + def __init__(self, worker: Any, **kwargs: Any) -> None: + self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds + self.task_tref = None + + def start(self, worker: Any) -> None: + if not celery_is_worker_primary(worker): + return + + # Access the worker's event loop (hub) + hub = worker.consumer.controller.hub + + # Schedule the periodic task + self.task_tref = hub.call_repeatedly( + self.interval, self.run_periodic_task, worker + ) + task_logger.info("Scheduled periodic task with hub.") + + def run_periodic_task(self, worker: Any) -> None: + try: + if not worker.primary_worker_lock: + return + + if not hasattr(worker, "primary_worker_lock"): + return + + r = redis_pool.get_client() + + lock: redis.lock.Lock = worker.primary_worker_lock + + task_logger.info("Reacquiring primary worker lock.") + + if lock.owned(): + task_logger.debug("Reacquiring primary worker lock.") + lock.reacquire() + else: + task_logger.warning( + "Full acquisition of primary worker lock. " + "Reasons could be computer sleep or a clock change." + ) + lock = r.lock( + DanswerRedisLocks.PRIMARY_WORKER, + timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, + ) + + task_logger.info("Primary worker lock: Acquire starting.") + acquired = lock.acquire( + blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2 + ) + if acquired: + task_logger.info("Primary worker lock: Acquire succeeded.") + else: + task_logger.error("Primary worker lock: Acquire failed!") + raise TimeoutError("Primary worker lock could not be acquired!") + + worker.primary_worker_lock = lock + except Exception: + task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.") + + def stop(self, worker: Any) -> None: + # Cancel the scheduled task when the worker stops + if self.task_tref: + self.task_tref.cancel() + task_logger.info("Canceled periodic task with hub.") + + +celery_app.steps["worker"].add(HubPeriodicTask) + +celery_app.autodiscover_tasks( + [ + "danswer.background.celery.tasks.connector_deletion", + "danswer.background.celery.tasks.periodic", + "danswer.background.celery.tasks.pruning", + "danswer.background.celery.tasks.vespa", + ] +) + ##### # Celery Beat (Periodic Tasks) Settings ##### diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index 8cda63b8f..9ee282e1a 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -1,5 +1,6 @@ from datetime import datetime from datetime import timezone +from typing import Any from sqlalchemy.orm import Session @@ -160,3 +161,30 @@ def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> se all_connector_doc_ids.update(doc_batch_processing_func(doc_batch)) return all_connector_doc_ids + + +def celery_is_listening_to_queue(worker: Any, name: str) -> bool: + """Checks to see if we're listening to the named queue""" + + # how to get a list of queues this worker is listening to + # https://stackoverflow.com/questions/29790523/how-to-determine-which-queues-a-celery-worker-is-consuming-at-runtime + queue_names = list(worker.app.amqp.queues.consume_from.keys()) + for queue_name in queue_names: + if queue_name == name: + return True + + return False + + +def celery_is_worker_primary(worker: Any) -> bool: + """There are multiple approaches that could be taken, but the way we do it is to + check the hostname set for the celery worker, either in celeryconfig.py or on the + command line.""" + hostname = worker.hostname + if hostname.startswith("light"): + return False + + if hostname.startswith("heavy"): + return False + + return True diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py new file mode 100644 index 000000000..655487f71 --- /dev/null +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -0,0 +1,133 @@ +import redis +from celery import shared_task +from celery.exceptions import SoftTimeLimitExceeded +from celery.utils.log import get_task_logger +from redis import Redis +from sqlalchemy.orm import Session +from sqlalchemy.orm.exc import ObjectDeletedError + +from danswer.background.celery.celery_app import celery_app +from danswer.background.celery.celery_redis import RedisConnectorDeletion +from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.configs.constants import DanswerRedisLocks +from danswer.db.connector_credential_pair import get_connector_credential_pairs +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.db.enums import IndexingStatus +from danswer.db.index_attempt import get_last_attempt +from danswer.db.models import ConnectorCredentialPair +from danswer.db.search_settings import get_current_search_settings +from danswer.redis.redis_pool import RedisPool + +redis_pool = RedisPool() + +# use this within celery tasks to get celery task specific logging +task_logger = get_task_logger(__name__) + + +@shared_task( + name="check_for_connector_deletion_task", + soft_time_limit=JOB_TIMEOUT, + trail=False, +) +def check_for_connector_deletion_task() -> None: + r = redis_pool.get_client() + + lock_beat = r.lock( + DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK, + timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, + ) + + try: + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return + + with Session(get_sqlalchemy_engine()) as db_session: + cc_pairs = get_connector_credential_pairs(db_session) + for cc_pair in cc_pairs: + try_generate_document_cc_pair_cleanup_tasks( + cc_pair, db_session, r, lock_beat + ) + except SoftTimeLimitExceeded: + task_logger.info( + "Soft time limit exceeded, task is being terminated gracefully." + ) + except Exception: + task_logger.exception("Unexpected exception") + finally: + if lock_beat.owned(): + lock_beat.release() + + +def try_generate_document_cc_pair_cleanup_tasks( + cc_pair: ConnectorCredentialPair, + db_session: Session, + r: Redis, + lock_beat: redis.lock.Lock, +) -> int | None: + """Returns an int if syncing is needed. The int represents the number of sync tasks generated. + Note that syncing can still be required even if the number of sync tasks generated is zero. + Returns None if no syncing is required. + """ + + lock_beat.reacquire() + + rcd = RedisConnectorDeletion(cc_pair.id) + + # don't generate sync tasks if tasks are still pending + if r.exists(rcd.fence_key): + return None + + # we need to refresh the state of the object inside the fence + # to avoid a race condition with db.commit/fence deletion + # at the end of this taskset + try: + db_session.refresh(cc_pair) + except ObjectDeletedError: + return None + + if cc_pair.status != ConnectorCredentialPairStatus.DELETING: + return None + + search_settings = get_current_search_settings(db_session) + + last_indexing = get_last_attempt( + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + search_settings_id=search_settings.id, + db_session=db_session, + ) + if last_indexing: + if ( + last_indexing.status == IndexingStatus.IN_PROGRESS + or last_indexing.status == IndexingStatus.NOT_STARTED + ): + return None + + # add tasks to celery and build up the task set to monitor in redis + r.delete(rcd.taskset_key) + + # Add all documents that need to be updated into the queue + task_logger.info( + f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}" + ) + tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat) + if tasks_generated is None: + return None + + # Currently we are allowing the sync to proceed with 0 tasks. + # It's possible for sets/groups to be generated initially with no entries + # and they still need to be marked as up to date. + # if tasks_generated == 0: + # return 0 + + task_logger.info( + f"RedisConnectorDeletion.generate_tasks finished. " + f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" + ) + + # set this only after all tasks have been added + r.set(rcd.fence_key, tasks_generated) + return tasks_generated diff --git a/backend/danswer/background/celery/tasks/periodic/tasks.py b/backend/danswer/background/celery/tasks/periodic/tasks.py new file mode 100644 index 000000000..bd3b082ae --- /dev/null +++ b/backend/danswer/background/celery/tasks/periodic/tasks.py @@ -0,0 +1,140 @@ +##### +# Periodic Tasks +##### +import json +from typing import Any + +from celery import shared_task +from celery.contrib.abortable import AbortableTask # type: ignore +from celery.exceptions import TaskRevokedError +from celery.utils.log import get_task_logger +from sqlalchemy import inspect +from sqlalchemy import text +from sqlalchemy.orm import Session + +from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.configs.constants import PostgresAdvisoryLocks +from danswer.db.engine import get_sqlalchemy_engine # type: ignore + +# use this within celery tasks to get celery task specific logging +task_logger = get_task_logger(__name__) + + +@shared_task( + name="kombu_message_cleanup_task", + soft_time_limit=JOB_TIMEOUT, + bind=True, + base=AbortableTask, +) +def kombu_message_cleanup_task(self: Any) -> int: + """Runs periodically to clean up the kombu_message table""" + + # we will select messages older than this amount to clean up + KOMBU_MESSAGE_CLEANUP_AGE = 7 # days + KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000 + + ctx = {} + ctx["last_processed_id"] = 0 + ctx["deleted"] = 0 + ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE + ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT + with Session(get_sqlalchemy_engine()) as db_session: + # Exit the task if we can't take the advisory lock + result = db_session.execute( + text("SELECT pg_try_advisory_lock(:id)"), + {"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value}, + ).scalar() + if not result: + return 0 + + while True: + if self.is_aborted(): + raise TaskRevokedError("kombu_message_cleanup_task was aborted.") + + b = kombu_message_cleanup_task_helper(ctx, db_session) + if not b: + break + + db_session.commit() + + if ctx["deleted"] > 0: + task_logger.info( + f"Deleted {ctx['deleted']} orphaned messages from kombu_message." + ) + + return ctx["deleted"] + + +def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool: + """ + Helper function to clean up old messages from the `kombu_message` table that are no longer relevant. + + This function retrieves messages from the `kombu_message` table that are no longer visible and + older than a specified interval. It checks if the corresponding task_id exists in the + `celery_taskmeta` table. If the task_id does not exist, the message is deleted. + + Args: + ctx (dict): A context dictionary containing configuration parameters such as: + - 'cleanup_age' (int): The age in days after which messages are considered old. + - 'page_limit' (int): The maximum number of messages to process in one batch. + - 'last_processed_id' (int): The ID of the last processed message to handle pagination. + - 'deleted' (int): A counter to track the number of deleted messages. + db_session (Session): The SQLAlchemy database session for executing queries. + + Returns: + bool: Returns True if there are more rows to process, False if not. + """ + + inspector = inspect(db_session.bind) + if not inspector: + return False + + # With the move to redis as celery's broker and backend, kombu tables may not even exist. + # We can fail silently. + if not inspector.has_table("kombu_message"): + return False + + query = text( + """ + SELECT id, timestamp, payload + FROM kombu_message WHERE visible = 'false' + AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days + AND id > :last_processed_id + ORDER BY id + LIMIT :page_limit +""" + ) + kombu_messages = db_session.execute( + query, + { + "interval_days": f"{ctx['cleanup_age']} days", + "page_limit": ctx["page_limit"], + "last_processed_id": ctx["last_processed_id"], + }, + ).fetchall() + + if len(kombu_messages) == 0: + return False + + for msg in kombu_messages: + payload = json.loads(msg[2]) + task_id = payload["headers"]["id"] + + # Check if task_id exists in celery_taskmeta + task_exists = db_session.execute( + text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"), + {"task_id": task_id}, + ).fetchone() + + # If task_id does not exist, delete the message + if not task_exists: + result = db_session.execute( + text("DELETE FROM kombu_message WHERE id = :message_id"), + {"message_id": msg[0]}, + ) + if result.rowcount > 0: # type: ignore + ctx["deleted"] += 1 + + ctx["last_processed_id"] = msg[0] + + return True diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py new file mode 100644 index 000000000..2f840e430 --- /dev/null +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -0,0 +1,120 @@ +from celery import shared_task +from celery.utils.log import get_task_logger +from sqlalchemy.orm import Session + +from danswer.background.celery.celery_app import celery_app +from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector +from danswer.background.celery.celery_utils import should_prune_cc_pair +from danswer.background.connector_deletion import delete_connector_credential_pair_batch +from danswer.background.task_utils import build_celery_task_wrapper +from danswer.background.task_utils import name_cc_prune_task +from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.connectors.factory import instantiate_connector +from danswer.connectors.models import InputType +from danswer.db.connector_credential_pair import get_connector_credential_pair +from danswer.db.connector_credential_pair import get_connector_credential_pairs +from danswer.db.document import get_documents_for_connector_credential_pair +from danswer.db.engine import get_sqlalchemy_engine +from danswer.document_index.document_index_utils import get_both_index_names +from danswer.document_index.factory import get_default_document_index + + +# use this within celery tasks to get celery task specific logging +task_logger = get_task_logger(__name__) + + +@shared_task( + name="check_for_prune_task", + soft_time_limit=JOB_TIMEOUT, +) +def check_for_prune_task() -> None: + """Runs periodically to check if any prune tasks should be run and adds them + to the queue""" + + with Session(get_sqlalchemy_engine()) as db_session: + all_cc_pairs = get_connector_credential_pairs(db_session) + + for cc_pair in all_cc_pairs: + if should_prune_cc_pair( + connector=cc_pair.connector, + credential=cc_pair.credential, + db_session=db_session, + ): + task_logger.info(f"Pruning the {cc_pair.connector.name} connector") + + prune_documents_task.apply_async( + kwargs=dict( + connector_id=cc_pair.connector.id, + credential_id=cc_pair.credential.id, + ) + ) + + +@build_celery_task_wrapper(name_cc_prune_task) +@celery_app.task(name="prune_documents_task", soft_time_limit=JOB_TIMEOUT) +def prune_documents_task(connector_id: int, credential_id: int) -> None: + """connector pruning task. For a cc pair, this task pulls all document IDs from the source + and compares those IDs to locally stored documents and deletes all locally stored IDs missing + from the most recently pulled document ID list""" + with Session(get_sqlalchemy_engine()) as db_session: + try: + cc_pair = get_connector_credential_pair( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + ) + + if not cc_pair: + task_logger.warning( + f"ccpair not found for {connector_id} {credential_id}" + ) + return + + runnable_connector = instantiate_connector( + db_session, + cc_pair.connector.source, + InputType.PRUNE, + cc_pair.connector.connector_specific_config, + cc_pair.credential, + ) + + all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector( + runnable_connector + ) + + all_indexed_document_ids = { + doc.id + for doc in get_documents_for_connector_credential_pair( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + ) + } + + doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids) + + curr_ind_name, sec_ind_name = get_both_index_names(db_session) + document_index = get_default_document_index( + primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name + ) + + if len(doc_ids_to_remove) == 0: + task_logger.info( + f"No docs to prune from {cc_pair.connector.source} connector" + ) + return + + task_logger.info( + f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector" + ) + delete_connector_credential_pair_batch( + document_ids=doc_ids_to_remove, + connector_id=connector_id, + credential_id=credential_id, + document_index=document_index, + ) + except Exception as e: + task_logger.exception( + f"Failed to run pruning for connector id {connector_id}." + ) + raise e diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py new file mode 100644 index 000000000..d11d317d0 --- /dev/null +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -0,0 +1,526 @@ +import traceback +from typing import cast + +import redis +from celery import shared_task +from celery import Task +from celery.exceptions import SoftTimeLimitExceeded +from celery.utils.log import get_task_logger +from redis import Redis +from sqlalchemy.orm import Session + +from danswer.access.access import get_access_for_document +from danswer.background.celery.celery_app import celery_app +from danswer.background.celery.celery_redis import RedisConnectorCredentialPair +from danswer.background.celery.celery_redis import RedisConnectorDeletion +from danswer.background.celery.celery_redis import RedisDocumentSet +from danswer.background.celery.celery_redis import RedisUserGroup +from danswer.configs.app_configs import JOB_TIMEOUT +from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT +from danswer.configs.constants import DanswerRedisLocks +from danswer.db.connector import fetch_connector_by_id +from danswer.db.connector_credential_pair import add_deletion_failure_message +from danswer.db.connector_credential_pair import ( + delete_connector_credential_pair__no_commit, +) +from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id +from danswer.db.connector_credential_pair import get_connector_credential_pairs +from danswer.db.document import count_documents_by_needs_sync +from danswer.db.document import get_document +from danswer.db.document import mark_document_as_synced +from danswer.db.document_set import delete_document_set +from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit +from danswer.db.document_set import fetch_document_sets +from danswer.db.document_set import fetch_document_sets_for_document +from danswer.db.document_set import get_document_set_by_id +from danswer.db.document_set import mark_document_set_as_synced +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.index_attempt import delete_index_attempts +from danswer.db.models import DocumentSet +from danswer.db.models import UserGroup +from danswer.document_index.document_index_utils import get_both_index_names +from danswer.document_index.factory import get_default_document_index +from danswer.document_index.interfaces import UpdateRequest +from danswer.redis.redis_pool import RedisPool +from danswer.utils.variable_functionality import fetch_versioned_implementation +from danswer.utils.variable_functionality import ( + fetch_versioned_implementation_with_fallback, +) +from danswer.utils.variable_functionality import noop_fallback + +redis_pool = RedisPool() + +# use this within celery tasks to get celery task specific logging +task_logger = get_task_logger(__name__) + + +# celery auto associates tasks created inside another task, +# which bloats the result metadata considerably. trail=False prevents this. +@shared_task( + name="check_for_vespa_sync_task", + soft_time_limit=JOB_TIMEOUT, + trail=False, +) +def check_for_vespa_sync_task() -> None: + """Runs periodically to check if any document needs syncing. + Generates sets of tasks for Celery if syncing is needed.""" + + r = redis_pool.get_client() + + lock_beat = r.lock( + DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK, + timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, + ) + + try: + # these tasks should never overlap + if not lock_beat.acquire(blocking=False): + return + + with Session(get_sqlalchemy_engine()) as db_session: + try_generate_stale_document_sync_tasks(db_session, r, lock_beat) + + # check if any document sets are not synced + document_set_info = fetch_document_sets( + user_id=None, db_session=db_session, include_outdated=True + ) + for document_set, _ in document_set_info: + try_generate_document_set_sync_tasks( + document_set, db_session, r, lock_beat + ) + + # check if any user groups are not synced + try: + fetch_user_groups = fetch_versioned_implementation( + "danswer.db.user_group", "fetch_user_groups" + ) + + user_groups = fetch_user_groups( + db_session=db_session, only_up_to_date=False + ) + for usergroup in user_groups: + try_generate_user_group_sync_tasks( + usergroup, db_session, r, lock_beat + ) + except ModuleNotFoundError: + # Always exceptions on the MIT version, which is expected + pass + except SoftTimeLimitExceeded: + task_logger.info( + "Soft time limit exceeded, task is being terminated gracefully." + ) + except Exception: + task_logger.exception("Unexpected exception") + finally: + if lock_beat.owned(): + lock_beat.release() + + +def try_generate_stale_document_sync_tasks( + db_session: Session, r: Redis, lock_beat: redis.lock.Lock +) -> int | None: + # the fence is up, do nothing + if r.exists(RedisConnectorCredentialPair.get_fence_key()): + return None + + r.delete(RedisConnectorCredentialPair.get_taskset_key()) # delete the taskset + + # add tasks to celery and build up the task set to monitor in redis + stale_doc_count = count_documents_by_needs_sync(db_session) + if stale_doc_count == 0: + return None + + task_logger.info( + f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair." + ) + + task_logger.info("RedisConnector.generate_tasks starting by cc_pair.") + + # rkuo: we could technically sync all stale docs in one big pass. + # but I feel it's more understandable to group the docs by cc_pair + total_tasks_generated = 0 + cc_pairs = get_connector_credential_pairs(db_session) + for cc_pair in cc_pairs: + rc = RedisConnectorCredentialPair(cc_pair.id) + tasks_generated = rc.generate_tasks(celery_app, db_session, r, lock_beat) + + if tasks_generated is None: + continue + + if tasks_generated == 0: + continue + + task_logger.info( + f"RedisConnector.generate_tasks finished for single cc_pair. " + f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" + ) + + total_tasks_generated += tasks_generated + + task_logger.info( + f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}" + ) + + r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated) + return total_tasks_generated + + +def try_generate_document_set_sync_tasks( + document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock +) -> int | None: + lock_beat.reacquire() + + rds = RedisDocumentSet(document_set.id) + + # don't generate document set sync tasks if tasks are still pending + if r.exists(rds.fence_key): + return None + + # don't generate sync tasks if we're up to date + # race condition with the monitor/cleanup function if we use a cached result! + db_session.refresh(document_set) + if document_set.is_up_to_date: + return None + + # add tasks to celery and build up the task set to monitor in redis + r.delete(rds.taskset_key) + + task_logger.info( + f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}" + ) + + # Add all documents that need to be updated into the queue + tasks_generated = rds.generate_tasks(celery_app, db_session, r, lock_beat) + if tasks_generated is None: + return None + + # Currently we are allowing the sync to proceed with 0 tasks. + # It's possible for sets/groups to be generated initially with no entries + # and they still need to be marked as up to date. + # if tasks_generated == 0: + # return 0 + + task_logger.info( + f"RedisDocumentSet.generate_tasks finished. " + f"document_set_id={document_set.id} tasks_generated={tasks_generated}" + ) + + # set this only after all tasks have been added + r.set(rds.fence_key, tasks_generated) + return tasks_generated + + +def try_generate_user_group_sync_tasks( + usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock +) -> int | None: + lock_beat.reacquire() + + rug = RedisUserGroup(usergroup.id) + + # don't generate sync tasks if tasks are still pending + if r.exists(rug.fence_key): + return None + + # race condition with the monitor/cleanup function if we use a cached result! + db_session.refresh(usergroup) + if usergroup.is_up_to_date: + return None + + # add tasks to celery and build up the task set to monitor in redis + r.delete(rug.taskset_key) + + # Add all documents that need to be updated into the queue + task_logger.info( + f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}" + ) + tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat) + if tasks_generated is None: + return None + + # Currently we are allowing the sync to proceed with 0 tasks. + # It's possible for sets/groups to be generated initially with no entries + # and they still need to be marked as up to date. + # if tasks_generated == 0: + # return 0 + + task_logger.info( + f"RedisUserGroup.generate_tasks finished. " + f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}" + ) + + # set this only after all tasks have been added + r.set(rug.fence_key, tasks_generated) + return tasks_generated + + +def monitor_connector_taskset(r: Redis) -> None: + fence_value = r.get(RedisConnectorCredentialPair.get_fence_key()) + if fence_value is None: + return + + try: + initial_count = int(cast(int, fence_value)) + except ValueError: + task_logger.error("The value is not an integer.") + return + + count = r.scard(RedisConnectorCredentialPair.get_taskset_key()) + task_logger.info( + f"Stale document sync progress: remaining={count} initial={initial_count}" + ) + if count == 0: + r.delete(RedisConnectorCredentialPair.get_taskset_key()) + r.delete(RedisConnectorCredentialPair.get_fence_key()) + task_logger.info(f"Successfully synced stale documents. count={initial_count}") + + +def monitor_document_set_taskset( + key_bytes: bytes, r: Redis, db_session: Session +) -> None: + fence_key = key_bytes.decode("utf-8") + document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key) + if document_set_id is None: + task_logger.warning("could not parse document set id from {key}") + return + + rds = RedisDocumentSet(document_set_id) + + fence_value = r.get(rds.fence_key) + if fence_value is None: + return + + try: + initial_count = int(cast(int, fence_value)) + except ValueError: + task_logger.error("The value is not an integer.") + return + + count = cast(int, r.scard(rds.taskset_key)) + task_logger.info( + f"Document set sync progress: document_set_id={document_set_id} remaining={count} initial={initial_count}" + ) + if count > 0: + return + + document_set = cast( + DocumentSet, + get_document_set_by_id(db_session=db_session, document_set_id=document_set_id), + ) # casting since we "know" a document set with this ID exists + if document_set: + if not document_set.connector_credential_pairs: + # if there are no connectors, then delete the document set. + delete_document_set(document_set_row=document_set, db_session=db_session) + task_logger.info( + f"Successfully deleted document set with ID: '{document_set_id}'!" + ) + else: + mark_document_set_as_synced(document_set_id, db_session) + task_logger.info( + f"Successfully synced document set with ID: '{document_set_id}'!" + ) + + r.delete(rds.taskset_key) + r.delete(rds.fence_key) + + +def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None: + fence_key = key_bytes.decode("utf-8") + cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key) + if cc_pair_id is None: + task_logger.warning("could not parse document set id from {key}") + return + + rcd = RedisConnectorDeletion(cc_pair_id) + + fence_value = r.get(rcd.fence_key) + if fence_value is None: + return + + try: + initial_count = int(cast(int, fence_value)) + except ValueError: + task_logger.error("The value is not an integer.") + return + + count = cast(int, r.scard(rcd.taskset_key)) + task_logger.info( + f"Connector deletion progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}" + ) + if count > 0: + return + + with Session(get_sqlalchemy_engine()) as db_session: + cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) + if not cc_pair: + return + + try: + # clean up the rest of the related Postgres entities + # index attempts + delete_index_attempts( + db_session=db_session, + cc_pair_id=cc_pair.id, + ) + + # document sets + delete_document_set_cc_pair_relationship__no_commit( + db_session=db_session, + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + ) + + # user groups + cleanup_user_groups = fetch_versioned_implementation_with_fallback( + "danswer.db.user_group", + "delete_user_group_cc_pair_relationship__no_commit", + noop_fallback, + ) + cleanup_user_groups( + cc_pair_id=cc_pair.id, + db_session=db_session, + ) + + # finally, delete the cc-pair + delete_connector_credential_pair__no_commit( + db_session=db_session, + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + ) + # if there are no credentials left, delete the connector + connector = fetch_connector_by_id( + db_session=db_session, + connector_id=cc_pair.connector_id, + ) + if not connector or not len(connector.credentials): + task_logger.info( + "Found no credentials left for connector, deleting connector" + ) + db_session.delete(connector) + db_session.commit() + except Exception as e: + stack_trace = traceback.format_exc() + error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}" + add_deletion_failure_message(db_session, cc_pair.id, error_message) + task_logger.exception( + f"Failed to run connector_deletion. " + f"connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}" + ) + raise e + + task_logger.info( + f"Successfully deleted connector_credential_pair with connector_id: '{cc_pair.connector_id}' " + f"and credential_id: '{cc_pair.credential_id}'. " + f"Deleted {initial_count} docs." + ) + + r.delete(rcd.taskset_key) + r.delete(rcd.fence_key) + + +@shared_task(name="monitor_vespa_sync", soft_time_limit=300) +def monitor_vespa_sync() -> None: + """This is a celery beat task that monitors and finalizes metadata sync tasksets. + It scans for fence values and then gets the counts of any associated tasksets. + If the count is 0, that means all tasks finished and we should clean up. + + This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't + do anything too expensive in this function! + """ + r = redis_pool.get_client() + + lock_beat = r.lock( + DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK, + timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, + ) + + try: + # prevent overlapping tasks + if not lock_beat.acquire(blocking=False): + return + + if r.exists(RedisConnectorCredentialPair.get_fence_key()): + monitor_connector_taskset(r) + + for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"): + monitor_connector_deletion_taskset(key_bytes, r) + + with Session(get_sqlalchemy_engine()) as db_session: + for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): + monitor_document_set_taskset(key_bytes, r, db_session) + + for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): + monitor_usergroup_taskset = ( + fetch_versioned_implementation_with_fallback( + "danswer.background.celery.tasks.vespa.tasks", + "monitor_usergroup_taskset", + noop_fallback, + ) + ) + monitor_usergroup_taskset(key_bytes, r, db_session) + + # uncomment for debugging if needed + # r_celery = celery_app.broker_connection().channel().client + # length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery) + # task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}") + except SoftTimeLimitExceeded: + task_logger.info( + "Soft time limit exceeded, task is being terminated gracefully." + ) + finally: + if lock_beat.owned(): + lock_beat.release() + + +@shared_task( + name="vespa_metadata_sync_task", + bind=True, + soft_time_limit=45, + time_limit=60, + max_retries=3, +) +def vespa_metadata_sync_task(self: Task, document_id: str) -> bool: + task_logger.info(f"document_id={document_id}") + + try: + with Session(get_sqlalchemy_engine()) as db_session: + curr_ind_name, sec_ind_name = get_both_index_names(db_session) + document_index = get_default_document_index( + primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name + ) + + doc = get_document(document_id, db_session) + if not doc: + return False + + # document set sync + doc_sets = fetch_document_sets_for_document(document_id, db_session) + update_doc_sets: set[str] = set(doc_sets) + + # User group sync + doc_access = get_access_for_document( + document_id=document_id, db_session=db_session + ) + update_request = UpdateRequest( + document_ids=[document_id], + document_sets=update_doc_sets, + access=doc_access, + boost=doc.boost, + hidden=doc.hidden, + ) + + # update Vespa + document_index.update(update_requests=[update_request]) + + # update db last. Worst case = we crash right before this and + # the sync might repeat again later + mark_document_as_synced(document_id, db_session) + except SoftTimeLimitExceeded: + task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}") + except Exception as e: + task_logger.exception("Unexpected exception") + + # Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64 + countdown = 2 ** (self.request.retries + 4) + self.retry(exc=e, countdown=countdown) + + return True diff --git a/backend/danswer/background/connector_deletion.py b/backend/danswer/background/connector_deletion.py index 47a3477e6..983a3c129 100644 --- a/backend/danswer/background/connector_deletion.py +++ b/backend/danswer/background/connector_deletion.py @@ -10,15 +10,27 @@ are multiple connector / credential pairs that have indexed it connector / credential pair from the access list (6) delete all relevant entries from postgres """ +from celery import shared_task +from celery import Task +from celery.exceptions import SoftTimeLimitExceeded +from celery.utils.log import get_task_logger from sqlalchemy.orm import Session +from danswer.access.access import get_access_for_document from danswer.access.access import get_access_for_documents +from danswer.db.document import delete_document_by_connector_credential_pair__no_commit from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit from danswer.db.document import delete_documents_complete__no_commit +from danswer.db.document import get_document +from danswer.db.document import get_document_connector_count from danswer.db.document import get_document_connector_counts +from danswer.db.document import mark_document_as_synced from danswer.db.document import prepare_to_modify_documents +from danswer.db.document_set import fetch_document_sets_for_document from danswer.db.document_set import fetch_document_sets_for_documents from danswer.db.engine import get_sqlalchemy_engine +from danswer.document_index.document_index_utils import get_both_index_names +from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import DocumentIndex from danswer.document_index.interfaces import UpdateRequest from danswer.server.documents.models import ConnectorCredentialPairIdentifier @@ -26,6 +38,9 @@ from danswer.utils.logger import setup_logger logger = setup_logger() +# use this within celery tasks to get celery task specific logging +task_logger = get_task_logger(__name__) + _DELETION_BATCH_SIZE = 1000 @@ -108,3 +123,89 @@ def delete_connector_credential_pair_batch( ), ) db_session.commit() + + +@shared_task( + name="document_by_cc_pair_cleanup_task", + bind=True, + soft_time_limit=45, + time_limit=60, + max_retries=3, +) +def document_by_cc_pair_cleanup_task( + self: Task, document_id: str, connector_id: int, credential_id: int +) -> bool: + task_logger.info(f"document_id={document_id}") + + try: + with Session(get_sqlalchemy_engine()) as db_session: + curr_ind_name, sec_ind_name = get_both_index_names(db_session) + document_index = get_default_document_index( + primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name + ) + + count = get_document_connector_count(db_session, document_id) + if count == 1: + # count == 1 means this is the only remaining cc_pair reference to the doc + # delete it from vespa and the db + document_index.delete(doc_ids=[document_id]) + delete_documents_complete__no_commit( + db_session=db_session, + document_ids=[document_id], + ) + elif count > 1: + # count > 1 means the document still has cc_pair references + doc = get_document(document_id, db_session) + if not doc: + return False + + # the below functions do not include cc_pairs being deleted. + # i.e. they will correctly omit access for the current cc_pair + doc_access = get_access_for_document( + document_id=document_id, db_session=db_session + ) + + doc_sets = fetch_document_sets_for_document(document_id, db_session) + update_doc_sets: set[str] = set(doc_sets) + + update_request = UpdateRequest( + document_ids=[document_id], + document_sets=update_doc_sets, + access=doc_access, + boost=doc.boost, + hidden=doc.hidden, + ) + + # update Vespa. OK if doc doesn't exist. Raises exception otherwise. + document_index.update_single(update_request=update_request) + + # there are still other cc_pair references to the doc, so just resync to Vespa + delete_document_by_connector_credential_pair__no_commit( + db_session=db_session, + document_id=document_id, + connector_credential_pair_identifier=ConnectorCredentialPairIdentifier( + connector_id=connector_id, + credential_id=credential_id, + ), + ) + + mark_document_as_synced(document_id, db_session) + else: + pass + + # update_docs_last_modified__no_commit( + # db_session=db_session, + # document_ids=[document_id], + # ) + + db_session.commit() + except SoftTimeLimitExceeded: + task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}") + except Exception as e: + task_logger.exception("Unexpected exception") + + # Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64 + countdown = 2 ** (self.request.retries + 4) + self.retry(exc=e, countdown=countdown) + + return True diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 55eff6d0a..94e703635 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -416,6 +416,7 @@ def update_loop( warm_up_bi_encoder( embedding_model=embedding_model, ) + logger.notice("First inference complete.") client_primary: Client | SimpleJobClient client_secondary: Client | SimpleJobClient @@ -444,6 +445,7 @@ def update_loop( existing_jobs: dict[int, Future | SimpleJob] = {} + logger.notice("Startup complete. Waiting for indexing jobs...") while True: start = time.time() start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S") diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 670ad3771..e34b8b894 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -34,7 +34,9 @@ POSTGRES_WEB_APP_NAME = "web" POSTGRES_INDEXER_APP_NAME = "indexer" POSTGRES_CELERY_APP_NAME = "celery" POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat" -POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker" +POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary" +POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light" +POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy" POSTGRES_PERMISSIONS_APP_NAME = "permissions" POSTGRES_UNKNOWN_APP_NAME = "unknown" @@ -62,6 +64,7 @@ KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings" KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__" CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60 +CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120 class DocumentSource(str, Enum): @@ -186,6 +189,7 @@ class DanswerCeleryQueues: class DanswerRedisLocks: + PRIMARY_WORKER = "da_lock:primary_worker" CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat" MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat" CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat" diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 94b5d0123..af44498be 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -1,8 +1,10 @@ import contextlib +import threading import time from collections.abc import AsyncGenerator from collections.abc import Generator from datetime import datetime +from typing import Any from typing import ContextManager from sqlalchemy import event @@ -32,14 +34,9 @@ logger = setup_logger() SYNC_DB_API = "psycopg2" ASYNC_DB_API = "asyncpg" -POSTGRES_APP_NAME = ( - POSTGRES_UNKNOWN_APP_NAME # helps to diagnose open connections in postgres -) - # global so we don't create more than one engine per process # outside of being best practice, this is needed so we can properly pool # connections and not create a new pool on every request -_SYNC_ENGINE: Engine | None = None _ASYNC_ENGINE: AsyncEngine | None = None SessionFactory: sessionmaker[Session] | None = None @@ -108,6 +105,67 @@ def get_db_current_time(db_session: Session) -> datetime: return result +class SqlEngine: + """Class to manage a global sql alchemy engine (needed for proper resource control) + Will eventually subsume most of the standalone functions in this file. + Sync only for now""" + + _engine: Engine | None = None + _lock: threading.Lock = threading.Lock() + _app_name: str = POSTGRES_UNKNOWN_APP_NAME + + # Default parameters for engine creation + DEFAULT_ENGINE_KWARGS = { + "pool_size": 40, + "max_overflow": 10, + "pool_pre_ping": POSTGRES_POOL_PRE_PING, + "pool_recycle": POSTGRES_POOL_RECYCLE, + } + + def __init__(self) -> None: + pass + + @classmethod + def _init_engine(cls, **engine_kwargs: Any) -> Engine: + """Private helper method to create and return an Engine.""" + connection_string = build_connection_string( + db_api=SYNC_DB_API, app_name=cls._app_name + "_sync" + ) + merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs} + return create_engine(connection_string, **merged_kwargs) + + @classmethod + def init_engine(cls, **engine_kwargs: Any) -> None: + """Allow the caller to init the engine with extra params. Different clients + such as the API server and different celery workers and tasks + need different settings.""" + with cls._lock: + if not cls._engine: + cls._engine = cls._init_engine(**engine_kwargs) + + @classmethod + def get_engine(cls) -> Engine: + """Gets the sql alchemy engine. Will init a default engine if init hasn't + already been called. You probably want to init first!""" + if not cls._engine: + with cls._lock: + if not cls._engine: + cls._engine = cls._init_engine() + return cls._engine + + @classmethod + def set_app_name(cls, app_name: str) -> None: + """Class method to set the app name.""" + cls._app_name = app_name + + @classmethod + def get_app_name(cls) -> str: + """Class method to get current app name.""" + if not cls._app_name: + return "" + return cls._app_name + + def build_connection_string( *, db_api: str = ASYNC_DB_API, @@ -125,24 +183,11 @@ def build_connection_string( def init_sqlalchemy_engine(app_name: str) -> None: - global POSTGRES_APP_NAME - POSTGRES_APP_NAME = app_name + SqlEngine.set_app_name(app_name) def get_sqlalchemy_engine() -> Engine: - global _SYNC_ENGINE - if _SYNC_ENGINE is None: - connection_string = build_connection_string( - db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync" - ) - _SYNC_ENGINE = create_engine( - connection_string, - pool_size=40, - max_overflow=10, - pool_pre_ping=POSTGRES_POOL_PRE_PING, - pool_recycle=POSTGRES_POOL_RECYCLE, - ) - return _SYNC_ENGINE + return SqlEngine.get_engine() def get_sqlalchemy_async_engine() -> AsyncEngine: @@ -154,7 +199,9 @@ def get_sqlalchemy_async_engine() -> AsyncEngine: _ASYNC_ENGINE = create_async_engine( connection_string, connect_args={ - "server_settings": {"application_name": POSTGRES_APP_NAME + "_async"} + "server_settings": { + "application_name": SqlEngine.get_app_name() + "_async" + } }, pool_size=40, max_overflow=10, diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 9aacc985d..428666751 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -239,7 +239,7 @@ def prune_cc_pair( db_session: Session = Depends(get_session), ) -> StatusResponse[list[int]]: # avoiding circular refs - from danswer.background.celery.celery_app import prune_documents_task + from danswer.background.celery.tasks.pruning.tasks import prune_documents_task cc_pair = get_connector_credential_pair_from_id( cc_pair_id=cc_pair_id, diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index 481d2fedb..1ebe5bd06 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user from danswer.auth.users import current_curator_or_admin_user +from danswer.background.celery.celery_app import celery_app from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DocumentSource @@ -146,10 +147,6 @@ def create_deletion_attempt_for_connector_id( user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> None: - from danswer.background.celery.celery_app import ( - check_for_connector_deletion_task, - ) - connector_id = connector_credential_pair_identifier.connector_id credential_id = connector_credential_pair_identifier.credential_id @@ -193,8 +190,11 @@ def create_deletion_attempt_for_connector_id( status=ConnectorCredentialPairStatus.DELETING, ) - # run the beat task to pick up this deletion early - check_for_connector_deletion_task.apply_async( + db_session.commit() + + # run the beat task to pick up this deletion from the db immediately + celery_app.send_task( + "check_for_connector_deletion_task", priority=DanswerCeleryPriority.HIGH, ) diff --git a/backend/ee/danswer/background/celery/tasks/vespa/tasks.py b/backend/ee/danswer/background/celery/tasks/vespa/tasks.py new file mode 100644 index 000000000..d194b2ef9 --- /dev/null +++ b/backend/ee/danswer/background/celery/tasks/vespa/tasks.py @@ -0,0 +1,52 @@ +from typing import cast + +from redis import Redis +from sqlalchemy.orm import Session + +from danswer.background.celery.celery_app import task_logger +from danswer.background.celery.celery_redis import RedisUserGroup +from danswer.utils.logger import setup_logger +from ee.danswer.db.user_group import delete_user_group +from ee.danswer.db.user_group import fetch_user_group +from ee.danswer.db.user_group import mark_user_group_as_synced + +logger = setup_logger() + + +def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None: + """This function is likely to move in the worker refactor happening next.""" + key = key_bytes.decode("utf-8") + usergroup_id = RedisUserGroup.get_id_from_fence_key(key) + if not usergroup_id: + task_logger.warning("Could not parse usergroup id from {key}") + return + + rug = RedisUserGroup(usergroup_id) + fence_value = r.get(rug.fence_key) + if fence_value is None: + return + + try: + initial_count = int(cast(int, fence_value)) + except ValueError: + task_logger.error("The value is not an integer.") + return + + count = cast(int, r.scard(rug.taskset_key)) + task_logger.info( + f"User group sync: usergroup_id={usergroup_id} remaining={count} initial={initial_count}" + ) + if count > 0: + return + + user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id) + if user_group: + if user_group.is_up_for_deletion: + delete_user_group(db_session=db_session, user_group=user_group) + task_logger.info(f"Deleted usergroup. id='{usergroup_id}'") + else: + mark_user_group_as_synced(db_session=db_session, user_group=user_group) + task_logger.info(f"Synced usergroup. id='{usergroup_id}'") + + r.delete(rug.taskset_key) + r.delete(rug.fence_key) diff --git a/backend/ee/danswer/background/celery_utils.py b/backend/ee/danswer/background/celery_utils.py index afb68f2a2..c42812f81 100644 --- a/backend/ee/danswer/background/celery_utils.py +++ b/backend/ee/danswer/background/celery_utils.py @@ -1,11 +1,5 @@ -from typing import cast - -from redis import Redis from sqlalchemy.orm import Session -from danswer.background.celery.celery_app import task_logger -from danswer.background.celery.celery_redis import RedisUserGroup -from danswer.db.engine import get_sqlalchemy_engine from danswer.db.enums import AccessType from danswer.db.models import ConnectorCredentialPair from danswer.db.tasks import check_task_is_live_and_not_timed_out @@ -18,9 +12,6 @@ from ee.danswer.background.task_name_builders import ( from ee.danswer.background.task_name_builders import ( name_sync_external_group_permissions_task, ) -from ee.danswer.db.user_group import delete_user_group -from ee.danswer.db.user_group import fetch_user_group -from ee.danswer.db.user_group import mark_user_group_as_synced logger = setup_logger() @@ -79,43 +70,3 @@ def should_perform_external_group_permissions_check( return False return True - - -def monitor_usergroup_taskset(key_bytes: bytes, r: Redis) -> None: - """This function is likely to move in the worker refactor happening next.""" - key = key_bytes.decode("utf-8") - usergroup_id = RedisUserGroup.get_id_from_fence_key(key) - if not usergroup_id: - task_logger.warning("Could not parse usergroup id from {key}") - return - - rug = RedisUserGroup(usergroup_id) - fence_value = r.get(rug.fence_key) - if fence_value is None: - return - - try: - initial_count = int(cast(int, fence_value)) - except ValueError: - task_logger.error("The value is not an integer.") - return - - count = cast(int, r.scard(rug.taskset_key)) - task_logger.info( - f"User group sync: usergroup_id={usergroup_id} remaining={count} initial={initial_count}" - ) - if count > 0: - return - - with Session(get_sqlalchemy_engine()) as db_session: - user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id) - if user_group: - if user_group.is_up_for_deletion: - delete_user_group(db_session=db_session, user_group=user_group) - task_logger.info(f"Deleted usergroup. id='{usergroup_id}'") - else: - mark_user_group_as_synced(db_session=db_session, user_group=user_group) - task_logger.info(f"Synced usergroup. id='{usergroup_id}'") - - r.delete(rug.taskset_key) - r.delete(rug.fence_key) diff --git a/backend/scripts/dev_run_background_jobs.py b/backend/scripts/dev_run_background_jobs.py index 96fdc2115..a4a253a10 100644 --- a/backend/scripts/dev_run_background_jobs.py +++ b/backend/scripts/dev_run_background_jobs.py @@ -18,7 +18,8 @@ def monitor_process(process_name: str, process: subprocess.Popen) -> None: def run_jobs(exclude_indexing: bool) -> None: - cmd_worker = [ + # command setup + cmd_worker_primary = [ "celery", "-A", "ee.danswer.background.celery.celery_app", @@ -26,8 +27,38 @@ def run_jobs(exclude_indexing: bool) -> None: "--pool=threads", "--concurrency=6", "--loglevel=INFO", + "-n", + "primary@%n", "-Q", - "celery,vespa_metadata_sync,connector_deletion", + "celery", + ] + + cmd_worker_light = [ + "celery", + "-A", + "ee.danswer.background.celery.celery_app", + "worker", + "--pool=threads", + "--concurrency=16", + "--loglevel=INFO", + "-n", + "light@%n", + "-Q", + "vespa_metadata_sync,connector_deletion", + ] + + cmd_worker_heavy = [ + "celery", + "-A", + "ee.danswer.background.celery.celery_app", + "worker", + "--pool=threads", + "--concurrency=6", + "--loglevel=INFO", + "-n", + "heavy@%n", + "-Q", + "connector_pruning", ] cmd_beat = [ @@ -38,19 +69,38 @@ def run_jobs(exclude_indexing: bool) -> None: "--loglevel=INFO", ] - worker_process = subprocess.Popen( - cmd_worker, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True + # spawn processes + worker_primary_process = subprocess.Popen( + cmd_worker_primary, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True ) + + worker_light_process = subprocess.Popen( + cmd_worker_light, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True + ) + + worker_heavy_process = subprocess.Popen( + cmd_worker_heavy, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True + ) + beat_process = subprocess.Popen( cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True ) - worker_thread = threading.Thread( - target=monitor_process, args=("WORKER", worker_process) + # monitor threads + worker_primary_thread = threading.Thread( + target=monitor_process, args=("PRIMARY", worker_primary_process) + ) + worker_light_thread = threading.Thread( + target=monitor_process, args=("LIGHT", worker_light_process) + ) + worker_heavy_thread = threading.Thread( + target=monitor_process, args=("HEAVY", worker_heavy_process) ) beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process)) - worker_thread.start() + worker_primary_thread.start() + worker_light_thread.start() + worker_heavy_thread.start() beat_thread.start() if not exclude_indexing: @@ -93,7 +143,9 @@ def run_jobs(exclude_indexing: bool) -> None: except Exception: pass - worker_thread.join() + worker_primary_thread.join() + worker_light_thread.join() + worker_heavy_thread.join() beat_thread.join() diff --git a/backend/supervisord.conf b/backend/supervisord.conf index ff055a78f..5b9dca95b 100644 --- a/backend/supervisord.conf +++ b/backend/supervisord.conf @@ -24,16 +24,50 @@ autorestart=true # on a system, but this should be okay for now since all our celery tasks are # relatively compute-light (e.g. they tend to just make a bunch of requests to # Vespa / Postgres) -[program:celery_worker] +[program:celery_worker_primary] command=celery -A danswer.background.celery.celery_run:celery_app worker --pool=threads - --concurrency=6 + --concurrency=4 + --prefetch-multiplier=1 --loglevel=INFO - --logfile=/var/log/celery_worker_supervisor.log - -Q celery,vespa_metadata_sync,connector_deletion -environment=LOG_FILE_NAME=celery_worker + --logfile=/var/log/celery_worker_primary_supervisor.log + --hostname=primary@%%n + -Q celery +environment=LOG_FILE_NAME=celery_worker_primary redirect_stderr=true autorestart=true +startsecs=10 +stopasgroup=true + +[program:celery_worker_light] +command=celery -A danswer.background.celery.celery_run:celery_app worker + --pool=threads + --concurrency=16 + --prefetch-multiplier=8 + --loglevel=INFO + --logfile=/var/log/celery_worker_light_supervisor.log + --hostname=light@%%n + -Q vespa_metadata_sync,connector_deletion +environment=LOG_FILE_NAME=celery_worker_light +redirect_stderr=true +autorestart=true +startsecs=10 +stopasgroup=true + +[program:celery_worker_heavy] +command=celery -A danswer.background.celery.celery_run:celery_app worker + --pool=threads + --concurrency=4 + --prefetch-multiplier=1 + --loglevel=INFO + --logfile=/var/log/celery_worker_heavy_supervisor.log + --hostname=heavy@%%n + -Q connector_pruning +environment=LOG_FILE_NAME=celery_worker_heavy +redirect_stderr=true +autorestart=true +startsecs=10 +stopasgroup=true # Job scheduler for periodic tasks [program:celery_beat] @@ -41,6 +75,8 @@ command=celery -A danswer.background.celery.celery_run:celery_app beat --logfile=/var/log/celery_beat_supervisor.log environment=LOG_FILE_NAME=celery_beat redirect_stderr=true +startsecs=10 +stopasgroup=true # Listens for Slack messages and responds with answers # for all channels that the DanswerBot has been added to. @@ -60,9 +96,13 @@ startsecs=60 command=tail -qF /var/log/document_indexing_info.log /var/log/celery_beat_supervisor.log - /var/log/celery_worker_supervisor.log + /var/log/celery_worker_primary_supervisor.log + /var/log/celery_worker_light_supervisor.log + /var/log/celery_worker_heavy_supervisor.log /var/log/celery_beat_debug.log - /var/log/celery_worker_debug.log + /var/log/celery_worker_primary_debug.log + /var/log/celery_worker_light_debug.log + /var/log/celery_worker_heavy_debug.log /var/log/slack_bot_debug.log stdout_logfile=/dev/stdout stdout_logfile_maxbytes=0