Disallowed simultaneous pruning jobs (#1704)

* Added TTL to EE Celery tasks

* fixed alembic files

* fixed frontend build issue and reworked file deletion

* FileD

* revert change

* reworked delete chatmessage

* added orphan cleanup

* ensured syntax

* Disallowed simultaneous pruning jobs

* added rate limiting and env vars

* i hope this is how you use decorators

* nonsense

* cleaned up names, added config

* renamed other utils

* Update celery_utils.py

* reverted changes
This commit is contained in:
hagen-danswer 2024-06-28 16:26:00 -07:00 committed by GitHub
parent 3fe5313b02
commit 60dd77393d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 74 additions and 13 deletions

View File

@ -6,16 +6,23 @@ from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.interfaces import BaseConnector
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.engine import get_db_current_time
from danswer.db.models import Connector
from danswer.db.models import Credential
from danswer.db.models import DocumentSet
from danswer.db.tasks import check_live_task_not_timed_out
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
@ -47,7 +54,7 @@ def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_live_task_not_timed_out(latest_sync, db_session):
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
return False
@ -73,7 +80,19 @@ def should_prune_cc_pair(
return True
return False
if check_live_task_not_timed_out(last_pruning_task, db_session):
if PREVENT_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
)
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
logger.info("Another Connector is already pruning. Skipping.")
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
return False
@ -84,25 +103,36 @@ def should_prune_cc_pair(
return time_since_last_pruning.total_seconds() >= connector.prune_freq
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
return {doc.id for doc in doc_batch}
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs
"""
all_connector_doc_ids: set[str] = set()
doc_batch_generator = None
if isinstance(runnable_connector, IdConnector):
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
elif isinstance(runnable_connector, LoadConnector):
doc_batch_generator = runnable_connector.load_from_state()
for doc_batch in doc_batch_generator:
all_connector_doc_ids.update(doc.id for doc in doc_batch)
elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
for doc_batch in doc_batch_generator:
all_connector_doc_ids.update(doc.id for doc in doc_batch)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
if doc_batch_generator:
doc_batch_processing_func = document_batch_to_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids

View File

@ -22,8 +22,13 @@ def name_document_set_sync_task(document_set_id: int) -> str:
return f"sync_doc_set_{document_set_id}"
def name_cc_prune_task(connector_id: int, credential_id: int) -> str:
return f"prune_connector_credential_pair_{connector_id}_{credential_id}"
def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
if not connector_id or not credential_id:
task_name = "prune_connector_credential_pair"
return task_name
T = TypeVar("T", bound=Callable)

View File

@ -214,6 +214,15 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
PREVENT_SIMULTANEOUS_PRUNING = (
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
)
#####
# Indexing Configs

View File

@ -26,6 +26,23 @@ def get_latest_task(
return latest_task
def get_latest_task_by_type(
task_name: str,
db_session: Session,
) -> TaskQueueState | None:
stmt = (
select(TaskQueueState)
.where(TaskQueueState.task_name.like(f"%{task_name}%"))
.order_by(desc(TaskQueueState.id))
.limit(1)
)
result = db_session.execute(stmt)
latest_task = result.scalars().first()
return latest_task
def register_task(
task_id: str,
task_name: str,
@ -66,7 +83,7 @@ def mark_task_finished(
db_session.commit()
def check_live_task_not_timed_out(
def check_task_is_live_and_not_timed_out(
task: TaskQueueState,
db_session: Session,
timeout: int = JOB_TIMEOUT,

View File

@ -1,7 +1,7 @@
from sqlalchemy.orm import Session
from danswer.db.models import UserGroup
from danswer.db.tasks import check_live_task_not_timed_out
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import name_chat_ttl_task
@ -16,7 +16,7 @@ def should_sync_user_groups(user_group: UserGroup, db_session: Session) -> bool:
task_name = name_user_group_sync_task(user_group.id)
latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_live_task_not_timed_out(latest_sync, db_session):
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
logger.info("TTL check is already being performed. Skipping.")
return False
return True
@ -34,7 +34,7 @@ def should_perform_chat_ttl_check(
if not latest_task:
return True
if latest_task and check_live_task_not_timed_out(latest_task, db_session):
if latest_task and check_task_is_live_and_not_timed_out(latest_task, db_session):
logger.info("TTL check is already being performed. Skipping.")
return False
return True