Celery Beat (#575)

This commit is contained in:
Yuhong Sun
2023-10-16 14:59:42 -07:00
committed by GitHub
parent a7ddb22e50
commit b5982c10c3
19 changed files with 507 additions and 299 deletions

View File

@ -7,6 +7,7 @@ from danswer.db.models import Base
from sqlalchemy import pool from sqlalchemy import pool
from sqlalchemy.engine import Connection from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
@ -21,7 +22,7 @@ if config.config_file_name is not None:
# for 'autogenerate' support # for 'autogenerate' support
# from myapp import mymodel # from myapp import mymodel
# target_metadata = mymodel.Base.metadata # target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata target_metadata = [Base.metadata, ResultModelBase.metadata]
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
# can be acquired: # can be acquired:
@ -44,7 +45,7 @@ def run_migrations_offline() -> None:
url = build_connection_string() url = build_connection_string()
context.configure( context.configure(
url=url, url=url,
target_metadata=target_metadata, target_metadata=target_metadata, # type: ignore
literal_binds=True, literal_binds=True,
dialect_opts={"paramstyle": "named"}, dialect_opts={"paramstyle": "named"},
) )
@ -54,7 +55,7 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None: def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata) context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
with context.begin_transaction(): with context.begin_transaction():
context.run_migrations() context.run_migrations()

View File

@ -0,0 +1,48 @@
"""Task Tracking
Revision ID: 78dbe7e38469
Revises: 7ccea01261f6
Create Date: 2023-10-15 23:40:50.593262
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "78dbe7e38469"
down_revision = "7ccea01261f6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"task_queue_jobs",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("task_id", sa.String(), nullable=False),
sa.Column("task_name", sa.String(), nullable=False),
sa.Column(
"status",
sa.Enum(
"PENDING",
"STARTED",
"SUCCESS",
"FAILURE",
name="taskstatus",
native_enum=False,
),
nullable=False,
),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"register_time",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("task_queue_jobs")

View File

@ -1,9 +1,39 @@
from celery import Celery import os
from datetime import timedelta
from pathlib import Path
from typing import cast
from danswer.background.connector_deletion import cleanup_connector_credential_pair from celery import Celery # type: ignore
from celery.result import AsyncResult
from sqlalchemy.orm import Session
from danswer.background.connector_deletion import _delete_connector_credential_pair
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.connectors.file.utils import file_age_in_hours
from danswer.datastores.document_index import get_default_document_index
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import UpdateRequest
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import fetch_documents_for_document_set
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import build_connection_string from danswer.db.engine import build_connection_string
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import SYNC_DB_API from danswer.db.engine import SYNC_DB_API
from danswer.document_set.document_set import sync_document_set from danswer.db.models import DocumentSet
from danswer.db.tasks import check_live_task_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import mark_task_finished
from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@ -13,17 +43,193 @@ celery_backend_url = "db+" + build_connection_string(db_api=SYNC_DB_API)
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url) celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit _ExistingTaskCache: dict[int, AsyncResult] = {}
_SYNC_BATCH_SIZE = 1000
#####
# Tasks that need to be run in job queue, registered via APIs
#
# If imports from this module are needed, use local imports to avoid circular importing
#####
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def cleanup_connector_credential_pair_task( def cleanup_connector_credential_pair_task(
connector_id: int, credential_id: int connector_id: int,
credential_id: int,
) -> int: ) -> int:
return cleanup_connector_credential_pair(connector_id, credential_id) """Connector deletion task. This is run as an async task because it is a somewhat slow job.
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
or updating the ACL"""
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair or not check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair
):
raise ValueError(
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
"This is likely because there is an ongoing / planned indexing attempt OR the "
"connector is not disabled."
)
try:
# The bulk of the work is in here, updates Postgres and Vespa
return _delete_connector_credential_pair(
db_session=db_session,
document_index=get_default_document_index(),
cc_pair=cc_pair,
)
except Exception as e:
logger.exception(f"Failed to run connector_deletion due to {e}")
raise e
@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit @celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int) -> None: def sync_document_set_task(document_set_id: int) -> None:
try: """For document sets marked as not up to date, sync the state from postgres
return sync_document_set(document_set_id=document_set_id) into the datastore. Also handles deletions."""
except Exception:
logger.exception("Failed to sync document set %s", document_set_id) def _sync_document_batch(
raise document_ids: list[str], document_index: DocumentIndex
) -> None:
logger.debug(f"Syncing document sets for: {document_ids}")
# begin a transaction, release lock at the end
with Session(get_sqlalchemy_engine()) as db_session:
# acquires a lock on the documents so that no other process can modify them
prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
)
# get current state of document sets for these documents
document_set_map = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=document_ids, db_session=db_session
)
}
# update Vespa
document_index.update(
update_requests=[
UpdateRequest(
document_ids=[document_id],
document_sets=set(document_set_map.get(document_id, [])),
)
for document_id in document_ids
]
)
with Session(get_sqlalchemy_engine()) as db_session:
task_name = name_document_set_sync_task(document_set_id)
mark_task_start(task_name, db_session)
try:
document_index = get_default_document_index()
documents_to_update = fetch_documents_for_document_set(
document_set_id=document_set_id,
db_session=db_session,
current_only=False,
)
for document_batch in batch_generator(
documents_to_update, _SYNC_BATCH_SIZE
):
_sync_document_batch(
document_ids=[document.id for document in document_batch],
document_index=document_index,
)
# if there are no connectors, then delete the document set. Otherwise, just
# mark it as successfully synced.
document_set = cast(
DocumentSet,
get_document_set_by_id(
db_session=db_session, document_set_id=document_set_id
),
) # casting since we "know" a document set with this ID exists
if not document_set.connector_credential_pairs:
delete_document_set(
document_set_row=document_set, db_session=db_session
)
logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(
document_set_id=document_set_id, db_session=db_session
)
logger.info(f"Document set sync for '{document_set_id}' complete!")
except Exception:
logger.exception("Failed to sync document set %s", document_set_id)
mark_task_finished(task_name, db_session, success=False)
raise
mark_task_finished(task_name, db_session)
#####
# Periodic Tasks
#####
@celery_app.task(
name="check_for_document_sets_sync_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_document_sets_sync_task() -> None:
"""Runs periodically to check if any document sets are out of sync
Creates a task to sync the set if needed"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
if not document_set.is_up_to_date:
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_live_task_not_timed_out(
latest_sync, db_session
):
logger.info(
f"Document set '{document_set.id}' is already syncing. Skipping."
)
continue
logger.info(f"Document set {document_set.id} syncing now!")
task = sync_document_set_task.apply_async(
kwargs=dict(document_set_id=document_set.id),
)
register_task(task.id, task_name, db_session)
@celery_app.task(name="clean_old_temp_files_task", soft_time_limit=JOB_TIMEOUT)
def clean_old_temp_files_task(
age_threshold_in_hours: float | int = 24 * 7, # 1 week,
base_path: Path | str = FILE_CONNECTOR_TMP_STORAGE_PATH,
) -> None:
"""Files added via the File connector need to be deleted after ingestion
Currently handled async of the indexing job"""
os.makedirs(base_path, exist_ok=True)
for file in os.listdir(base_path):
if file_age_in_hours(file) > age_threshold_in_hours:
os.remove(Path(base_path) / file)
#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"check-for-document-set-sync": {
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
},
"clean-old-temp-files": {
"task": "clean_old_temp_files_task",
"schedule": timedelta(minutes=30),
},
}

View File

@ -1,11 +1,15 @@
import json import json
from typing import cast
from celery.result import AsyncResult from celery.result import AsyncResult
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.background.celery.celery import celery_app from danswer.background.celery.celery import celery_app
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DeletionStatus
from danswer.server.models import DeletionAttemptSnapshot
def get_celery_task(task_id: str) -> AsyncResult: def get_celery_task(task_id: str) -> AsyncResult:
@ -35,3 +39,37 @@ def get_celery_task_status(task_id: str) -> str | None:
return task.status return task.status
return None return None
def get_deletion_status(
connector_id: int, credential_id: int
) -> DeletionAttemptSnapshot | None:
cleanup_task_id = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id
)
deletion_task = get_celery_task(task_id=cleanup_task_id)
deletion_task_status = get_celery_task_status(task_id=cleanup_task_id)
deletion_status = None
error_msg = None
num_docs_deleted = 0
if deletion_task_status == "SUCCESS":
deletion_status = DeletionStatus.SUCCESS
num_docs_deleted = cast(int, deletion_task.get(propagate=False))
elif deletion_task_status == "FAILURE":
deletion_status = DeletionStatus.FAILED
error_msg = deletion_task.get(propagate=False)
elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING":
deletion_status = DeletionStatus.IN_PROGRESS
return (
DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=deletion_status,
error_msg=str(error_msg),
num_docs_deleted=num_docs_deleted,
)
if deletion_status
else None
)

View File

@ -1,41 +0,0 @@
from typing import cast
from danswer.background.celery.celery_utils import get_celery_task
from danswer.background.celery.celery_utils import get_celery_task_status
from danswer.background.connector_deletion import get_cleanup_task_id
from danswer.db.models import DeletionStatus
from danswer.server.models import DeletionAttemptSnapshot
def get_deletion_status(
connector_id: int, credential_id: int
) -> DeletionAttemptSnapshot | None:
cleanup_task_id = get_cleanup_task_id(
connector_id=connector_id, credential_id=credential_id
)
deletion_task = get_celery_task(task_id=cleanup_task_id)
deletion_task_status = get_celery_task_status(task_id=cleanup_task_id)
deletion_status = None
error_msg = None
num_docs_deleted = 0
if deletion_task_status == "SUCCESS":
deletion_status = DeletionStatus.SUCCESS
num_docs_deleted = cast(int, deletion_task.get(propagate=False))
elif deletion_task_status == "FAILURE":
deletion_status = DeletionStatus.FAILED
error_msg = deletion_task.get(propagate=False)
elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING":
deletion_status = DeletionStatus.IN_PROGRESS
return (
DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=deletion_status,
error_msg=str(error_msg),
num_docs_deleted=num_docs_deleted,
)
if deletion_status
else None
)

View File

@ -17,15 +17,12 @@ from typing import cast
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents from danswer.access.access import get_access_for_documents
from danswer.datastores.document_index import get_default_document_index
from danswer.datastores.interfaces import DocumentIndex from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import UpdateRequest from danswer.datastores.interfaces import UpdateRequest
from danswer.db.connector import fetch_connector_by_id from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import ( from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit, delete_connector_credential_pair__no_commit,
) )
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import delete_document_by_connector_credential_pair from danswer.db.document import delete_document_by_connector_credential_pair
from danswer.db.document import delete_documents_complete from danswer.db.document import delete_documents_complete
from danswer.db.document import get_document_connector_cnts from danswer.db.document import get_document_connector_cnts
@ -211,39 +208,3 @@ def _delete_connector_credential_pair(
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs." f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
) )
return num_docs_deleted return num_docs_deleted
def cleanup_connector_credential_pair(
connector_id: int,
credential_id: int,
) -> int:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair or not check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair
):
raise ValueError(
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
"This is likely because there is an ongoing / planned indexing attempt OR the "
"connector is not disabled."
)
try:
return _delete_connector_credential_pair(
db_session=db_session,
document_index=get_default_document_index(),
cc_pair=cc_pair,
)
except Exception as e:
logger.exception(f"Failed to run connector_deletion due to {e}")
raise e
def get_cleanup_task_id(connector_id: int, credential_id: int) -> str:
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"

View File

@ -1,55 +0,0 @@
from celery.result import AsyncResult
from sqlalchemy.orm import Session
from danswer.background.celery.celery import sync_document_set_task
from danswer.background.utils import interval_run_job
from danswer.db.document_set import (
fetch_document_sets,
)
from danswer.db.engine import get_sqlalchemy_engine
from danswer.utils.logger import setup_logger
logger = setup_logger()
_ExistingTaskCache: dict[int, AsyncResult] = {}
def _document_sync_loop() -> None:
# cleanup tasks
existing_tasks = list(_ExistingTaskCache.items())
for document_set_id, task in existing_tasks:
if task.ready():
logger.info(
f"Document set '{document_set_id}' is complete with status "
f"{task.status}. Cleaning up."
)
del _ExistingTaskCache[document_set_id]
# kick off new tasks
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
if not document_set.is_up_to_date:
if document_set.id in _ExistingTaskCache:
logger.info(
f"Document set '{document_set.id}' is already syncing. Skipping."
)
continue
logger.info(
f"Document set {document_set.id} is not synced. Syncing now!"
)
task = sync_document_set_task.apply_async(
kwargs=dict(document_set_id=document_set.id),
)
_ExistingTaskCache[document_set.id] = task
if __name__ == "__main__":
interval_run_job(
job=_document_sync_loop, delay=5, emit_job_start_log=False
) # run every 5 seconds

View File

@ -1,6 +0,0 @@
from danswer.background.utils import interval_run_job
from danswer.connectors.file.utils import clean_old_temp_files
if __name__ == "__main__":
interval_run_job(clean_old_temp_files, 30 * 60) # run every 30 minutes

View File

@ -0,0 +1,6 @@
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
def name_document_set_sync_task(document_set_id: int) -> str:
return f"sync_doc_set_{document_set_id}"

View File

@ -1,24 +0,0 @@
import time
from collections.abc import Callable
from typing import Any
from danswer.utils.logger import setup_logger
logger = setup_logger()
def interval_run_job(
job: Callable[[], Any], delay: int | float, emit_job_start_log: bool = True
) -> None:
while True:
start = time.time()
if emit_job_start_log:
logger.info(f"Running '{job.__name__}', current time: {time.ctime(start)}")
try:
job()
except Exception as e:
logger.exception(f"Failed to run update due to {e}")
sleep_time = delay - (time.time() - start)
if sleep_time > 0:
time.sleep(sleep_time)

View File

@ -211,7 +211,7 @@ CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
# fairly large amount of memory in order to increase substantially, since # fairly large amount of memory in order to increase substantially, since
# each worker loads the embedding models into memory. # each worker loads the embedding models into memory.
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1) NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
# Logs every model prompt and output, mostly used for development or exploration purposes # Logs every model prompt and output, mostly used for development or exploration purposes
LOG_ALL_MODEL_INTERACTIONS = ( LOG_ALL_MODEL_INTERACTIONS = (
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true" os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"

View File

@ -8,7 +8,6 @@ from typing import IO
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
_FILE_AGE_CLEANUP_THRESHOLD_HOURS = 24 * 7 # 1 week
_VALID_FILE_EXTENSIONS = [".txt", ".zip", ".pdf"] _VALID_FILE_EXTENSIONS = [".txt", ".zip", ".pdf"]
@ -53,13 +52,3 @@ def write_temp_files(
def file_age_in_hours(filepath: str | Path) -> float: def file_age_in_hours(filepath: str | Path) -> float:
return (time.time() - os.path.getmtime(filepath)) / (60 * 60) return (time.time() - os.path.getmtime(filepath)) / (60 * 60)
def clean_old_temp_files(
age_threshold_in_hours: float | int = _FILE_AGE_CLEANUP_THRESHOLD_HOURS,
base_path: Path | str = FILE_CONNECTOR_TMP_STORAGE_PATH,
) -> None:
os.makedirs(base_path, exist_ok=True)
for file in os.listdir(base_path):
if file_age_in_hours(file) > age_threshold_in_hours:
os.remove(Path(base_path) / file)

View File

@ -52,6 +52,14 @@ class DeletionStatus(str, PyEnum):
FAILED = "failed" FAILED = "failed"
# Consistent with Celery task statuses
class TaskStatus(str, PyEnum):
PENDING = "PENDING"
STARTED = "STARTED"
SUCCESS = "SUCCESS"
FAILURE = "FAILURE"
class Base(DeclarativeBase): class Base(DeclarativeBase):
pass pass
@ -566,3 +574,22 @@ class SlackBotConfig(Base):
) )
persona: Mapped[Persona | None] = relationship("Persona") persona: Mapped[Persona | None] = relationship("Persona")
class TaskQueueState(Base):
# Currently refers to Celery Tasks
__tablename__ = "task_queue_jobs"
id: Mapped[int] = mapped_column(primary_key=True)
# Celery task id
task_id: Mapped[str] = mapped_column(String)
# For any job type, this would be the same
task_name: Mapped[str] = mapped_column(String)
# Note that if the task dies, this won't necessarily be marked FAILED correctly
status: Mapped[TaskStatus] = mapped_column(Enum(TaskStatus))
start_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True)
)
register_time: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)

View File

@ -0,0 +1,85 @@
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.db.engine import get_db_current_time
from danswer.db.models import TaskQueueState
from danswer.db.models import TaskStatus
def get_latest_task(
task_name: str,
db_session: Session,
) -> TaskQueueState | None:
stmt = (
select(TaskQueueState)
.where(TaskQueueState.task_name == task_name)
.order_by(desc(TaskQueueState.id))
.limit(1)
)
result = db_session.execute(stmt)
latest_task = result.scalars().first()
return latest_task
def register_task(
task_id: str,
task_name: str,
db_session: Session,
) -> TaskQueueState:
new_task = TaskQueueState(
task_id=task_id, task_name=task_name, status=TaskStatus.PENDING
)
db_session.add(new_task)
db_session.commit()
return new_task
def mark_task_start(
task_name: str,
db_session: Session,
) -> None:
task = get_latest_task(task_name, db_session)
if not task:
raise ValueError(f"No task found with name {task_name}")
task.start_time = func.now() # type: ignore
db_session.commit()
def mark_task_finished(
task_name: str,
db_session: Session,
success: bool = True,
) -> None:
latest_task = get_latest_task(task_name, db_session)
if latest_task is None:
raise ValueError(f"tasks for {task_name} do not exist")
latest_task.status = TaskStatus.SUCCESS if success else TaskStatus.FAILURE
db_session.commit()
def check_live_task_not_timed_out(
task: TaskQueueState,
db_session: Session,
timeout: int = JOB_TIMEOUT,
) -> bool:
# We only care for live tasks to not create new periodic tasks
if task.status in [TaskStatus.SUCCESS, TaskStatus.FAILURE]:
return False
current_db_time = get_db_current_time(db_session=db_session)
last_update_time = task.register_time
if task.start_time:
last_update_time = max(task.register_time, task.start_time)
time_elapsed = current_db_time - last_update_time
return time_elapsed.total_seconds() < timeout

View File

@ -1,84 +0,0 @@
from typing import cast
from sqlalchemy.orm import Session
from danswer.datastores.document_index import get_default_document_index
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import UpdateRequest
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import fetch_documents_for_document_set
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
logger = setup_logger()
_SYNC_BATCH_SIZE = 1000
def _sync_document_batch(
document_ids: list[str], document_index: DocumentIndex
) -> None:
logger.debug(f"Syncing document sets for: {document_ids}")
# begin a transaction, release lock at the end
with Session(get_sqlalchemy_engine()) as db_session:
# acquires a lock on the documents so that no other process can modify them
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
# get current state of document sets for these documents
document_set_map = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=document_ids, db_session=db_session
)
}
# update Vespa
document_index.update(
update_requests=[
UpdateRequest(
document_ids=[document_id],
document_sets=set(document_set_map.get(document_id, [])),
)
for document_id in document_ids
]
)
def sync_document_set(document_set_id: int) -> None:
document_index = get_default_document_index()
with Session(get_sqlalchemy_engine()) as db_session:
documents_to_update = fetch_documents_for_document_set(
document_set_id=document_set_id,
db_session=db_session,
current_only=False,
)
for document_batch in batch_generator(documents_to_update, _SYNC_BATCH_SIZE):
_sync_document_batch(
document_ids=[document.id for document in document_batch],
document_index=document_index,
)
# if there are no connectors, then delete the document set. Otherwise, just
# mark it as successfully synced.
document_set = cast(
DocumentSet,
get_document_set_by_id(
db_session=db_session, document_set_id=document_set_id
),
) # casting since we "know" a document set with this ID exists
if not document_set.connector_credential_pairs:
delete_document_set(document_set_row=document_set, db_session=db_session)
logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(
document_set_id=document_set_id, db_session=db_session
)
logger.info(f"Document set sync for '{document_set_id}' complete!")

View File

@ -14,11 +14,8 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user from danswer.auth.users import current_user
from danswer.background.celery.celery import cleanup_connector_credential_pair_task from danswer.background.celery.celery_utils import get_deletion_status
from danswer.background.celery.deletion_utils import get_deletion_status from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.connector_deletion import (
get_cleanup_task_id,
)
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
@ -536,6 +533,8 @@ def create_deletion_attempt_for_connector_id(
_: User = Depends(current_admin_user), _: User = Depends(current_admin_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> None: ) -> None:
from danswer.background.celery.celery import cleanup_connector_credential_pair_task
connector_id = connector_credential_pair_identifier.connector_id connector_id = connector_credential_pair_identifier.connector_id
credential_id = connector_credential_pair_identifier.credential_id credential_id = connector_credential_pair_identifier.credential_id
@ -559,7 +558,7 @@ def create_deletion_attempt_for_connector_id(
"no ongoing / planned indexing attempts.", "no ongoing / planned indexing attempts.",
) )
task_id = get_cleanup_task_id( task_id = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id connector_id=connector_id, credential_id=credential_id
) )
cleanup_connector_credential_pair_task.apply_async( cleanup_connector_credential_pair_task.apply_async(

View File

@ -1,4 +1,5 @@
import logging import logging
import os
from collections.abc import MutableMapping from collections.abc import MutableMapping
from typing import Any from typing import Any
@ -52,7 +53,9 @@ class _IndexAttemptLoggingAdapter(logging.LoggerAdapter):
def setup_logger( def setup_logger(
name: str = __name__, log_level: int = get_log_level_from_str() name: str = __name__,
log_level: int = get_log_level_from_str(),
logfile_name: str | None = None,
) -> logging.LoggerAdapter: ) -> logging.LoggerAdapter:
logger = logging.getLogger(name) logger = logging.getLogger(name)
@ -73,4 +76,12 @@ def setup_logger(
logger.addHandler(handler) logger.addHandler(handler)
if logfile_name:
is_containerized = os.path.exists("/.dockerenv")
file_name_template = (
"/var/log/{name}.log" if is_containerized else "./log/{name}.log"
)
file_handler = logging.FileHandler(file_name_template.format(name=logfile_name))
logger.addHandler(file_handler)
return _IndexAttemptLoggingAdapter(logger) return _IndexAttemptLoggingAdapter(logger)

View File

@ -0,0 +1,52 @@
import subprocess
import threading
def monitor_process(process_name: str, process: subprocess.Popen) -> None:
assert process.stdout is not None
while True:
output = process.stdout.readline()
if output:
print(f"{process_name}: {output.strip()}")
if process.poll() is not None:
break
def run_celery() -> None:
cmd_worker = [
"celery",
"-A",
"danswer.background.celery",
"worker",
"--loglevel=INFO",
"--concurrency=1",
]
cmd_beat = ["celery", "-A", "danswer.background.celery", "beat", "--loglevel=INFO"]
# Redirect stderr to stdout for both processes
worker_process = subprocess.Popen(
cmd_worker, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
beat_process = subprocess.Popen(
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
# Monitor outputs using threads
worker_thread = threading.Thread(
target=monitor_process, args=("WORKER", worker_process)
)
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
worker_thread.start()
beat_thread.start()
# Wait for threads to finish
worker_thread.join()
beat_thread.join()
if __name__ == "__main__":
run_celery()

View File

@ -10,28 +10,23 @@ stdout_logfile_maxbytes=52428800
redirect_stderr=true redirect_stderr=true
autorestart=true autorestart=true
[program:celery] # Background jobs that must be run async due to long time to completion
command=celery -A danswer.background.celery worker --loglevel=INFO [program:celery_worker]
stdout_logfile=/var/log/celery.log command=celery -A danswer.background.celery worker --loglevel=INFO --logfile=/var/log/celery_worker.log
stdout_logfile=/var/log/celery_worker_supervisor.log
stdout_logfile_maxbytes=52428800 stdout_logfile_maxbytes=52428800
redirect_stderr=true redirect_stderr=true
autorestart=true autorestart=true
[program:file_deletion] # Job scheduler for periodic tasks
command=python danswer/background/file_deletion.py [program:celery_beat]
stdout_logfile=/var/log/file_deletion.log command=celery -A danswer.background.celery beat --loglevel=INFO --logfile=/var/log/celery_beat.log
stdout_logfile=/var/log/celery_beat_supervisor.log
stdout_logfile_maxbytes=52428800 stdout_logfile_maxbytes=52428800
redirect_stderr=true redirect_stderr=true
autorestart=true autorestart=true
[program:document_set_sync] # Listens for Slack messages and responds with answers
command=python danswer/background/document_set_sync_script.py
stdout_logfile=/var/log/document_set_sync.log
stdout_logfile_maxbytes=52428800
redirect_stderr=true
autorestart=true
# Listens for slack messages and responds with answers
# for all channels that the DanswerBot has been added to. # for all channels that the DanswerBot has been added to.
# If not setup, this will just fail 5 times and then stop. # If not setup, this will just fail 5 times and then stop.
# More details on setup here: https://docs.danswer.dev/slack_bot_setup # More details on setup here: https://docs.danswer.dev/slack_bot_setup
@ -44,9 +39,9 @@ autorestart=true
startretries=5 startretries=5
startsecs=60 startsecs=60
# pushes all logs from the above programs to stdout # Pushes all logs from the above programs to stdout
[program:log-redirect-handler] [program:log-redirect-handler]
command=tail -qF /var/log/update.log /var/log/celery.log /var/log/file_deletion.log /var/log/slack_bot_listener.log /var/log/document_set_sync.log command=tail -qF /var/log/update.log /var/log/celery_worker.log /var/log/celery_worker_supervisor.log /var/log/celery_beat.log /var/log/celery_beat_supervisor.log /var/log/slack_bot_listener.log
stdout_logfile=/dev/stdout stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0 stdout_logfile_maxbytes=0
redirect_stderr=true redirect_stderr=true