Fix deletion status display + add celery util + fix seg faults (#615)

This commit is contained in:
Chris Weaver 2023-10-22 19:41:29 -07:00 committed by GitHub
parent 8403b94722
commit 89807c8c05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 150 additions and 99 deletions

View File

@ -4,10 +4,11 @@ from pathlib import Path
from typing import cast
from celery import Celery # type: ignore
from celery.result import AsyncResult
from sqlalchemy.orm import Session
from danswer.background.connector_deletion import _delete_connector_credential_pair
from danswer.background.connector_deletion import delete_connector_credential_pair
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
from danswer.configs.app_configs import JOB_TIMEOUT
@ -30,9 +31,6 @@ from danswer.db.engine import SYNC_DB_API
from danswer.db.models import DocumentSet
from danswer.db.tasks import check_live_task_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import mark_task_finished
from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
@ -43,7 +41,6 @@ celery_backend_url = "db+" + build_connection_string(db_api=SYNC_DB_API)
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
_ExistingTaskCache: dict[int, AsyncResult] = {}
_SYNC_BATCH_SIZE = 1000
@ -52,6 +49,7 @@ _SYNC_BATCH_SIZE = 1000
#
# If imports from this module are needed, use local imports to avoid circular importing
#####
@build_celery_task_wrapper(name_cc_cleanup_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def cleanup_connector_credential_pair_task(
connector_id: int,
@ -79,7 +77,7 @@ def cleanup_connector_credential_pair_task(
try:
# The bulk of the work is in here, updates Postgres and Vespa
return _delete_connector_credential_pair(
return delete_connector_credential_pair(
db_session=db_session,
document_index=get_default_document_index(),
cc_pair=cc_pair,
@ -89,6 +87,7 @@ def cleanup_connector_credential_pair_task(
raise e
@build_celery_task_wrapper(name_document_set_sync_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int) -> None:
"""For document sets marked as not up to date, sync the state from postgres
@ -125,9 +124,6 @@ def sync_document_set_task(document_set_id: int) -> None:
)
with Session(get_sqlalchemy_engine()) as db_session:
task_name = name_document_set_sync_task(document_set_id)
mark_task_start(task_name, db_session)
try:
document_index = get_default_document_index()
documents_to_update = fetch_documents_for_document_set(
@ -166,11 +162,8 @@ def sync_document_set_task(document_set_id: int) -> None:
except Exception:
logger.exception("Failed to sync document set %s", document_set_id)
mark_task_finished(task_name, db_session, success=False)
raise
mark_task_finished(task_name, db_session)
#####
# Periodic Tasks
@ -201,10 +194,9 @@ def check_for_document_sets_sync_task() -> None:
continue
logger.info(f"Document set {document_set.id} syncing now!")
task = sync_document_set_task.apply_async(
sync_document_set_task.apply_async(
kwargs=dict(document_set_id=document_set.id),
)
register_task(task.id, task_name, db_session)
@celery_app.task(name="clean_old_temp_files_task", soft_time_limit=JOB_TIMEOUT)

View File

@ -1,75 +1,23 @@
import json
from typing import cast
from celery.result import AsyncResult
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.celery import celery_app
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DeletionStatus
from danswer.db.tasks import get_latest_task
from danswer.server.models import DeletionAttemptSnapshot
def get_celery_task(task_id: str) -> AsyncResult:
"""NOTE: even if the task doesn't exist, celery will still return something
with a `PENDING` state"""
return AsyncResult(task_id, backend=celery_app.backend)
def get_celery_task_status(task_id: str) -> str | None:
"""NOTE: is tightly coupled to the internals of kombu (which is the
translation layer to allow us to use Postgres as a broker). If we change
the broker, this will need to be updated.
This should not be called on any critical flows.
"""
# first check for any pending tasks
with Session(get_sqlalchemy_engine()) as session:
rows = session.execute(text("SELECT payload FROM kombu_message WHERE visible"))
for row in rows:
payload = json.loads(row[0])
if payload["headers"]["id"] == task_id:
return "PENDING"
task = get_celery_task(task_id)
# if not pending, then we know the task really exists
if task.status != "PENDING":
return task.status
return None
def get_deletion_status(
connector_id: int, credential_id: int
connector_id: int, credential_id: int, db_session: Session
) -> DeletionAttemptSnapshot | None:
cleanup_task_id = name_cc_cleanup_task(
cleanup_task_name = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id
)
deletion_task = get_celery_task(task_id=cleanup_task_id)
deletion_task_status = get_celery_task_status(task_id=cleanup_task_id)
task_state = get_latest_task(task_name=cleanup_task_name, db_session=db_session)
deletion_status = None
error_msg = None
num_docs_deleted = 0
if deletion_task_status == "SUCCESS":
deletion_status = DeletionStatus.SUCCESS
num_docs_deleted = cast(int, deletion_task.get(propagate=False))
elif deletion_task_status == "FAILURE":
deletion_status = DeletionStatus.FAILED
error_msg = deletion_task.get(propagate=False)
elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING":
deletion_status = DeletionStatus.IN_PROGRESS
if not task_state:
return None
return (
DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=deletion_status,
error_msg=str(error_msg),
num_docs_deleted=num_docs_deleted,
)
if deletion_status
else None
return DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=task_state.status,
)

View File

@ -144,7 +144,7 @@ def cleanup_synced_entities(
)
def _delete_connector_credential_pair(
def delete_connector_credential_pair(
db_session: Session,
document_index: DocumentIndex,
cc_pair: ConnectorCredentialPair,

View File

@ -1,6 +1,117 @@
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import cast
from typing import TypeVar
from celery import Task
from celery.result import AsyncResult
from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.tasks import mark_task_finished
from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
def name_document_set_sync_task(document_set_id: int) -> str:
return f"sync_doc_set_{document_set_id}"
T = TypeVar("T", bound=Callable)
def build_run_wrapper(build_name_fn: Callable[..., str]) -> Callable[[T], T]:
"""Utility meant to wrap the celery task `run` function in order to
automatically update our custom `task_queue_jobs` table appropriately"""
def wrap_task_fn(task_fn: T) -> T:
@wraps(task_fn)
def wrapped_task_fn(*args: list, **kwargs: dict) -> Any:
engine = get_sqlalchemy_engine()
task_name = build_name_fn(*args, **kwargs)
with Session(engine) as db_session:
# mark the task as started
mark_task_start(task_name=task_name, db_session=db_session)
result = None
exception = None
try:
result = task_fn(*args, **kwargs)
except Exception as e:
exception = e
with Session(engine) as db_session:
mark_task_finished(
task_name=task_name,
db_session=db_session,
success=exception is None,
)
if not exception:
return result
else:
raise exception
return cast(T, wrapped_task_fn)
return wrap_task_fn
# rough type signature for `apply_async`
AA = TypeVar("AA", bound=Callable[..., AsyncResult])
def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA], AA]:
"""Utility meant to wrap celery `apply_async` function in order to automatically
update create an entry in our `task_queue_jobs` table"""
def wrapper(fn: AA) -> AA:
@wraps(fn)
def wrapped_fn(
args: tuple | None = None,
kwargs: dict[str, Any] | None = None,
*other_args: list,
**other_kwargs: dict[str, Any],
) -> Any:
# `apply_async` takes in args / kwargs directly as arguments
args_for_build_name = args or tuple()
kwargs_for_build_name = kwargs or {}
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
with Session(get_sqlalchemy_engine()) as db_session:
# mark the task as started
task = fn(args, kwargs, *other_args, **other_kwargs)
register_task(task.id, task_name, db_session)
return task
return cast(AA, wrapped_fn)
return wrapper
def build_celery_task_wrapper(
build_name_fn: Callable[..., str]
) -> Callable[[Task], Task]:
"""Utility meant to wrap celery task functions in order to automatically
update our custom `task_queue_jobs` table appropriately.
On task creation (e.g. `apply_async`), a row is inserted into the table with
status `PENDING`.
On task start, the latest row is updated to have status `STARTED`.
On task success, the latest row is updated to have status `SUCCESS`.
On the task raising an unhandled exception, the latest row is updated to have
status `FAILURE`.
"""
def wrap_task(task: Task) -> Task:
task.run = build_run_wrapper(build_name_fn)(task.run) # type: ignore
task.apply_async = build_apply_async_wrapper(build_name_fn)(task.apply_async) # type: ignore
return task
return wrap_task

View File

@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_utils import get_deletion_status
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
@ -366,7 +365,9 @@ def get_connector_indexing_status(
if latest_index_attempt
else None,
deletion_attempt=get_deletion_status(
connector_id=connector.id, credential_id=credential.id
connector_id=connector.id,
credential_id=credential.id,
db_session=db_session,
),
is_deletable=check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair
@ -590,12 +591,8 @@ def create_deletion_attempt_for_connector_id(
"no ongoing / planned indexing attempts.",
)
task_id = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id
)
cleanup_connector_credential_pair_task.apply_async(
kwargs=dict(connector_id=connector_id, credential_id=credential_id),
task_id=task_id,
)

View File

@ -22,10 +22,10 @@ from danswer.db.models import AllowedAnswerFilters
from danswer.db.models import ChannelConfig
from danswer.db.models import Connector
from danswer.db.models import Credential
from danswer.db.models import DeletionStatus
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import TaskStatus
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
@ -326,9 +326,7 @@ class IndexAttemptSnapshot(BaseModel):
class DeletionAttemptSnapshot(BaseModel):
connector_id: int
credential_id: int
status: DeletionStatus
error_msg: str | None
num_docs_deleted: int
status: TaskStatus
class ConnectorBase(BaseModel):

View File

@ -11,8 +11,18 @@ redirect_stderr=true
autorestart=true
# Background jobs that must be run async due to long time to completion
# NOTE: due to an issue with Celery + SQLAlchemy
# (https://github.com/celery/celery/issues/7007#issuecomment-1740139367)
# we must use the threads pool instead of the default prefork pool for now
# in order to avoid intermittent errors like:
# `billiard.exceptions.WorkerLostError: Worker exited prematurely: signal 11 (SIGSEGV)`.
#
# This means workers will not be able take advantage of multiple CPU cores
# on a system, but this should be okay for now since all our celery tasks are
# relatively compute-light (e.g. they tend to just make a bunch of requests to
# Vespa / Postgres)
[program:celery_worker]
command=celery -A danswer.background.celery worker --loglevel=INFO --logfile=/var/log/celery_worker.log
command=celery -A danswer.background.celery worker --pool=threads --loglevel=INFO --logfile=/var/log/celery_worker.log
stdout_logfile=/var/log/celery_worker_supervisor.log
stdout_logfile_maxbytes=52428800
redirect_stderr=true

View File

@ -37,11 +37,7 @@ export function StatusRow<ConnectorConfigType, ConnectorCredentialType>({
}
if (connector.disabled) {
const deletionAttempt = connectorIndexingStatus.deletion_attempt;
if (
!deletionAttempt ||
deletionAttempt.status === "not_started" ||
deletionAttempt.status === "failed"
) {
if (!deletionAttempt || deletionAttempt.status === "FAILURE") {
statusDisplay = <div className="text-red-700">Disabled</div>;
} else {
statusDisplay = <div className="text-red-700">Deleting...</div>;

View File

@ -17,8 +17,8 @@ const SingleUseConnectorStatus = ({
}) => {
if (
deletionAttempt &&
(deletionAttempt.status === "in_progress" ||
deletionAttempt.status === "not_started")
(deletionAttempt.status === "PENDING" ||
deletionAttempt.status === "STARTED")
) {
return <div className="text-red-500">Deleting...</div>;
}

View File

@ -33,6 +33,7 @@ export type ValidStatuses =
| "failed"
| "in_progress"
| "not_started";
export type TaskStatus = "PENDING" | "STARTED" | "SUCCESS" | "FAILURE";
export interface DocumentBoostStatus {
document_id: string;
@ -244,9 +245,7 @@ export interface Document360CredentialJson {
export interface DeletionAttemptSnapshot {
connector_id: number;
credential_id: number;
status: ValidStatuses;
error_msg?: string;
num_docs_deleted: number;
status: TaskStatus;
}
// DOCUMENT SETS