diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index 48f0295cd..6b9b5a896 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -6,16 +6,23 @@ from sqlalchemy.orm import Session from danswer.background.task_utils import name_cc_cleanup_task from danswer.background.task_utils import name_cc_prune_task from danswer.background.task_utils import name_document_set_sync_task +from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE +from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING +from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( + rate_limit_builder, +) from danswer.connectors.interfaces import BaseConnector from danswer.connectors.interfaces import IdConnector from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector +from danswer.connectors.models import Document from danswer.db.engine import get_db_current_time from danswer.db.models import Connector from danswer.db.models import Credential from danswer.db.models import DocumentSet -from danswer.db.tasks import check_live_task_not_timed_out +from danswer.db.tasks import check_task_is_live_and_not_timed_out from danswer.db.tasks import get_latest_task +from danswer.db.tasks import get_latest_task_by_type from danswer.server.documents.models import DeletionAttemptSnapshot from danswer.utils.logger import setup_logger @@ -47,7 +54,7 @@ def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool: task_name = name_document_set_sync_task(document_set.id) latest_sync = get_latest_task(task_name, db_session) - if latest_sync and check_live_task_not_timed_out(latest_sync, db_session): + if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session): logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.") return False @@ -73,7 +80,19 @@ def should_prune_cc_pair( return True return False - if check_live_task_not_timed_out(last_pruning_task, db_session): + if PREVENT_SIMULTANEOUS_PRUNING: + pruning_type_task_name = name_cc_prune_task() + last_pruning_type_task = get_latest_task_by_type( + pruning_type_task_name, db_session + ) + + if last_pruning_type_task and check_task_is_live_and_not_timed_out( + last_pruning_type_task, db_session + ): + logger.info("Another Connector is already pruning. Skipping.") + return False + + if check_task_is_live_and_not_timed_out(last_pruning_task, db_session): logger.info(f"Connector '{connector.name}' is already pruning. Skipping.") return False @@ -84,25 +103,36 @@ def should_prune_cc_pair( return time_since_last_pruning.total_seconds() >= connector.prune_freq +def document_batch_to_ids(doc_batch: list[Document]) -> set[str]: + return {doc.id for doc in doc_batch} + + def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]: """ If the PruneConnector hasnt been implemented for the given connector, just pull all docs using the load_from_state and grab out the IDs """ all_connector_doc_ids: set[str] = set() + + doc_batch_generator = None if isinstance(runnable_connector, IdConnector): all_connector_doc_ids = runnable_connector.retrieve_all_source_ids() elif isinstance(runnable_connector, LoadConnector): doc_batch_generator = runnable_connector.load_from_state() - for doc_batch in doc_batch_generator: - all_connector_doc_ids.update(doc.id for doc in doc_batch) elif isinstance(runnable_connector, PollConnector): start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp() end = datetime.now(timezone.utc).timestamp() doc_batch_generator = runnable_connector.poll_source(start=start, end=end) - for doc_batch in doc_batch_generator: - all_connector_doc_ids.update(doc.id for doc in doc_batch) else: raise RuntimeError("Pruning job could not find a valid runnable_connector.") + if doc_batch_generator: + doc_batch_processing_func = document_batch_to_ids + if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE: + doc_batch_processing_func = rate_limit_builder( + max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60 + )(document_batch_to_ids) + for doc_batch in doc_batch_generator: + all_connector_doc_ids.update(doc_batch_processing_func(doc_batch)) + return all_connector_doc_ids diff --git a/backend/danswer/background/task_utils.py b/backend/danswer/background/task_utils.py index 902abdfec..6e1226788 100644 --- a/backend/danswer/background/task_utils.py +++ b/backend/danswer/background/task_utils.py @@ -22,8 +22,13 @@ def name_document_set_sync_task(document_set_id: int) -> str: return f"sync_doc_set_{document_set_id}" -def name_cc_prune_task(connector_id: int, credential_id: int) -> str: - return f"prune_connector_credential_pair_{connector_id}_{credential_id}" +def name_cc_prune_task( + connector_id: int | None = None, credential_id: int | None = None +) -> str: + task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}" + if not connector_id or not credential_id: + task_name = "prune_connector_credential_pair" + return task_name T = TypeVar("T", bound=Callable) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 92995ac44..c40c661b4 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -214,6 +214,15 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = ( DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day +PREVENT_SIMULTANEOUS_PRUNING = ( + os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true" +) + +# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation. +MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int( + os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0) +) + ##### # Indexing Configs diff --git a/backend/danswer/db/tasks.py b/backend/danswer/db/tasks.py index a12f98861..23a7edc98 100644 --- a/backend/danswer/db/tasks.py +++ b/backend/danswer/db/tasks.py @@ -26,6 +26,23 @@ def get_latest_task( return latest_task +def get_latest_task_by_type( + task_name: str, + db_session: Session, +) -> TaskQueueState | None: + stmt = ( + select(TaskQueueState) + .where(TaskQueueState.task_name.like(f"%{task_name}%")) + .order_by(desc(TaskQueueState.id)) + .limit(1) + ) + + result = db_session.execute(stmt) + latest_task = result.scalars().first() + + return latest_task + + def register_task( task_id: str, task_name: str, @@ -66,7 +83,7 @@ def mark_task_finished( db_session.commit() -def check_live_task_not_timed_out( +def check_task_is_live_and_not_timed_out( task: TaskQueueState, db_session: Session, timeout: int = JOB_TIMEOUT, diff --git a/backend/ee/danswer/background/celery_utils.py b/backend/ee/danswer/background/celery_utils.py index 9ab436596..0134f6642 100644 --- a/backend/ee/danswer/background/celery_utils.py +++ b/backend/ee/danswer/background/celery_utils.py @@ -1,7 +1,7 @@ from sqlalchemy.orm import Session from danswer.db.models import UserGroup -from danswer.db.tasks import check_live_task_not_timed_out +from danswer.db.tasks import check_task_is_live_and_not_timed_out from danswer.db.tasks import get_latest_task from danswer.utils.logger import setup_logger from ee.danswer.background.task_name_builders import name_chat_ttl_task @@ -16,7 +16,7 @@ def should_sync_user_groups(user_group: UserGroup, db_session: Session) -> bool: task_name = name_user_group_sync_task(user_group.id) latest_sync = get_latest_task(task_name, db_session) - if latest_sync and check_live_task_not_timed_out(latest_sync, db_session): + if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session): logger.info("TTL check is already being performed. Skipping.") return False return True @@ -34,7 +34,7 @@ def should_perform_chat_ttl_check( if not latest_task: return True - if latest_task and check_live_task_not_timed_out(latest_task, db_session): + if latest_task and check_task_is_live_and_not_timed_out(latest_task, db_session): logger.info("TTL check is already being performed. Skipping.") return False return True