mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-07 19:38:19 +02:00
Disallowed simultaneous pruning jobs (#1704)
* Added TTL to EE Celery tasks * fixed alembic files * fixed frontend build issue and reworked file deletion * FileD * revert change * reworked delete chatmessage * added orphan cleanup * ensured syntax * Disallowed simultaneous pruning jobs * added rate limiting and env vars * i hope this is how you use decorators * nonsense * cleaned up names, added config * renamed other utils * Update celery_utils.py * reverted changes
This commit is contained in:
parent
3fe5313b02
commit
60dd77393d
@ -6,16 +6,23 @@ from sqlalchemy.orm import Session
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from danswer.connectors.interfaces import BaseConnector
|
||||
from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.tasks import check_live_task_not_timed_out
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.db.tasks import get_latest_task_by_type
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@ -47,7 +54,7 @@ def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
|
||||
task_name = name_document_set_sync_task(document_set.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_live_task_not_timed_out(latest_sync, db_session):
|
||||
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
|
||||
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
|
||||
return False
|
||||
|
||||
@ -73,7 +80,19 @@ def should_prune_cc_pair(
|
||||
return True
|
||||
return False
|
||||
|
||||
if check_live_task_not_timed_out(last_pruning_task, db_session):
|
||||
if PREVENT_SIMULTANEOUS_PRUNING:
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
)
|
||||
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
logger.info("Another Connector is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
@ -84,25 +103,36 @@ def should_prune_cc_pair(
|
||||
return time_since_last_pruning.total_seconds() >= connector.prune_freq
|
||||
|
||||
|
||||
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
|
||||
return {doc.id for doc in doc_batch}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
|
||||
"""
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
all docs using the load_from_state and grab out the IDs
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
doc_batch_generator = None
|
||||
if isinstance(runnable_connector, IdConnector):
|
||||
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
|
||||
elif isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
for doc_batch in doc_batch_generator:
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch)
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
|
||||
for doc_batch in doc_batch_generator:
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch)
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
if doc_batch_generator:
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
@ -22,8 +22,13 @@ def name_document_set_sync_task(document_set_id: int) -> str:
|
||||
return f"sync_doc_set_{document_set_id}"
|
||||
|
||||
|
||||
def name_cc_prune_task(connector_id: int, credential_id: int) -> str:
|
||||
return f"prune_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
def name_cc_prune_task(
|
||||
connector_id: int | None = None, credential_id: int | None = None
|
||||
) -> str:
|
||||
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
if not connector_id or not credential_id:
|
||||
task_name = "prune_connector_credential_pair"
|
||||
return task_name
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
@ -214,6 +214,15 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||
|
||||
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
|
||||
|
||||
PREVENT_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
|
||||
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Indexing Configs
|
||||
|
@ -26,6 +26,23 @@ def get_latest_task(
|
||||
return latest_task
|
||||
|
||||
|
||||
def get_latest_task_by_type(
|
||||
task_name: str,
|
||||
db_session: Session,
|
||||
) -> TaskQueueState | None:
|
||||
stmt = (
|
||||
select(TaskQueueState)
|
||||
.where(TaskQueueState.task_name.like(f"%{task_name}%"))
|
||||
.order_by(desc(TaskQueueState.id))
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
latest_task = result.scalars().first()
|
||||
|
||||
return latest_task
|
||||
|
||||
|
||||
def register_task(
|
||||
task_id: str,
|
||||
task_name: str,
|
||||
@ -66,7 +83,7 @@ def mark_task_finished(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def check_live_task_not_timed_out(
|
||||
def check_task_is_live_and_not_timed_out(
|
||||
task: TaskQueueState,
|
||||
db_session: Session,
|
||||
timeout: int = JOB_TIMEOUT,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.db.tasks import check_live_task_not_timed_out
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
@ -16,7 +16,7 @@ def should_sync_user_groups(user_group: UserGroup, db_session: Session) -> bool:
|
||||
task_name = name_user_group_sync_task(user_group.id)
|
||||
latest_sync = get_latest_task(task_name, db_session)
|
||||
|
||||
if latest_sync and check_live_task_not_timed_out(latest_sync, db_session):
|
||||
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
|
||||
logger.info("TTL check is already being performed. Skipping.")
|
||||
return False
|
||||
return True
|
||||
@ -34,7 +34,7 @@ def should_perform_chat_ttl_check(
|
||||
if not latest_task:
|
||||
return True
|
||||
|
||||
if latest_task and check_live_task_not_timed_out(latest_task, db_session):
|
||||
if latest_task and check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||
logger.info("TTL check is already being performed. Skipping.")
|
||||
return False
|
||||
return True
|
||||
|
Loading…
x
Reference in New Issue
Block a user