Celery Beat (#575)

This commit is contained in:
Yuhong Sun 2023-10-16 14:59:42 -07:00 committed by GitHub
parent a7ddb22e50
commit b5982c10c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 507 additions and 299 deletions

View File

@ -7,6 +7,7 @@ from danswer.db.models import Base
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
@ -21,7 +22,7 @@ if config.config_file_name is not None:
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
target_metadata = [Base.metadata, ResultModelBase.metadata]
# other values from the config, defined by the needs of env.py,
# can be acquired:
@ -44,7 +45,7 @@ def run_migrations_offline() -> None:
url = build_connection_string()
context.configure(
url=url,
target_metadata=target_metadata,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
@ -54,7 +55,7 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
with context.begin_transaction():
context.run_migrations()

View File

@ -0,0 +1,48 @@
"""Task Tracking
Revision ID: 78dbe7e38469
Revises: 7ccea01261f6
Create Date: 2023-10-15 23:40:50.593262
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "78dbe7e38469"
down_revision = "7ccea01261f6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"task_queue_jobs",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("task_id", sa.String(), nullable=False),
sa.Column("task_name", sa.String(), nullable=False),
sa.Column(
"status",
sa.Enum(
"PENDING",
"STARTED",
"SUCCESS",
"FAILURE",
name="taskstatus",
native_enum=False,
),
nullable=False,
),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"register_time",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("task_queue_jobs")

View File

@ -1,9 +1,39 @@
from celery import Celery
import os
from datetime import timedelta
from pathlib import Path
from typing import cast
from danswer.background.connector_deletion import cleanup_connector_credential_pair
from celery import Celery # type: ignore
from celery.result import AsyncResult
from sqlalchemy.orm import Session
from danswer.background.connector_deletion import _delete_connector_credential_pair
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.connectors.file.utils import file_age_in_hours
from danswer.datastores.document_index import get_default_document_index
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import UpdateRequest
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import fetch_documents_for_document_set
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import SYNC_DB_API
from danswer.document_set.document_set import sync_document_set
from danswer.db.models import DocumentSet
from danswer.db.tasks import check_live_task_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import mark_task_finished
from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -13,17 +43,193 @@ celery_backend_url = "db+" + build_connection_string(db_api=SYNC_DB_API)
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit
_ExistingTaskCache: dict[int, AsyncResult] = {}
_SYNC_BATCH_SIZE = 1000
#####
# Tasks that need to be run in job queue, registered via APIs
#
# If imports from this module are needed, use local imports to avoid circular importing
#####
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def cleanup_connector_credential_pair_task(
connector_id: int, credential_id: int
connector_id: int,
credential_id: int,
) -> int:
return cleanup_connector_credential_pair(connector_id, credential_id)
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
or updating the ACL"""
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair or not check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair
):
raise ValueError(
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
"This is likely because there is an ongoing / planned indexing attempt OR the "
"connector is not disabled."
)
try:
# The bulk of the work is in here, updates Postgres and Vespa
return _delete_connector_credential_pair(
db_session=db_session,
document_index=get_default_document_index(),
cc_pair=cc_pair,
)
except Exception as e:
logger.exception(f"Failed to run connector_deletion due to {e}")
raise e
@celery_app.task(soft_time_limit=60 * 60 * 6) # 6 hour time limit
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int) -> None:
try:
return sync_document_set(document_set_id=document_set_id)
except Exception:
logger.exception("Failed to sync document set %s", document_set_id)
raise
"""For document sets marked as not up to date, sync the state from postgres
into the datastore. Also handles deletions."""
def _sync_document_batch(
document_ids: list[str], document_index: DocumentIndex
) -> None:
logger.debug(f"Syncing document sets for: {document_ids}")
# begin a transaction, release lock at the end
with Session(get_sqlalchemy_engine()) as db_session:
# acquires a lock on the documents so that no other process can modify them
prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
)
# get current state of document sets for these documents
document_set_map = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=document_ids, db_session=db_session
)
}
# update Vespa
document_index.update(
update_requests=[
UpdateRequest(
document_ids=[document_id],
document_sets=set(document_set_map.get(document_id, [])),
)
for document_id in document_ids
]
)
with Session(get_sqlalchemy_engine()) as db_session:
task_name = name_document_set_sync_task(document_set_id)
mark_task_start(task_name, db_session)
try:
document_index = get_default_document_index()
documents_to_update = fetch_documents_for_document_set(
document_set_id=document_set_id,
db_session=db_session,
current_only=False,
)
for document_batch in batch_generator(
documents_to_update, _SYNC_BATCH_SIZE
):
_sync_document_batch(
document_ids=[document.id for document in document_batch],
document_index=document_index,
)
# if there are no connectors, then delete the document set. Otherwise, just
# mark it as successfully synced.
document_set = cast(
DocumentSet,
get_document_set_by_id(
db_session=db_session, document_set_id=document_set_id
),
) # casting since we "know" a document set with this ID exists
if not document_set.connector_credential_pairs:
delete_document_set(
document_set_row=document_set, db_session=db_session
)
logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(
document_set_id=document_set_id, db_session=db_session
)
logger.info(f"Document set sync for '{document_set_id}' complete!")
except Exception:
logger.exception("Failed to sync document set %s", document_set_id)
mark_task_finished(task_name, db_session, success=False)
raise
mark_task_finished(task_name, db_session)
#####
# Periodic Tasks
#####
@celery_app.task(
name="check_for_document_sets_sync_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_document_sets_sync_task() -> None:
"""Runs periodically to check if any document sets are out of sync
Creates a task to sync the set if needed"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
if not document_set.is_up_to_date:
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_live_task_not_timed_out(
latest_sync, db_session
):
logger.info(
f"Document set '{document_set.id}' is already syncing. Skipping."
)
continue
logger.info(f"Document set {document_set.id} syncing now!")
task = sync_document_set_task.apply_async(
kwargs=dict(document_set_id=document_set.id),
)
register_task(task.id, task_name, db_session)
@celery_app.task(name="clean_old_temp_files_task", soft_time_limit=JOB_TIMEOUT)
def clean_old_temp_files_task(
age_threshold_in_hours: float | int = 24 * 7, # 1 week,
base_path: Path | str = FILE_CONNECTOR_TMP_STORAGE_PATH,
) -> None:
"""Files added via the File connector need to be deleted after ingestion
Currently handled async of the indexing job"""
os.makedirs(base_path, exist_ok=True)
for file in os.listdir(base_path):
if file_age_in_hours(file) > age_threshold_in_hours:
os.remove(Path(base_path) / file)
#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"check-for-document-set-sync": {
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
},
"clean-old-temp-files": {
"task": "clean_old_temp_files_task",
"schedule": timedelta(minutes=30),
},
}

View File

@ -1,11 +1,15 @@
import json
from typing import cast
from celery.result import AsyncResult
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.celery import celery_app
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DeletionStatus
from danswer.server.models import DeletionAttemptSnapshot
def get_celery_task(task_id: str) -> AsyncResult:
@ -35,3 +39,37 @@ def get_celery_task_status(task_id: str) -> str | None:
return task.status
return None
def get_deletion_status(
connector_id: int, credential_id: int
) -> DeletionAttemptSnapshot | None:
cleanup_task_id = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id
)
deletion_task = get_celery_task(task_id=cleanup_task_id)
deletion_task_status = get_celery_task_status(task_id=cleanup_task_id)
deletion_status = None
error_msg = None
num_docs_deleted = 0
if deletion_task_status == "SUCCESS":
deletion_status = DeletionStatus.SUCCESS
num_docs_deleted = cast(int, deletion_task.get(propagate=False))
elif deletion_task_status == "FAILURE":
deletion_status = DeletionStatus.FAILED
error_msg = deletion_task.get(propagate=False)
elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING":
deletion_status = DeletionStatus.IN_PROGRESS
return (
DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=deletion_status,
error_msg=str(error_msg),
num_docs_deleted=num_docs_deleted,
)
if deletion_status
else None
)

View File

@ -1,41 +0,0 @@
from typing import cast
from danswer.background.celery.celery_utils import get_celery_task
from danswer.background.celery.celery_utils import get_celery_task_status
from danswer.background.connector_deletion import get_cleanup_task_id
from danswer.db.models import DeletionStatus
from danswer.server.models import DeletionAttemptSnapshot
def get_deletion_status(
connector_id: int, credential_id: int
) -> DeletionAttemptSnapshot | None:
cleanup_task_id = get_cleanup_task_id(
connector_id=connector_id, credential_id=credential_id
)
deletion_task = get_celery_task(task_id=cleanup_task_id)
deletion_task_status = get_celery_task_status(task_id=cleanup_task_id)
deletion_status = None
error_msg = None
num_docs_deleted = 0
if deletion_task_status == "SUCCESS":
deletion_status = DeletionStatus.SUCCESS
num_docs_deleted = cast(int, deletion_task.get(propagate=False))
elif deletion_task_status == "FAILURE":
deletion_status = DeletionStatus.FAILED
error_msg = deletion_task.get(propagate=False)
elif deletion_task_status == "STARTED" or deletion_task_status == "PENDING":
deletion_status = DeletionStatus.IN_PROGRESS
return (
DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=deletion_status,
error_msg=str(error_msg),
num_docs_deleted=num_docs_deleted,
)
if deletion_status
else None
)

View File

@ -17,15 +17,12 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.datastores.document_index import get_default_document_index
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import UpdateRequest
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import delete_document_by_connector_credential_pair
from danswer.db.document import delete_documents_complete
from danswer.db.document import get_document_connector_cnts
@ -211,39 +208,3 @@ def _delete_connector_credential_pair(
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
)
return num_docs_deleted
def cleanup_connector_credential_pair(
connector_id: int,
credential_id: int,
) -> int:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair or not check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair
):
raise ValueError(
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
"This is likely because there is an ongoing / planned indexing attempt OR the "
"connector is not disabled."
)
try:
return _delete_connector_credential_pair(
db_session=db_session,
document_index=get_default_document_index(),
cc_pair=cc_pair,
)
except Exception as e:
logger.exception(f"Failed to run connector_deletion due to {e}")
raise e
def get_cleanup_task_id(connector_id: int, credential_id: int) -> str:
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"

View File

@ -1,55 +0,0 @@
from celery.result import AsyncResult
from sqlalchemy.orm import Session
from danswer.background.celery.celery import sync_document_set_task
from danswer.background.utils import interval_run_job
from danswer.db.document_set import (
fetch_document_sets,
)
from danswer.db.engine import get_sqlalchemy_engine
from danswer.utils.logger import setup_logger
logger = setup_logger()
_ExistingTaskCache: dict[int, AsyncResult] = {}
def _document_sync_loop() -> None:
# cleanup tasks
existing_tasks = list(_ExistingTaskCache.items())
for document_set_id, task in existing_tasks:
if task.ready():
logger.info(
f"Document set '{document_set_id}' is complete with status "
f"{task.status}. Cleaning up."
)
del _ExistingTaskCache[document_set_id]
# kick off new tasks
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
db_session=db_session, include_outdated=True
)
for document_set, _ in document_set_info:
if not document_set.is_up_to_date:
if document_set.id in _ExistingTaskCache:
logger.info(
f"Document set '{document_set.id}' is already syncing. Skipping."
)
continue
logger.info(
f"Document set {document_set.id} is not synced. Syncing now!"
)
task = sync_document_set_task.apply_async(
kwargs=dict(document_set_id=document_set.id),
)
_ExistingTaskCache[document_set.id] = task
if __name__ == "__main__":
interval_run_job(
job=_document_sync_loop, delay=5, emit_job_start_log=False
) # run every 5 seconds

View File

@ -1,6 +0,0 @@
from danswer.background.utils import interval_run_job
from danswer.connectors.file.utils import clean_old_temp_files
if __name__ == "__main__":
interval_run_job(clean_old_temp_files, 30 * 60) # run every 30 minutes

View File

@ -0,0 +1,6 @@
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
def name_document_set_sync_task(document_set_id: int) -> str:
return f"sync_doc_set_{document_set_id}"

View File

@ -1,24 +0,0 @@
import time
from collections.abc import Callable
from typing import Any
from danswer.utils.logger import setup_logger
logger = setup_logger()
def interval_run_job(
job: Callable[[], Any], delay: int | float, emit_job_start_log: bool = True
) -> None:
while True:
start = time.time()
if emit_job_start_log:
logger.info(f"Running '{job.__name__}', current time: {time.ctime(start)}")
try:
job()
except Exception as e:
logger.exception(f"Failed to run update due to {e}")
sleep_time = delay - (time.time() - start)
if sleep_time > 0:
time.sleep(sleep_time)

View File

@ -211,7 +211,7 @@ CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
# fairly large amount of memory in order to increase substantially, since
# each worker loads the embedding models into memory.
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
# Logs every model prompt and output, mostly used for development or exploration purposes
LOG_ALL_MODEL_INTERACTIONS = (
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"

View File

@ -8,7 +8,6 @@ from typing import IO
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
_FILE_AGE_CLEANUP_THRESHOLD_HOURS = 24 * 7 # 1 week
_VALID_FILE_EXTENSIONS = [".txt", ".zip", ".pdf"]
@ -53,13 +52,3 @@ def write_temp_files(
def file_age_in_hours(filepath: str | Path) -> float:
return (time.time() - os.path.getmtime(filepath)) / (60 * 60)
def clean_old_temp_files(
age_threshold_in_hours: float | int = _FILE_AGE_CLEANUP_THRESHOLD_HOURS,
base_path: Path | str = FILE_CONNECTOR_TMP_STORAGE_PATH,
) -> None:
os.makedirs(base_path, exist_ok=True)
for file in os.listdir(base_path):
if file_age_in_hours(file) > age_threshold_in_hours:
os.remove(Path(base_path) / file)

View File

@ -52,6 +52,14 @@ class DeletionStatus(str, PyEnum):
FAILED = "failed"
# Consistent with Celery task statuses
class TaskStatus(str, PyEnum):
PENDING = "PENDING"
STARTED = "STARTED"
SUCCESS = "SUCCESS"
FAILURE = "FAILURE"
class Base(DeclarativeBase):
pass
@ -566,3 +574,22 @@ class SlackBotConfig(Base):
)
persona: Mapped[Persona | None] = relationship("Persona")
class TaskQueueState(Base):
# Currently refers to Celery Tasks
__tablename__ = "task_queue_jobs"
id: Mapped[int] = mapped_column(primary_key=True)
# Celery task id
task_id: Mapped[str] = mapped_column(String)
# For any job type, this would be the same
task_name: Mapped[str] = mapped_column(String)
# Note that if the task dies, this won't necessarily be marked FAILED correctly
status: Mapped[TaskStatus] = mapped_column(Enum(TaskStatus))
start_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True)
)
register_time: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)

View File

@ -0,0 +1,85 @@
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.db.engine import get_db_current_time
from danswer.db.models import TaskQueueState
from danswer.db.models import TaskStatus
def get_latest_task(
task_name: str,
db_session: Session,
) -> TaskQueueState | None:
stmt = (
select(TaskQueueState)
.where(TaskQueueState.task_name == task_name)
.order_by(desc(TaskQueueState.id))
.limit(1)
)
result = db_session.execute(stmt)
latest_task = result.scalars().first()
return latest_task
def register_task(
task_id: str,
task_name: str,
db_session: Session,
) -> TaskQueueState:
new_task = TaskQueueState(
task_id=task_id, task_name=task_name, status=TaskStatus.PENDING
)
db_session.add(new_task)
db_session.commit()
return new_task
def mark_task_start(
task_name: str,
db_session: Session,
) -> None:
task = get_latest_task(task_name, db_session)
if not task:
raise ValueError(f"No task found with name {task_name}")
task.start_time = func.now() # type: ignore
db_session.commit()
def mark_task_finished(
task_name: str,
db_session: Session,
success: bool = True,
) -> None:
latest_task = get_latest_task(task_name, db_session)
if latest_task is None:
raise ValueError(f"tasks for {task_name} do not exist")
latest_task.status = TaskStatus.SUCCESS if success else TaskStatus.FAILURE
db_session.commit()
def check_live_task_not_timed_out(
task: TaskQueueState,
db_session: Session,
timeout: int = JOB_TIMEOUT,
) -> bool:
# We only care for live tasks to not create new periodic tasks
if task.status in [TaskStatus.SUCCESS, TaskStatus.FAILURE]:
return False
current_db_time = get_db_current_time(db_session=db_session)
last_update_time = task.register_time
if task.start_time:
last_update_time = max(task.register_time, task.start_time)
time_elapsed = current_db_time - last_update_time
return time_elapsed.total_seconds() < timeout

View File

@ -1,84 +0,0 @@
from typing import cast
from sqlalchemy.orm import Session
from danswer.datastores.document_index import get_default_document_index
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import UpdateRequest
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import fetch_documents_for_document_set
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
logger = setup_logger()
_SYNC_BATCH_SIZE = 1000
def _sync_document_batch(
document_ids: list[str], document_index: DocumentIndex
) -> None:
logger.debug(f"Syncing document sets for: {document_ids}")
# begin a transaction, release lock at the end
with Session(get_sqlalchemy_engine()) as db_session:
# acquires a lock on the documents so that no other process can modify them
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
# get current state of document sets for these documents
document_set_map = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=document_ids, db_session=db_session
)
}
# update Vespa
document_index.update(
update_requests=[
UpdateRequest(
document_ids=[document_id],
document_sets=set(document_set_map.get(document_id, [])),
)
for document_id in document_ids
]
)
def sync_document_set(document_set_id: int) -> None:
document_index = get_default_document_index()
with Session(get_sqlalchemy_engine()) as db_session:
documents_to_update = fetch_documents_for_document_set(
document_set_id=document_set_id,
db_session=db_session,
current_only=False,
)
for document_batch in batch_generator(documents_to_update, _SYNC_BATCH_SIZE):
_sync_document_batch(
document_ids=[document.id for document in document_batch],
document_index=document_index,
)
# if there are no connectors, then delete the document set. Otherwise, just
# mark it as successfully synced.
document_set = cast(
DocumentSet,
get_document_set_by_id(
db_session=db_session, document_set_id=document_set_id
),
) # casting since we "know" a document set with this ID exists
if not document_set.connector_credential_pairs:
delete_document_set(document_set_row=document_set, db_session=db_session)
logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(
document_set_id=document_set_id, db_session=db_session
)
logger.info(f"Document set sync for '{document_set_id}' complete!")

View File

@ -14,11 +14,8 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery import cleanup_connector_credential_pair_task
from danswer.background.celery.deletion_utils import get_deletion_status
from danswer.background.connector_deletion import (
get_cleanup_task_id,
)
from danswer.background.celery.celery_utils import get_deletion_status
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
@ -536,6 +533,8 @@ def create_deletion_attempt_for_connector_id(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
from danswer.background.celery.celery import cleanup_connector_credential_pair_task
connector_id = connector_credential_pair_identifier.connector_id
credential_id = connector_credential_pair_identifier.credential_id
@ -559,7 +558,7 @@ def create_deletion_attempt_for_connector_id(
"no ongoing / planned indexing attempts.",
)
task_id = get_cleanup_task_id(
task_id = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id
)
cleanup_connector_credential_pair_task.apply_async(

View File

@ -1,4 +1,5 @@
import logging
import os
from collections.abc import MutableMapping
from typing import Any
@ -52,7 +53,9 @@ class _IndexAttemptLoggingAdapter(logging.LoggerAdapter):
def setup_logger(
name: str = __name__, log_level: int = get_log_level_from_str()
name: str = __name__,
log_level: int = get_log_level_from_str(),
logfile_name: str | None = None,
) -> logging.LoggerAdapter:
logger = logging.getLogger(name)
@ -73,4 +76,12 @@ def setup_logger(
logger.addHandler(handler)
if logfile_name:
is_containerized = os.path.exists("/.dockerenv")
file_name_template = (
"/var/log/{name}.log" if is_containerized else "./log/{name}.log"
)
file_handler = logging.FileHandler(file_name_template.format(name=logfile_name))
logger.addHandler(file_handler)
return _IndexAttemptLoggingAdapter(logger)

View File

@ -0,0 +1,52 @@
import subprocess
import threading
def monitor_process(process_name: str, process: subprocess.Popen) -> None:
assert process.stdout is not None
while True:
output = process.stdout.readline()
if output:
print(f"{process_name}: {output.strip()}")
if process.poll() is not None:
break
def run_celery() -> None:
cmd_worker = [
"celery",
"-A",
"danswer.background.celery",
"worker",
"--loglevel=INFO",
"--concurrency=1",
]
cmd_beat = ["celery", "-A", "danswer.background.celery", "beat", "--loglevel=INFO"]
# Redirect stderr to stdout for both processes
worker_process = subprocess.Popen(
cmd_worker, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
beat_process = subprocess.Popen(
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
)
# Monitor outputs using threads
worker_thread = threading.Thread(
target=monitor_process, args=("WORKER", worker_process)
)
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
worker_thread.start()
beat_thread.start()
# Wait for threads to finish
worker_thread.join()
beat_thread.join()
if __name__ == "__main__":
run_celery()

View File

@ -10,28 +10,23 @@ stdout_logfile_maxbytes=52428800
redirect_stderr=true
autorestart=true
[program:celery]
command=celery -A danswer.background.celery worker --loglevel=INFO
stdout_logfile=/var/log/celery.log
# Background jobs that must be run async due to long time to completion
[program:celery_worker]
command=celery -A danswer.background.celery worker --loglevel=INFO --logfile=/var/log/celery_worker.log
stdout_logfile=/var/log/celery_worker_supervisor.log
stdout_logfile_maxbytes=52428800
redirect_stderr=true
autorestart=true
[program:file_deletion]
command=python danswer/background/file_deletion.py
stdout_logfile=/var/log/file_deletion.log
# Job scheduler for periodic tasks
[program:celery_beat]
command=celery -A danswer.background.celery beat --loglevel=INFO --logfile=/var/log/celery_beat.log
stdout_logfile=/var/log/celery_beat_supervisor.log
stdout_logfile_maxbytes=52428800
redirect_stderr=true
autorestart=true
[program:document_set_sync]
command=python danswer/background/document_set_sync_script.py
stdout_logfile=/var/log/document_set_sync.log
stdout_logfile_maxbytes=52428800
redirect_stderr=true
autorestart=true
# Listens for slack messages and responds with answers
# Listens for Slack messages and responds with answers
# for all channels that the DanswerBot has been added to.
# If not setup, this will just fail 5 times and then stop.
# More details on setup here: https://docs.danswer.dev/slack_bot_setup
@ -44,9 +39,9 @@ autorestart=true
startretries=5
startsecs=60
# pushes all logs from the above programs to stdout
# Pushes all logs from the above programs to stdout
[program:log-redirect-handler]
command=tail -qF /var/log/update.log /var/log/celery.log /var/log/file_deletion.log /var/log/slack_bot_listener.log /var/log/document_set_sync.log
command=tail -qF /var/log/update.log /var/log/celery_worker.log /var/log/celery_worker_supervisor.log /var/log/celery_beat.log /var/log/celery_beat_supervisor.log /var/log/slack_bot_listener.log
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes=0
redirect_stderr=true