mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Celery Beat (#575)
This commit is contained in:
parent
a7ddb22e50
commit
b5982c10c3
@ -7,6 +7,7 @@ from danswer.db.models import Base
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
@ -21,7 +22,7 @@ if config.config_file_name is not None:
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = Base.metadata
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
@ -44,7 +45,7 @@ def run_migrations_offline() -> None:
|
||||
url = build_connection_string()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
@ -54,7 +55,7 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
48
backend/alembic/versions/78dbe7e38469_task_tracking.py
Normal file
48
backend/alembic/versions/78dbe7e38469_task_tracking.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""Task Tracking
|
||||
|
||||
Revision ID: 78dbe7e38469
|
||||
Revises: 7ccea01261f6
|
||||
Create Date: 2023-10-15 23:40:50.593262
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "78dbe7e38469"
|
||||
down_revision = "7ccea01261f6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"task_queue_jobs",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("task_id", sa.String(), nullable=False),
|
||||
sa.Column("task_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"PENDING",
|
||||
"STARTED",
|
||||
"SUCCESS",
|
||||
"FAILURE",
|
||||
name="taskstatus",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"register_time",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("task_queue_jobs")
|
@ -1,9 +1,39 @@
|
||||
from celery import Celery
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from danswer.background.connector_deletion import cleanup_connector_credential_pair
|
||||
from celery import Celery # type: ignore
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.connector_deletion import _delete_connector_credential_pair
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.connectors.file.utils import file_age_in_hours
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.document_set import fetch_documents_for_document_set
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import SYNC_DB_API
|
||||
from danswer.document_set.document_set import sync_document_set
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.tasks import check_live_task_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.db.tasks import mark_task_finished
|
||||
from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -13,17 +43,193 @@ celery_backend_url = "db+" + build_connection_string(db_api=SYNC_DB_API)
|
||||
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
|
||||
|
||||
|
||||
@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit
|
||||
_ExistingTaskCache: dict[int, AsyncResult] = {}
|
||||
_SYNC_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
#####
|
||||
# Tasks that need to be run in job queue, registered via APIs
|
||||
#
|
||||
# If imports from this module are needed, use local imports to avoid circular importing
|
||||
#####
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def cleanup_connector_credential_pair_task(
|
||||
connector_id: int, credential_id: int
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> int:
|
||||
return cleanup_connector_credential_pair(connector_id, credential_id)
|
||||
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
|
||||
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
|
||||
or updating the ACL"""
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
# validate that the connector / credential pair is deletable
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
if not cc_pair or not check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
|
||||
"This is likely because there is an ongoing / planned indexing attempt OR the "
|
||||
"connector is not disabled."
|
||||
)
|
||||
|
||||
try:
|
||||
# The bulk of the work is in here, updates Postgres and Vespa
|
||||
return _delete_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
document_index=get_default_document_index(),
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run connector_deletion due to {e}")
|
||||
raise e
|
||||
|
||||
|
||||
@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_document_set_task(document_set_id: int) -> None:
|
||||
try:
|
||||
return sync_document_set(document_set_id=document_set_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to sync document set %s", document_set_id)
|
||||
raise
|
||||
"""For document sets marked as not up to date, sync the state from postgres
|
||||
into the datastore. Also handles deletions."""
|
||||
|
||||
def _sync_document_batch(
|
||||
document_ids: list[str], document_index: DocumentIndex
|
||||
) -> None:
|
||||
logger.debug(f"Syncing document sets for: {document_ids}")
|
||||
# begin a transaction, release lock at the end
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# acquires a lock on the documents so that no other process can modify them
|
||||
prepare_to_modify_documents(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
|
||||
# get current state of document sets for these documents
|
||||
document_set_map = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
}
|
||||
|
||||
# update Vespa
|
||||
document_index.update(
|
||||
update_requests=[
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=set(document_set_map.get(document_id, [])),
|
||||
)
|
||||
for document_id in document_ids
|
||||
]
|
||||
)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
task_name = name_document_set_sync_task(document_set_id)
|
||||
mark_task_start(task_name, db_session)
|
||||
|
||||
try:
|
||||
document_index = get_default_document_index()
|
||||
documents_to_update = fetch_documents_for_document_set(
|
||||
document_set_id=document_set_id,
|
||||
db_session=db_session,
|
||||
current_only=False,
|
||||
)
|
||||
for document_batch in batch_generator(
|
||||
documents_to_update, _SYNC_BATCH_SIZE
|
||||
):
|
||||
_sync_document_batch(
|
||||
document_ids=[document.id for document in document_batch],
|
||||
document_index=document_index,
|
||||
)
|
||||
|
||||
# if there are no connectors, then delete the document set. Otherwise, just
|
||||
# mark it as successfully synced.
|
||||
document_set = cast(
|
||||
DocumentSet,
|
||||
get_document_set_by_id(
|
||||
db_session=db_session, document_set_id=document_set_id
|
||||
),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if not document_set.connector_credential_pairs:
|
||||
delete_document_set(
|
||||
document_set_row=document_set, db_session=db_session
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(
|
||||
document_set_id=document_set_id, db_session=db_session
|
||||
)
|
||||
logger.info(f"Document set sync for '{document_set_id}' complete!")
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to sync document set %s", document_set_id)
|
||||
mark_task_finished(task_name, db_session, success=False)
|
||||
raise
|
||||
|
||||
mark_task_finished(task_name, db_session)
|
||||
|
||||
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
@celery_app.task(
|
||||
name="check_for_document_sets_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_document_sets_sync_task() -> None:
|
||||
"""Runs periodically to check if any document sets are out of sync
|
||||
Creates a task to sync the set if needed"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
if not document_set.is_up_to_date:
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_live_task_not_timed_out(
|
||||
latest_sync, db_session
|
||||
):
|
||||
logger.info(
|
||||
f"Document set '{document_set.id}' is already syncing. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"Document set {document_set.id} syncing now!")
|
||||
task = sync_document_set_task.apply_async(
|
||||
kwargs=dict(document_set_id=document_set.id),
|
||||
)
|
||||
register_task(task.id, task_name, db_session)
|
||||
|
||||
|
||||
@celery_app.task(name="clean_old_temp_files_task", soft_time_limit=JOB_TIMEOUT)
|
||||
def clean_old_temp_files_task(
|
||||
age_threshold_in_hours: float | int = 24 * 7, # 1 week,
|
||||
base_path: Path | str = FILE_CONNECTOR_TMP_STORAGE_PATH,
|
||||
) -> None:
|
||||
"""Files added via the File connector need to be deleted after ingestion
|
||||
Currently handled async of the indexing job"""
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
for file in os.listdir(base_path):
|
||||
if file_age_in_hours(file) > age_threshold_in_hours:
|
||||
os.remove(Path(base_path) / file)
|
||||
|
||||
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"check-for-document-set-sync": {
|
||||
"task": "check_for_document_sets_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
"clean-old-temp-files": {
|
||||
"task": "clean_old_temp_files_task",
|
||||
"schedule": timedelta(minutes=30),
|
||||
},
|
||||
}
|
||||
|
@ -1,11 +1,15 @@
|
||||
import json
|
||||
from typing import cast
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery import celery_app
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import DeletionStatus
|
||||
from danswer.server.models import DeletionAttemptSnapshot
|
||||
|
||||
|
||||
def get_celery_task(task_id: str) -> AsyncResult:
|
||||
@ -35,3 +39,37 @@ def get_celery_task_status(task_id: str) -> str | None:
|
||||
return task.status
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_deletion_status(
|
||||
connector_id: int, credential_id: int
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
cleanup_task_id = name_cc_cleanup_task(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
)
|
||||
deletion_task = get_celery_task(task_id=cleanup_task_id)
|
||||
deletion_task_status = get_celery_task_status(task_id=cleanup_task_id)
|
||||
|
||||
deletion_status = None
|
||||
error_msg = None
|
||||
num_docs_deleted = 0
|
||||
if deletion_task_status == "SUCCESS":
|
||||
deletion_status = DeletionStatus.SUCCESS
|
||||
num_docs_deleted = cast(int, deletion_task.get(propagate=False))
|
||||
elif deletion_task_status == "FAILURE":
|
||||
deletion_status = DeletionStatus.FAILED
|
||||
error_msg = deletion_task.get(propagate=False)
|
||||
elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING":
|
||||
deletion_status = DeletionStatus.IN_PROGRESS
|
||||
|
||||
return (
|
||||
DeletionAttemptSnapshot(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
status=deletion_status,
|
||||
error_msg=str(error_msg),
|
||||
num_docs_deleted=num_docs_deleted,
|
||||
)
|
||||
if deletion_status
|
||||
else None
|
||||
)
|
||||
|
@ -1,41 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.background.celery.celery_utils import get_celery_task
|
||||
from danswer.background.celery.celery_utils import get_celery_task_status
|
||||
from danswer.background.connector_deletion import get_cleanup_task_id
|
||||
from danswer.db.models import DeletionStatus
|
||||
from danswer.server.models import DeletionAttemptSnapshot
|
||||
|
||||
|
||||
def get_deletion_status(
|
||||
connector_id: int, credential_id: int
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
cleanup_task_id = get_cleanup_task_id(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
)
|
||||
deletion_task = get_celery_task(task_id=cleanup_task_id)
|
||||
deletion_task_status = get_celery_task_status(task_id=cleanup_task_id)
|
||||
|
||||
deletion_status = None
|
||||
error_msg = None
|
||||
num_docs_deleted = 0
|
||||
if deletion_task_status == "SUCCESS":
|
||||
deletion_status = DeletionStatus.SUCCESS
|
||||
num_docs_deleted = cast(int, deletion_task.get(propagate=False))
|
||||
elif deletion_task_status == "FAILURE":
|
||||
deletion_status = DeletionStatus.FAILED
|
||||
error_msg = deletion_task.get(propagate=False)
|
||||
elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING":
|
||||
deletion_status = DeletionStatus.IN_PROGRESS
|
||||
|
||||
return (
|
||||
DeletionAttemptSnapshot(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
status=deletion_status,
|
||||
error_msg=str(error_msg),
|
||||
num_docs_deleted=num_docs_deleted,
|
||||
)
|
||||
if deletion_status
|
||||
else None
|
||||
)
|
@ -17,15 +17,12 @@ from typing import cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair
|
||||
from danswer.db.document import delete_documents_complete
|
||||
from danswer.db.document import get_document_connector_cnts
|
||||
@ -211,39 +208,3 @@ def _delete_connector_credential_pair(
|
||||
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
|
||||
)
|
||||
return num_docs_deleted
|
||||
|
||||
|
||||
def cleanup_connector_credential_pair(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> int:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
# validate that the connector / credential pair is deletable
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
if not cc_pair or not check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
|
||||
"This is likely because there is an ongoing / planned indexing attempt OR the "
|
||||
"connector is not disabled."
|
||||
)
|
||||
|
||||
try:
|
||||
return _delete_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
document_index=get_default_document_index(),
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run connector_deletion due to {e}")
|
||||
raise e
|
||||
|
||||
|
||||
def get_cleanup_task_id(connector_id: int, credential_id: int) -> str:
|
||||
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
|
@ -1,55 +0,0 @@
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery import sync_document_set_task
|
||||
from danswer.background.utils import interval_run_job
|
||||
from danswer.db.document_set import (
|
||||
fetch_document_sets,
|
||||
)
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_ExistingTaskCache: dict[int, AsyncResult] = {}
|
||||
|
||||
|
||||
def _document_sync_loop() -> None:
|
||||
# cleanup tasks
|
||||
existing_tasks = list(_ExistingTaskCache.items())
|
||||
for document_set_id, task in existing_tasks:
|
||||
if task.ready():
|
||||
logger.info(
|
||||
f"Document set '{document_set_id}' is complete with status "
|
||||
f"{task.status}. Cleaning up."
|
||||
)
|
||||
del _ExistingTaskCache[document_set_id]
|
||||
|
||||
# kick off new tasks
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
db_session=db_session, include_outdated=True
|
||||
)
|
||||
for document_set, _ in document_set_info:
|
||||
if not document_set.is_up_to_date:
|
||||
if document_set.id in _ExistingTaskCache:
|
||||
logger.info(
|
||||
f"Document set '{document_set.id}' is already syncing. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Document set {document_set.id} is not synced. Syncing now!"
|
||||
)
|
||||
task = sync_document_set_task.apply_async(
|
||||
kwargs=dict(document_set_id=document_set.id),
|
||||
)
|
||||
_ExistingTaskCache[document_set.id] = task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interval_run_job(
|
||||
job=_document_sync_loop, delay=5, emit_job_start_log=False
|
||||
) # run every 5 seconds
|
@ -1,6 +0,0 @@
|
||||
from danswer.background.utils import interval_run_job
|
||||
from danswer.connectors.file.utils import clean_old_temp_files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interval_run_job(clean_old_temp_files, 30 * 60) # run every 30 minutes
|
6
backend/danswer/background/task_utils.py
Normal file
6
backend/danswer/background/task_utils.py
Normal file
@ -0,0 +1,6 @@
|
||||
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
|
||||
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
|
||||
|
||||
def name_document_set_sync_task(document_set_id: int) -> str:
|
||||
return f"sync_doc_set_{document_set_id}"
|
@ -1,24 +0,0 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def interval_run_job(
|
||||
job: Callable[[], Any], delay: int | float, emit_job_start_log: bool = True
|
||||
) -> None:
|
||||
while True:
|
||||
start = time.time()
|
||||
if emit_job_start_log:
|
||||
logger.info(f"Running '{job.__name__}', current time: {time.ctime(start)}")
|
||||
try:
|
||||
job()
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run update due to {e}")
|
||||
sleep_time = delay - (time.time() - start)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
@ -211,7 +211,7 @@ CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
|
||||
# fairly large amount of memory in order to increase substantially, since
|
||||
# each worker loads the embedding models into memory.
|
||||
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
|
||||
|
||||
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
|
||||
# Logs every model prompt and output, mostly used for development or exploration purposes
|
||||
LOG_ALL_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
|
@ -8,7 +8,6 @@ from typing import IO
|
||||
|
||||
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
|
||||
|
||||
_FILE_AGE_CLEANUP_THRESHOLD_HOURS = 24 * 7 # 1 week
|
||||
_VALID_FILE_EXTENSIONS = [".txt", ".zip", ".pdf"]
|
||||
|
||||
|
||||
@ -53,13 +52,3 @@ def write_temp_files(
|
||||
|
||||
def file_age_in_hours(filepath: str | Path) -> float:
|
||||
return (time.time() - os.path.getmtime(filepath)) / (60 * 60)
|
||||
|
||||
|
||||
def clean_old_temp_files(
|
||||
age_threshold_in_hours: float | int = _FILE_AGE_CLEANUP_THRESHOLD_HOURS,
|
||||
base_path: Path | str = FILE_CONNECTOR_TMP_STORAGE_PATH,
|
||||
) -> None:
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
for file in os.listdir(base_path):
|
||||
if file_age_in_hours(file) > age_threshold_in_hours:
|
||||
os.remove(Path(base_path) / file)
|
||||
|
@ -52,6 +52,14 @@ class DeletionStatus(str, PyEnum):
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
# Consistent with Celery task statuses
|
||||
class TaskStatus(str, PyEnum):
|
||||
PENDING = "PENDING"
|
||||
STARTED = "STARTED"
|
||||
SUCCESS = "SUCCESS"
|
||||
FAILURE = "FAILURE"
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
@ -566,3 +574,22 @@ class SlackBotConfig(Base):
|
||||
)
|
||||
|
||||
persona: Mapped[Persona | None] = relationship("Persona")
|
||||
|
||||
|
||||
class TaskQueueState(Base):
|
||||
# Currently refers to Celery Tasks
|
||||
__tablename__ = "task_queue_jobs"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
# Celery task id
|
||||
task_id: Mapped[str] = mapped_column(String)
|
||||
# For any job type, this would be the same
|
||||
task_name: Mapped[str] = mapped_column(String)
|
||||
# Note that if the task dies, this won't necessarily be marked FAILED correctly
|
||||
status: Mapped[TaskStatus] = mapped_column(Enum(TaskStatus))
|
||||
start_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True)
|
||||
)
|
||||
register_time: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
85
backend/danswer/db/tasks.py
Normal file
85
backend/danswer/db/tasks.py
Normal file
@ -0,0 +1,85 @@
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.db.models import TaskStatus
|
||||
|
||||
|
||||
def get_latest_task(
|
||||
task_name: str,
|
||||
db_session: Session,
|
||||
) -> TaskQueueState | None:
|
||||
stmt = (
|
||||
select(TaskQueueState)
|
||||
.where(TaskQueueState.task_name == task_name)
|
||||
.order_by(desc(TaskQueueState.id))
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
latest_task = result.scalars().first()
|
||||
|
||||
return latest_task
|
||||
|
||||
|
||||
def register_task(
|
||||
task_id: str,
|
||||
task_name: str,
|
||||
db_session: Session,
|
||||
) -> TaskQueueState:
|
||||
new_task = TaskQueueState(
|
||||
task_id=task_id, task_name=task_name, status=TaskStatus.PENDING
|
||||
)
|
||||
|
||||
db_session.add(new_task)
|
||||
db_session.commit()
|
||||
|
||||
return new_task
|
||||
|
||||
|
||||
def mark_task_start(
|
||||
task_name: str,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
task = get_latest_task(task_name, db_session)
|
||||
if not task:
|
||||
raise ValueError(f"No task found with name {task_name}")
|
||||
|
||||
task.start_time = func.now() # type: ignore
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_task_finished(
|
||||
task_name: str,
|
||||
db_session: Session,
|
||||
success: bool = True,
|
||||
) -> None:
|
||||
latest_task = get_latest_task(task_name, db_session)
|
||||
if latest_task is None:
|
||||
raise ValueError(f"tasks for {task_name} do not exist")
|
||||
|
||||
latest_task.status = TaskStatus.SUCCESS if success else TaskStatus.FAILURE
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def check_live_task_not_timed_out(
|
||||
task: TaskQueueState,
|
||||
db_session: Session,
|
||||
timeout: int = JOB_TIMEOUT,
|
||||
) -> bool:
|
||||
# We only care for live tasks to not create new periodic tasks
|
||||
if task.status in [TaskStatus.SUCCESS, TaskStatus.FAILURE]:
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session=db_session)
|
||||
|
||||
last_update_time = task.register_time
|
||||
if task.start_time:
|
||||
last_update_time = max(task.register_time, task.start_time)
|
||||
|
||||
time_elapsed = current_db_time - last_update_time
|
||||
return time_elapsed.total_seconds() < timeout
|
@ -1,84 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.document_set import fetch_documents_for_document_set
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SYNC_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def _sync_document_batch(
|
||||
document_ids: list[str], document_index: DocumentIndex
|
||||
) -> None:
|
||||
logger.debug(f"Syncing document sets for: {document_ids}")
|
||||
# begin a transaction, release lock at the end
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# acquires a lock on the documents so that no other process can modify them
|
||||
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
|
||||
|
||||
# get current state of document sets for these documents
|
||||
document_set_map = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
}
|
||||
|
||||
# update Vespa
|
||||
document_index.update(
|
||||
update_requests=[
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=set(document_set_map.get(document_id, [])),
|
||||
)
|
||||
for document_id in document_ids
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def sync_document_set(document_set_id: int) -> None:
|
||||
document_index = get_default_document_index()
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
documents_to_update = fetch_documents_for_document_set(
|
||||
document_set_id=document_set_id,
|
||||
db_session=db_session,
|
||||
current_only=False,
|
||||
)
|
||||
for document_batch in batch_generator(documents_to_update, _SYNC_BATCH_SIZE):
|
||||
_sync_document_batch(
|
||||
document_ids=[document.id for document in document_batch],
|
||||
document_index=document_index,
|
||||
)
|
||||
|
||||
# if there are no connectors, then delete the document set. Otherwise, just
|
||||
# mark it as successfully synced.
|
||||
document_set = cast(
|
||||
DocumentSet,
|
||||
get_document_set_by_id(
|
||||
db_session=db_session, document_set_id=document_set_id
|
||||
),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if not document_set.connector_credential_pairs:
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(
|
||||
document_set_id=document_set_id, db_session=db_session
|
||||
)
|
||||
logger.info(f"Document set sync for '{document_set_id}' complete!")
|
@ -14,11 +14,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.background.celery.celery import cleanup_connector_credential_pair_task
|
||||
from danswer.background.celery.deletion_utils import get_deletion_status
|
||||
from danswer.background.connector_deletion import (
|
||||
get_cleanup_task_id,
|
||||
)
|
||||
from danswer.background.celery.celery_utils import get_deletion_status
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||
@ -536,6 +533,8 @@ def create_deletion_attempt_for_connector_id(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
from danswer.background.celery.celery import cleanup_connector_credential_pair_task
|
||||
|
||||
connector_id = connector_credential_pair_identifier.connector_id
|
||||
credential_id = connector_credential_pair_identifier.credential_id
|
||||
|
||||
@ -559,7 +558,7 @@ def create_deletion_attempt_for_connector_id(
|
||||
"no ongoing / planned indexing attempts.",
|
||||
)
|
||||
|
||||
task_id = get_cleanup_task_id(
|
||||
task_id = name_cc_cleanup_task(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
)
|
||||
cleanup_connector_credential_pair_task.apply_async(
|
||||
|
@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any
|
||||
|
||||
@ -52,7 +53,9 @@ class _IndexAttemptLoggingAdapter(logging.LoggerAdapter):
|
||||
|
||||
|
||||
def setup_logger(
|
||||
name: str = __name__, log_level: int = get_log_level_from_str()
|
||||
name: str = __name__,
|
||||
log_level: int = get_log_level_from_str(),
|
||||
logfile_name: str | None = None,
|
||||
) -> logging.LoggerAdapter:
|
||||
logger = logging.getLogger(name)
|
||||
|
||||
@ -73,4 +76,12 @@ def setup_logger(
|
||||
|
||||
logger.addHandler(handler)
|
||||
|
||||
if logfile_name:
|
||||
is_containerized = os.path.exists("/.dockerenv")
|
||||
file_name_template = (
|
||||
"/var/log/{name}.log" if is_containerized else "./log/{name}.log"
|
||||
)
|
||||
file_handler = logging.FileHandler(file_name_template.format(name=logfile_name))
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return _IndexAttemptLoggingAdapter(logger)
|
||||
|
52
backend/scripts/dev_run_celery.py
Normal file
52
backend/scripts/dev_run_celery.py
Normal file
@ -0,0 +1,52 @@
|
||||
import subprocess
|
||||
import threading
|
||||
|
||||
|
||||
def monitor_process(process_name: str, process: subprocess.Popen) -> None:
|
||||
assert process.stdout is not None
|
||||
|
||||
while True:
|
||||
output = process.stdout.readline()
|
||||
|
||||
if output:
|
||||
print(f"{process_name}: {output.strip()}")
|
||||
|
||||
if process.poll() is not None:
|
||||
break
|
||||
|
||||
|
||||
def run_celery() -> None:
|
||||
cmd_worker = [
|
||||
"celery",
|
||||
"-A",
|
||||
"danswer.background.celery",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--concurrency=1",
|
||||
]
|
||||
cmd_beat = ["celery", "-A", "danswer.background.celery", "beat", "--loglevel=INFO"]
|
||||
|
||||
# Redirect stderr to stdout for both processes
|
||||
worker_process = subprocess.Popen(
|
||||
cmd_worker, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
beat_process = subprocess.Popen(
|
||||
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
# Monitor outputs using threads
|
||||
worker_thread = threading.Thread(
|
||||
target=monitor_process, args=("WORKER", worker_process)
|
||||
)
|
||||
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
|
||||
|
||||
worker_thread.start()
|
||||
beat_thread.start()
|
||||
|
||||
# Wait for threads to finish
|
||||
worker_thread.join()
|
||||
beat_thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_celery()
|
@ -10,28 +10,23 @@ stdout_logfile_maxbytes=52428800
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
|
||||
[program:celery]
|
||||
command=celery -A danswer.background.celery worker --loglevel=INFO
|
||||
stdout_logfile=/var/log/celery.log
|
||||
# Background jobs that must be run async due to long time to completion
|
||||
[program:celery_worker]
|
||||
command=celery -A danswer.background.celery worker --loglevel=INFO --logfile=/var/log/celery_worker.log
|
||||
stdout_logfile=/var/log/celery_worker_supervisor.log
|
||||
stdout_logfile_maxbytes=52428800
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
|
||||
[program:file_deletion]
|
||||
command=python danswer/background/file_deletion.py
|
||||
stdout_logfile=/var/log/file_deletion.log
|
||||
# Job scheduler for periodic tasks
|
||||
[program:celery_beat]
|
||||
command=celery -A danswer.background.celery beat --loglevel=INFO --logfile=/var/log/celery_beat.log
|
||||
stdout_logfile=/var/log/celery_beat_supervisor.log
|
||||
stdout_logfile_maxbytes=52428800
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
|
||||
[program:document_set_sync]
|
||||
command=python danswer/background/document_set_sync_script.py
|
||||
stdout_logfile=/var/log/document_set_sync.log
|
||||
stdout_logfile_maxbytes=52428800
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
|
||||
# Listens for slack messages and responds with answers
|
||||
# Listens for Slack messages and responds with answers
|
||||
# for all channels that the DanswerBot has been added to.
|
||||
# If not setup, this will just fail 5 times and then stop.
|
||||
# More details on setup here: https://docs.danswer.dev/slack_bot_setup
|
||||
@ -44,9 +39,9 @@ autorestart=true
|
||||
startretries=5
|
||||
startsecs=60
|
||||
|
||||
# pushes all logs from the above programs to stdout
|
||||
# Pushes all logs from the above programs to stdout
|
||||
[program:log-redirect-handler]
|
||||
command=tail -qF /var/log/update.log /var/log/celery.log /var/log/file_deletion.log /var/log/slack_bot_listener.log /var/log/document_set_sync.log
|
||||
command=tail -qF /var/log/update.log /var/log/celery_worker.log /var/log/celery_worker_supervisor.log /var/log/celery_beat.log /var/log/celery_beat_supervisor.log /var/log/slack_bot_listener.log
|
||||
stdout_logfile=/dev/stdout
|
||||
stdout_logfile_maxbytes=0
|
||||
redirect_stderr=true
|
||||
|
Loading…
x
Reference in New Issue
Block a user