diff --git a/backend/danswer/background/celery/celery.py b/backend/danswer/background/celery/celery.py index 8d232cf3c..4aa16a2f0 100644 --- a/backend/danswer/background/celery/celery.py +++ b/backend/danswer/background/celery/celery.py @@ -4,10 +4,11 @@ from pathlib import Path from typing import cast from celery import Celery # type: ignore -from celery.result import AsyncResult from sqlalchemy.orm import Session -from danswer.background.connector_deletion import _delete_connector_credential_pair +from danswer.background.connector_deletion import delete_connector_credential_pair +from danswer.background.task_utils import build_celery_task_wrapper +from danswer.background.task_utils import name_cc_cleanup_task from danswer.background.task_utils import name_document_set_sync_task from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH from danswer.configs.app_configs import JOB_TIMEOUT @@ -30,9 +31,6 @@ from danswer.db.engine import SYNC_DB_API from danswer.db.models import DocumentSet from danswer.db.tasks import check_live_task_not_timed_out from danswer.db.tasks import get_latest_task -from danswer.db.tasks import mark_task_finished -from danswer.db.tasks import mark_task_start -from danswer.db.tasks import register_task from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger @@ -43,7 +41,6 @@ celery_backend_url = "db+" + build_connection_string(db_api=SYNC_DB_API) celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url) -_ExistingTaskCache: dict[int, AsyncResult] = {} _SYNC_BATCH_SIZE = 1000 @@ -52,6 +49,7 @@ _SYNC_BATCH_SIZE = 1000 # # If imports from this module are needed, use local imports to avoid circular importing ##### +@build_celery_task_wrapper(name_cc_cleanup_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def cleanup_connector_credential_pair_task( connector_id: int, @@ -79,7 +77,7 @@ def cleanup_connector_credential_pair_task( try: # The bulk of the work is in here, updates Postgres and Vespa - return _delete_connector_credential_pair( + return delete_connector_credential_pair( db_session=db_session, document_index=get_default_document_index(), cc_pair=cc_pair, @@ -89,6 +87,7 @@ def cleanup_connector_credential_pair_task( raise e +@build_celery_task_wrapper(name_document_set_sync_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) def sync_document_set_task(document_set_id: int) -> None: """For document sets marked as not up to date, sync the state from postgres @@ -125,9 +124,6 @@ def sync_document_set_task(document_set_id: int) -> None: ) with Session(get_sqlalchemy_engine()) as db_session: - task_name = name_document_set_sync_task(document_set_id) - mark_task_start(task_name, db_session) - try: document_index = get_default_document_index() documents_to_update = fetch_documents_for_document_set( @@ -166,11 +162,8 @@ def sync_document_set_task(document_set_id: int) -> None: except Exception: logger.exception("Failed to sync document set %s", document_set_id) - mark_task_finished(task_name, db_session, success=False) raise - mark_task_finished(task_name, db_session) - ##### # Periodic Tasks @@ -201,10 +194,9 @@ def check_for_document_sets_sync_task() -> None: continue logger.info(f"Document set {document_set.id} syncing now!") - task = sync_document_set_task.apply_async( + sync_document_set_task.apply_async( kwargs=dict(document_set_id=document_set.id), ) - register_task(task.id, task_name, db_session) @celery_app.task(name="clean_old_temp_files_task", soft_time_limit=JOB_TIMEOUT) diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index 516b8a2dd..21933cbcb 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -1,75 +1,23 @@ -import json -from typing import cast - -from celery.result import AsyncResult -from sqlalchemy import text from sqlalchemy.orm import Session -from danswer.background.celery.celery import celery_app from danswer.background.task_utils import name_cc_cleanup_task -from danswer.db.engine import get_sqlalchemy_engine -from danswer.db.models import DeletionStatus +from danswer.db.tasks import get_latest_task from danswer.server.models import DeletionAttemptSnapshot -def get_celery_task(task_id: str) -> AsyncResult: - """NOTE: even if the task doesn't exist, celery will still return something - with a `PENDING` state""" - return AsyncResult(task_id, backend=celery_app.backend) - - -def get_celery_task_status(task_id: str) -> str | None: - """NOTE: is tightly coupled to the internals of kombu (which is the - translation layer to allow us to use Postgres as a broker). If we change - the broker, this will need to be updated. - - This should not be called on any critical flows. - """ - # first check for any pending tasks - with Session(get_sqlalchemy_engine()) as session: - rows = session.execute(text("SELECT payload FROM kombu_message WHERE visible")) - for row in rows: - payload = json.loads(row[0]) - if payload["headers"]["id"] == task_id: - return "PENDING" - - task = get_celery_task(task_id) - # if not pending, then we know the task really exists - if task.status != "PENDING": - return task.status - - return None - - def get_deletion_status( - connector_id: int, credential_id: int + connector_id: int, credential_id: int, db_session: Session ) -> DeletionAttemptSnapshot | None: - cleanup_task_id = name_cc_cleanup_task( + cleanup_task_name = name_cc_cleanup_task( connector_id=connector_id, credential_id=credential_id ) - deletion_task = get_celery_task(task_id=cleanup_task_id) - deletion_task_status = get_celery_task_status(task_id=cleanup_task_id) + task_state = get_latest_task(task_name=cleanup_task_name, db_session=db_session) - deletion_status = None - error_msg = None - num_docs_deleted = 0 - if deletion_task_status == "SUCCESS": - deletion_status = DeletionStatus.SUCCESS - num_docs_deleted = cast(int, deletion_task.get(propagate=False)) - elif deletion_task_status == "FAILURE": - deletion_status = DeletionStatus.FAILED - error_msg = deletion_task.get(propagate=False) - elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING": - deletion_status = DeletionStatus.IN_PROGRESS + if not task_state: + return None - return ( - DeletionAttemptSnapshot( - connector_id=connector_id, - credential_id=credential_id, - status=deletion_status, - error_msg=str(error_msg), - num_docs_deleted=num_docs_deleted, - ) - if deletion_status - else None + return DeletionAttemptSnapshot( + connector_id=connector_id, + credential_id=credential_id, + status=task_state.status, ) diff --git a/backend/danswer/background/connector_deletion.py b/backend/danswer/background/connector_deletion.py index 6aea62850..ae2cfec18 100644 --- a/backend/danswer/background/connector_deletion.py +++ b/backend/danswer/background/connector_deletion.py @@ -144,7 +144,7 @@ def cleanup_synced_entities( ) -def _delete_connector_credential_pair( +def delete_connector_credential_pair( db_session: Session, document_index: DocumentIndex, cc_pair: ConnectorCredentialPair, diff --git a/backend/danswer/background/task_utils.py b/backend/danswer/background/task_utils.py index ec7f79f56..78a2938c3 100644 --- a/backend/danswer/background/task_utils.py +++ b/backend/danswer/background/task_utils.py @@ -1,6 +1,117 @@ +from collections.abc import Callable +from functools import wraps +from typing import Any +from typing import cast +from typing import TypeVar + +from celery import Task +from celery.result import AsyncResult +from sqlalchemy.orm import Session + +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.tasks import mark_task_finished +from danswer.db.tasks import mark_task_start +from danswer.db.tasks import register_task + + def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str: return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}" def name_document_set_sync_task(document_set_id: int) -> str: return f"sync_doc_set_{document_set_id}" + + +T = TypeVar("T", bound=Callable) + + +def build_run_wrapper(build_name_fn: Callable[..., str]) -> Callable[[T], T]: + """Utility meant to wrap the celery task `run` function in order to + automatically update our custom `task_queue_jobs` table appropriately""" + + def wrap_task_fn(task_fn: T) -> T: + @wraps(task_fn) + def wrapped_task_fn(*args: list, **kwargs: dict) -> Any: + engine = get_sqlalchemy_engine() + + task_name = build_name_fn(*args, **kwargs) + with Session(engine) as db_session: + # mark the task as started + mark_task_start(task_name=task_name, db_session=db_session) + + result = None + exception = None + try: + result = task_fn(*args, **kwargs) + except Exception as e: + exception = e + + with Session(engine) as db_session: + mark_task_finished( + task_name=task_name, + db_session=db_session, + success=exception is None, + ) + + if not exception: + return result + else: + raise exception + + return cast(T, wrapped_task_fn) + + return wrap_task_fn + + +# rough type signature for `apply_async` +AA = TypeVar("AA", bound=Callable[..., AsyncResult]) + + +def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA], AA]: + """Utility meant to wrap celery `apply_async` function in order to automatically + update create an entry in our `task_queue_jobs` table""" + + def wrapper(fn: AA) -> AA: + @wraps(fn) + def wrapped_fn( + args: tuple | None = None, + kwargs: dict[str, Any] | None = None, + *other_args: list, + **other_kwargs: dict[str, Any], + ) -> Any: + # `apply_async` takes in args / kwargs directly as arguments + args_for_build_name = args or tuple() + kwargs_for_build_name = kwargs or {} + task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name) + with Session(get_sqlalchemy_engine()) as db_session: + # mark the task as started + task = fn(args, kwargs, *other_args, **other_kwargs) + register_task(task.id, task_name, db_session) + + return task + + return cast(AA, wrapped_fn) + + return wrapper + + +def build_celery_task_wrapper( + build_name_fn: Callable[..., str] +) -> Callable[[Task], Task]: + """Utility meant to wrap celery task functions in order to automatically + update our custom `task_queue_jobs` table appropriately. + + On task creation (e.g. `apply_async`), a row is inserted into the table with + status `PENDING`. + On task start, the latest row is updated to have status `STARTED`. + On task success, the latest row is updated to have status `SUCCESS`. + On the task raising an unhandled exception, the latest row is updated to have + status `FAILURE`. + """ + + def wrap_task(task: Task) -> Task: + task.run = build_run_wrapper(build_name_fn)(task.run) # type: ignore + task.apply_async = build_apply_async_wrapper(build_name_fn)(task.apply_async) # type: ignore + return task + + return wrap_task diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index 9e788ce2b..d7954e639 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -15,7 +15,6 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user from danswer.auth.users import current_user from danswer.background.celery.celery_utils import get_deletion_status -from danswer.background.task_utils import name_cc_cleanup_task from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY @@ -366,7 +365,9 @@ def get_connector_indexing_status( if latest_index_attempt else None, deletion_attempt=get_deletion_status( - connector_id=connector.id, credential_id=credential.id + connector_id=connector.id, + credential_id=credential.id, + db_session=db_session, ), is_deletable=check_deletion_attempt_is_allowed( connector_credential_pair=cc_pair @@ -590,12 +591,8 @@ def create_deletion_attempt_for_connector_id( "no ongoing / planned indexing attempts.", ) - task_id = name_cc_cleanup_task( - connector_id=connector_id, credential_id=credential_id - ) cleanup_connector_credential_pair_task.apply_async( kwargs=dict(connector_id=connector_id, credential_id=credential_id), - task_id=task_id, ) diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 217855373..35194d9a5 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -22,10 +22,10 @@ from danswer.db.models import AllowedAnswerFilters from danswer.db.models import ChannelConfig from danswer.db.models import Connector from danswer.db.models import Credential -from danswer.db.models import DeletionStatus from danswer.db.models import DocumentSet as DocumentSetDBModel from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus +from danswer.db.models import TaskStatus from danswer.direct_qa.interfaces import DanswerQuote from danswer.search.models import QueryFlow from danswer.search.models import SearchType @@ -326,9 +326,7 @@ class IndexAttemptSnapshot(BaseModel): class DeletionAttemptSnapshot(BaseModel): connector_id: int credential_id: int - status: DeletionStatus - error_msg: str | None - num_docs_deleted: int + status: TaskStatus class ConnectorBase(BaseModel): diff --git a/backend/supervisord.conf b/backend/supervisord.conf index 3613f5f50..69b13bd64 100644 --- a/backend/supervisord.conf +++ b/backend/supervisord.conf @@ -11,8 +11,18 @@ redirect_stderr=true autorestart=true # Background jobs that must be run async due to long time to completion +# NOTE: due to an issue with Celery + SQLAlchemy +# (https://github.com/celery/celery/issues/7007#issuecomment-1740139367) +# we must use the threads pool instead of the default prefork pool for now +# in order to avoid intermittent errors like: +# `billiard.exceptions.WorkerLostError: Worker exited prematurely: signal 11 (SIGSEGV)`. +# +# This means workers will not be able take advantage of multiple CPU cores +# on a system, but this should be okay for now since all our celery tasks are +# relatively compute-light (e.g. they tend to just make a bunch of requests to +# Vespa / Postgres) [program:celery_worker] -command=celery -A danswer.background.celery worker --loglevel=INFO --logfile=/var/log/celery_worker.log +command=celery -A danswer.background.celery worker --pool=threads --loglevel=INFO --logfile=/var/log/celery_worker.log stdout_logfile=/var/log/celery_worker_supervisor.log stdout_logfile_maxbytes=52428800 redirect_stderr=true diff --git a/web/src/components/admin/connectors/table/ConnectorsTable.tsx b/web/src/components/admin/connectors/table/ConnectorsTable.tsx index 31d6a5ae3..8716c31f2 100644 --- a/web/src/components/admin/connectors/table/ConnectorsTable.tsx +++ b/web/src/components/admin/connectors/table/ConnectorsTable.tsx @@ -37,11 +37,7 @@ export function StatusRow({ } if (connector.disabled) { const deletionAttempt = connectorIndexingStatus.deletion_attempt; - if ( - !deletionAttempt || - deletionAttempt.status === "not_started" || - deletionAttempt.status === "failed" - ) { + if (!deletionAttempt || deletionAttempt.status === "FAILURE") { statusDisplay =
Disabled
; } else { statusDisplay =
Deleting...
; diff --git a/web/src/components/admin/connectors/table/SingleUseConnectorsTable.tsx b/web/src/components/admin/connectors/table/SingleUseConnectorsTable.tsx index 5a174d1d8..a69b42a8c 100644 --- a/web/src/components/admin/connectors/table/SingleUseConnectorsTable.tsx +++ b/web/src/components/admin/connectors/table/SingleUseConnectorsTable.tsx @@ -17,8 +17,8 @@ const SingleUseConnectorStatus = ({ }) => { if ( deletionAttempt && - (deletionAttempt.status === "in_progress" || - deletionAttempt.status === "not_started") + (deletionAttempt.status === "PENDING" || + deletionAttempt.status === "STARTED") ) { return
Deleting...
; } diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 38720fde9..fb43d3297 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -33,6 +33,7 @@ export type ValidStatuses = | "failed" | "in_progress" | "not_started"; +export type TaskStatus = "PENDING" | "STARTED" | "SUCCESS" | "FAILURE"; export interface DocumentBoostStatus { document_id: string; @@ -244,9 +245,7 @@ export interface Document360CredentialJson { export interface DeletionAttemptSnapshot { connector_id: number; credential_id: number; - status: ValidStatuses; - error_msg?: string; - num_docs_deleted: number; + status: TaskStatus; } // DOCUMENT SETS