Disallowed simultaneous pruning jobs (#1704)

* Added TTL to EE Celery tasks

* fixed alembic files

* fixed frontend build issue and reworked file deletion

* FileD

* revert change

* reworked delete chatmessage

* added orphan cleanup

* ensured syntax

* Disallowed simultaneous pruning jobs

* added rate limiting and env vars

* i hope this is how you use decorators

* nonsense

* cleaned up names, added config

* renamed other utils

* Update celery_utils.py

* reverted changes
This commit is contained in:
hagen-danswer
2024-06-28 16:26:00 -07:00
committed by GitHub
parent 3fe5313b02
commit 60dd77393d
5 changed files with 74 additions and 13 deletions

View File

@@ -6,16 +6,23 @@ from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from danswer.connectors.interfaces import BaseConnector from danswer.connectors.interfaces import BaseConnector
from danswer.connectors.interfaces import IdConnector from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.engine import get_db_current_time from danswer.db.engine import get_db_current_time
from danswer.db.models import Connector from danswer.db.models import Connector
from danswer.db.models import Credential from danswer.db.models import Credential
from danswer.db.models import DocumentSet from danswer.db.models import DocumentSet
from danswer.db.tasks import check_live_task_not_timed_out from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
from danswer.server.documents.models import DeletionAttemptSnapshot from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@@ -47,7 +54,7 @@ def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
task_name = name_document_set_sync_task(document_set.id) task_name = name_document_set_sync_task(document_set.id)
latest_sync = get_latest_task(task_name, db_session) latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_live_task_not_timed_out(latest_sync, db_session): if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.") logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
return False return False
@@ -73,7 +80,19 @@ def should_prune_cc_pair(
return True return True
return False return False
if check_live_task_not_timed_out(last_pruning_task, db_session): if PREVENT_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
)
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
logger.info("Another Connector is already pruning. Skipping.")
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.") logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
return False return False
@@ -84,25 +103,36 @@ def should_prune_cc_pair(
return time_since_last_pruning.total_seconds() >= connector.prune_freq return time_since_last_pruning.total_seconds() >= connector.prune_freq
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
return {doc.id for doc in doc_batch}
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]: def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
""" """
If the PruneConnector hasnt been implemented for the given connector, just pull If the PruneConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs all docs using the load_from_state and grab out the IDs
""" """
all_connector_doc_ids: set[str] = set() all_connector_doc_ids: set[str] = set()
doc_batch_generator = None
if isinstance(runnable_connector, IdConnector): if isinstance(runnable_connector, IdConnector):
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids() all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
elif isinstance(runnable_connector, LoadConnector): elif isinstance(runnable_connector, LoadConnector):
doc_batch_generator = runnable_connector.load_from_state() doc_batch_generator = runnable_connector.load_from_state()
for doc_batch in doc_batch_generator:
all_connector_doc_ids.update(doc.id for doc in doc_batch)
elif isinstance(runnable_connector, PollConnector): elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp() start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp() end = datetime.now(timezone.utc).timestamp()
doc_batch_generator = runnable_connector.poll_source(start=start, end=end) doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
for doc_batch in doc_batch_generator:
all_connector_doc_ids.update(doc.id for doc in doc_batch)
else: else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.") raise RuntimeError("Pruning job could not find a valid runnable_connector.")
if doc_batch_generator:
doc_batch_processing_func = document_batch_to_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids return all_connector_doc_ids

View File

@@ -22,8 +22,13 @@ def name_document_set_sync_task(document_set_id: int) -> str:
return f"sync_doc_set_{document_set_id}" return f"sync_doc_set_{document_set_id}"
def name_cc_prune_task(connector_id: int, credential_id: int) -> str: def name_cc_prune_task(
return f"prune_connector_credential_pair_{connector_id}_{credential_id}" connector_id: int | None = None, credential_id: int | None = None
) -> str:
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
if not connector_id or not credential_id:
task_name = "prune_connector_credential_pair"
return task_name
T = TypeVar("T", bound=Callable) T = TypeVar("T", bound=Callable)

View File

@@ -214,6 +214,15 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
PREVENT_SIMULTANEOUS_PRUNING = (
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
)
##### #####
# Indexing Configs # Indexing Configs

View File

@@ -26,6 +26,23 @@ def get_latest_task(
return latest_task return latest_task
def get_latest_task_by_type(
task_name: str,
db_session: Session,
) -> TaskQueueState | None:
stmt = (
select(TaskQueueState)
.where(TaskQueueState.task_name.like(f"%{task_name}%"))
.order_by(desc(TaskQueueState.id))
.limit(1)
)
result = db_session.execute(stmt)
latest_task = result.scalars().first()
return latest_task
def register_task( def register_task(
task_id: str, task_id: str,
task_name: str, task_name: str,
@@ -66,7 +83,7 @@ def mark_task_finished(
db_session.commit() db_session.commit()
def check_live_task_not_timed_out( def check_task_is_live_and_not_timed_out(
task: TaskQueueState, task: TaskQueueState,
db_session: Session, db_session: Session,
timeout: int = JOB_TIMEOUT, timeout: int = JOB_TIMEOUT,

View File

@@ -1,7 +1,7 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.db.models import UserGroup from danswer.db.models import UserGroup
from danswer.db.tasks import check_live_task_not_timed_out from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task from danswer.db.tasks import get_latest_task
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import name_chat_ttl_task from ee.danswer.background.task_name_builders import name_chat_ttl_task
@@ -16,7 +16,7 @@ def should_sync_user_groups(user_group: UserGroup, db_session: Session) -> bool:
task_name = name_user_group_sync_task(user_group.id) task_name = name_user_group_sync_task(user_group.id)
latest_sync = get_latest_task(task_name, db_session) latest_sync = get_latest_task(task_name, db_session)
if latest_sync and check_live_task_not_timed_out(latest_sync, db_session): if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
logger.info("TTL check is already being performed. Skipping.") logger.info("TTL check is already being performed. Skipping.")
return False return False
return True return True
@@ -34,7 +34,7 @@ def should_perform_chat_ttl_check(
if not latest_task: if not latest_task:
return True return True
if latest_task and check_live_task_not_timed_out(latest_task, db_session): if latest_task and check_task_is_live_and_not_timed_out(latest_task, db_session):
logger.info("TTL check is already being performed. Skipping.") logger.info("TTL check is already being performed. Skipping.")
return False return False
return True return True