mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-11 05:19:52 +02:00
Fix deletion status display + add celery util + fix seg faults (#615)
This commit is contained in:
parent
8403b94722
commit
89807c8c05
@ -4,10 +4,11 @@ from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery # type: ignore
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.connector_deletion import _delete_connector_credential_pair
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
@ -30,9 +31,6 @@ from danswer.db.engine import SYNC_DB_API
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.tasks import check_live_task_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.db.tasks import mark_task_finished
|
||||
from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@ -43,7 +41,6 @@ celery_backend_url = "db+" + build_connection_string(db_api=SYNC_DB_API)
|
||||
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
|
||||
|
||||
|
||||
_ExistingTaskCache: dict[int, AsyncResult] = {}
|
||||
_SYNC_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
@ -52,6 +49,7 @@ _SYNC_BATCH_SIZE = 1000
|
||||
#
|
||||
# If imports from this module are needed, use local imports to avoid circular importing
|
||||
#####
|
||||
@build_celery_task_wrapper(name_cc_cleanup_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def cleanup_connector_credential_pair_task(
|
||||
connector_id: int,
|
||||
@ -79,7 +77,7 @@ def cleanup_connector_credential_pair_task(
|
||||
|
||||
try:
|
||||
# The bulk of the work is in here, updates Postgres and Vespa
|
||||
return _delete_connector_credential_pair(
|
||||
return delete_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
document_index=get_default_document_index(),
|
||||
cc_pair=cc_pair,
|
||||
@ -89,6 +87,7 @@ def cleanup_connector_credential_pair_task(
|
||||
raise e
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_document_set_sync_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_document_set_task(document_set_id: int) -> None:
|
||||
"""For document sets marked as not up to date, sync the state from postgres
|
||||
@ -125,9 +124,6 @@ def sync_document_set_task(document_set_id: int) -> None:
|
||||
)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
task_name = name_document_set_sync_task(document_set_id)
|
||||
mark_task_start(task_name, db_session)
|
||||
|
||||
try:
|
||||
document_index = get_default_document_index()
|
||||
documents_to_update = fetch_documents_for_document_set(
|
||||
@ -166,11 +162,8 @@ def sync_document_set_task(document_set_id: int) -> None:
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to sync document set %s", document_set_id)
|
||||
mark_task_finished(task_name, db_session, success=False)
|
||||
raise
|
||||
|
||||
mark_task_finished(task_name, db_session)
|
||||
|
||||
|
||||
#####
|
||||
# Periodic Tasks
|
||||
@ -201,10 +194,9 @@ def check_for_document_sets_sync_task() -> None:
|
||||
continue
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now!")
|
||||
task = sync_document_set_task.apply_async(
|
||||
sync_document_set_task.apply_async(
|
||||
kwargs=dict(document_set_id=document_set.id),
|
||||
)
|
||||
register_task(task.id, task_name, db_session)
|
||||
|
||||
|
||||
@celery_app.task(name="clean_old_temp_files_task", soft_time_limit=JOB_TIMEOUT)
|
||||
|
@ -1,75 +1,23 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery import celery_app
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import DeletionStatus
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.server.models import DeletionAttemptSnapshot
|
||||
|
||||
|
||||
def get_celery_task(task_id: str) -> AsyncResult:
|
||||
"""NOTE: even if the task doesn't exist, celery will still return something
|
||||
with a `PENDING` state"""
|
||||
return AsyncResult(task_id, backend=celery_app.backend)
|
||||
|
||||
|
||||
def get_celery_task_status(task_id: str) -> str | None:
|
||||
"""NOTE: is tightly coupled to the internals of kombu (which is the
|
||||
translation layer to allow us to use Postgres as a broker). If we change
|
||||
the broker, this will need to be updated.
|
||||
|
||||
This should not be called on any critical flows.
|
||||
"""
|
||||
# first check for any pending tasks
|
||||
with Session(get_sqlalchemy_engine()) as session:
|
||||
rows = session.execute(text("SELECT payload FROM kombu_message WHERE visible"))
|
||||
for row in rows:
|
||||
payload = json.loads(row[0])
|
||||
if payload["headers"]["id"] == task_id:
|
||||
return "PENDING"
|
||||
|
||||
task = get_celery_task(task_id)
|
||||
# if not pending, then we know the task really exists
|
||||
if task.status != "PENDING":
|
||||
return task.status
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_deletion_status(
|
||||
connector_id: int, credential_id: int
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
cleanup_task_id = name_cc_cleanup_task(
|
||||
cleanup_task_name = name_cc_cleanup_task(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
)
|
||||
deletion_task = get_celery_task(task_id=cleanup_task_id)
|
||||
deletion_task_status = get_celery_task_status(task_id=cleanup_task_id)
|
||||
task_state = get_latest_task(task_name=cleanup_task_name, db_session=db_session)
|
||||
|
||||
deletion_status = None
|
||||
error_msg = None
|
||||
num_docs_deleted = 0
|
||||
if deletion_task_status == "SUCCESS":
|
||||
deletion_status = DeletionStatus.SUCCESS
|
||||
num_docs_deleted = cast(int, deletion_task.get(propagate=False))
|
||||
elif deletion_task_status == "FAILURE":
|
||||
deletion_status = DeletionStatus.FAILED
|
||||
error_msg = deletion_task.get(propagate=False)
|
||||
elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING":
|
||||
deletion_status = DeletionStatus.IN_PROGRESS
|
||||
if not task_state:
|
||||
return None
|
||||
|
||||
return (
|
||||
DeletionAttemptSnapshot(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
status=deletion_status,
|
||||
error_msg=str(error_msg),
|
||||
num_docs_deleted=num_docs_deleted,
|
||||
)
|
||||
if deletion_status
|
||||
else None
|
||||
return DeletionAttemptSnapshot(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
status=task_state.status,
|
||||
)
|
||||
|
@ -144,7 +144,7 @@ def cleanup_synced_entities(
|
||||
)
|
||||
|
||||
|
||||
def _delete_connector_credential_pair(
|
||||
def delete_connector_credential_pair(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
|
@ -1,6 +1,117 @@
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from celery import Task
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.tasks import mark_task_finished
|
||||
from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
|
||||
|
||||
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
|
||||
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
|
||||
|
||||
def name_document_set_sync_task(document_set_id: int) -> str:
|
||||
return f"sync_doc_set_{document_set_id}"
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
def build_run_wrapper(build_name_fn: Callable[..., str]) -> Callable[[T], T]:
|
||||
"""Utility meant to wrap the celery task `run` function in order to
|
||||
automatically update our custom `task_queue_jobs` table appropriately"""
|
||||
|
||||
def wrap_task_fn(task_fn: T) -> T:
|
||||
@wraps(task_fn)
|
||||
def wrapped_task_fn(*args: list, **kwargs: dict) -> Any:
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
task_name = build_name_fn(*args, **kwargs)
|
||||
with Session(engine) as db_session:
|
||||
# mark the task as started
|
||||
mark_task_start(task_name=task_name, db_session=db_session)
|
||||
|
||||
result = None
|
||||
exception = None
|
||||
try:
|
||||
result = task_fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
with Session(engine) as db_session:
|
||||
mark_task_finished(
|
||||
task_name=task_name,
|
||||
db_session=db_session,
|
||||
success=exception is None,
|
||||
)
|
||||
|
||||
if not exception:
|
||||
return result
|
||||
else:
|
||||
raise exception
|
||||
|
||||
return cast(T, wrapped_task_fn)
|
||||
|
||||
return wrap_task_fn
|
||||
|
||||
|
||||
# rough type signature for `apply_async`
|
||||
AA = TypeVar("AA", bound=Callable[..., AsyncResult])
|
||||
|
||||
|
||||
def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA], AA]:
|
||||
"""Utility meant to wrap celery `apply_async` function in order to automatically
|
||||
update create an entry in our `task_queue_jobs` table"""
|
||||
|
||||
def wrapper(fn: AA) -> AA:
|
||||
@wraps(fn)
|
||||
def wrapped_fn(
|
||||
args: tuple | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
*other_args: list,
|
||||
**other_kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
# `apply_async` takes in args / kwargs directly as arguments
|
||||
args_for_build_name = args or tuple()
|
||||
kwargs_for_build_name = kwargs or {}
|
||||
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# mark the task as started
|
||||
task = fn(args, kwargs, *other_args, **other_kwargs)
|
||||
register_task(task.id, task_name, db_session)
|
||||
|
||||
return task
|
||||
|
||||
return cast(AA, wrapped_fn)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def build_celery_task_wrapper(
|
||||
build_name_fn: Callable[..., str]
|
||||
) -> Callable[[Task], Task]:
|
||||
"""Utility meant to wrap celery task functions in order to automatically
|
||||
update our custom `task_queue_jobs` table appropriately.
|
||||
|
||||
On task creation (e.g. `apply_async`), a row is inserted into the table with
|
||||
status `PENDING`.
|
||||
On task start, the latest row is updated to have status `STARTED`.
|
||||
On task success, the latest row is updated to have status `SUCCESS`.
|
||||
On the task raising an unhandled exception, the latest row is updated to have
|
||||
status `FAILURE`.
|
||||
"""
|
||||
|
||||
def wrap_task(task: Task) -> Task:
|
||||
task.run = build_run_wrapper(build_name_fn)(task.run) # type: ignore
|
||||
task.apply_async = build_apply_async_wrapper(build_name_fn)(task.apply_async) # type: ignore
|
||||
return task
|
||||
|
||||
return wrap_task
|
||||
|
@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.background.celery.celery_utils import get_deletion_status
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||
@ -366,7 +365,9 @@ def get_connector_indexing_status(
|
||||
if latest_index_attempt
|
||||
else None,
|
||||
deletion_attempt=get_deletion_status(
|
||||
connector_id=connector.id, credential_id=credential.id
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
db_session=db_session,
|
||||
),
|
||||
is_deletable=check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair
|
||||
@ -590,12 +591,8 @@ def create_deletion_attempt_for_connector_id(
|
||||
"no ongoing / planned indexing attempts.",
|
||||
)
|
||||
|
||||
task_id = name_cc_cleanup_task(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
)
|
||||
cleanup_connector_credential_pair_task.apply_async(
|
||||
kwargs=dict(connector_id=connector_id, credential_id=credential_id),
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
|
||||
|
@ -22,10 +22,10 @@ from danswer.db.models import AllowedAnswerFilters
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import DeletionStatus
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import TaskStatus
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
@ -326,9 +326,7 @@ class IndexAttemptSnapshot(BaseModel):
|
||||
class DeletionAttemptSnapshot(BaseModel):
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
status: DeletionStatus
|
||||
error_msg: str | None
|
||||
num_docs_deleted: int
|
||||
status: TaskStatus
|
||||
|
||||
|
||||
class ConnectorBase(BaseModel):
|
||||
|
@ -11,8 +11,18 @@ redirect_stderr=true
|
||||
autorestart=true
|
||||
|
||||
# Background jobs that must be run async due to long time to completion
|
||||
# NOTE: due to an issue with Celery + SQLAlchemy
|
||||
# (https://github.com/celery/celery/issues/7007#issuecomment-1740139367)
|
||||
# we must use the threads pool instead of the default prefork pool for now
|
||||
# in order to avoid intermittent errors like:
|
||||
# `billiard.exceptions.WorkerLostError: Worker exited prematurely: signal 11 (SIGSEGV)`.
|
||||
#
|
||||
# This means workers will not be able take advantage of multiple CPU cores
|
||||
# on a system, but this should be okay for now since all our celery tasks are
|
||||
# relatively compute-light (e.g. they tend to just make a bunch of requests to
|
||||
# Vespa / Postgres)
|
||||
[program:celery_worker]
|
||||
command=celery -A danswer.background.celery worker --loglevel=INFO --logfile=/var/log/celery_worker.log
|
||||
command=celery -A danswer.background.celery worker --pool=threads --loglevel=INFO --logfile=/var/log/celery_worker.log
|
||||
stdout_logfile=/var/log/celery_worker_supervisor.log
|
||||
stdout_logfile_maxbytes=52428800
|
||||
redirect_stderr=true
|
||||
|
@ -37,11 +37,7 @@ export function StatusRow<ConnectorConfigType, ConnectorCredentialType>({
|
||||
}
|
||||
if (connector.disabled) {
|
||||
const deletionAttempt = connectorIndexingStatus.deletion_attempt;
|
||||
if (
|
||||
!deletionAttempt ||
|
||||
deletionAttempt.status === "not_started" ||
|
||||
deletionAttempt.status === "failed"
|
||||
) {
|
||||
if (!deletionAttempt || deletionAttempt.status === "FAILURE") {
|
||||
statusDisplay = <div className="text-red-700">Disabled</div>;
|
||||
} else {
|
||||
statusDisplay = <div className="text-red-700">Deleting...</div>;
|
||||
|
@ -17,8 +17,8 @@ const SingleUseConnectorStatus = ({
|
||||
}) => {
|
||||
if (
|
||||
deletionAttempt &&
|
||||
(deletionAttempt.status === "in_progress" ||
|
||||
deletionAttempt.status === "not_started")
|
||||
(deletionAttempt.status === "PENDING" ||
|
||||
deletionAttempt.status === "STARTED")
|
||||
) {
|
||||
return <div className="text-red-500">Deleting...</div>;
|
||||
}
|
||||
|
@ -33,6 +33,7 @@ export type ValidStatuses =
|
||||
| "failed"
|
||||
| "in_progress"
|
||||
| "not_started";
|
||||
export type TaskStatus = "PENDING" | "STARTED" | "SUCCESS" | "FAILURE";
|
||||
|
||||
export interface DocumentBoostStatus {
|
||||
document_id: string;
|
||||
@ -244,9 +245,7 @@ export interface Document360CredentialJson {
|
||||
export interface DeletionAttemptSnapshot {
|
||||
connector_id: number;
|
||||
credential_id: number;
|
||||
status: ValidStatuses;
|
||||
error_msg?: string;
|
||||
num_docs_deleted: number;
|
||||
status: TaskStatus;
|
||||
}
|
||||
|
||||
// DOCUMENT SETS
|
||||
|
Loading…
x
Reference in New Issue
Block a user