mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 13:05:49 +02:00
Disallowed simultaneous pruning jobs (#1704)
* Added TTL to EE Celery tasks * fixed alembic files * fixed frontend build issue and reworked file deletion * FileD * revert change * reworked delete chatmessage * added orphan cleanup * ensured syntax * Disallowed simultaneous pruning jobs * added rate limiting and env vars * i hope this is how you use decorators * nonsense * cleaned up names, added config * renamed other utils * Update celery_utils.py * reverted changes
This commit is contained in:
@@ -6,16 +6,23 @@ from sqlalchemy.orm import Session
|
|||||||
from danswer.background.task_utils import name_cc_cleanup_task
|
from danswer.background.task_utils import name_cc_cleanup_task
|
||||||
from danswer.background.task_utils import name_cc_prune_task
|
from danswer.background.task_utils import name_cc_prune_task
|
||||||
from danswer.background.task_utils import name_document_set_sync_task
|
from danswer.background.task_utils import name_document_set_sync_task
|
||||||
|
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||||
|
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
|
||||||
|
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||||
|
rate_limit_builder,
|
||||||
|
)
|
||||||
from danswer.connectors.interfaces import BaseConnector
|
from danswer.connectors.interfaces import BaseConnector
|
||||||
from danswer.connectors.interfaces import IdConnector
|
from danswer.connectors.interfaces import IdConnector
|
||||||
from danswer.connectors.interfaces import LoadConnector
|
from danswer.connectors.interfaces import LoadConnector
|
||||||
from danswer.connectors.interfaces import PollConnector
|
from danswer.connectors.interfaces import PollConnector
|
||||||
|
from danswer.connectors.models import Document
|
||||||
from danswer.db.engine import get_db_current_time
|
from danswer.db.engine import get_db_current_time
|
||||||
from danswer.db.models import Connector
|
from danswer.db.models import Connector
|
||||||
from danswer.db.models import Credential
|
from danswer.db.models import Credential
|
||||||
from danswer.db.models import DocumentSet
|
from danswer.db.models import DocumentSet
|
||||||
from danswer.db.tasks import check_live_task_not_timed_out
|
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||||
from danswer.db.tasks import get_latest_task
|
from danswer.db.tasks import get_latest_task
|
||||||
|
from danswer.db.tasks import get_latest_task_by_type
|
||||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
@@ -47,7 +54,7 @@ def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
|
|||||||
task_name = name_document_set_sync_task(document_set.id)
|
task_name = name_document_set_sync_task(document_set.id)
|
||||||
latest_sync = get_latest_task(task_name, db_session)
|
latest_sync = get_latest_task(task_name, db_session)
|
||||||
|
|
||||||
if latest_sync and check_live_task_not_timed_out(latest_sync, db_session):
|
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
|
||||||
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
|
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -73,7 +80,19 @@ def should_prune_cc_pair(
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if check_live_task_not_timed_out(last_pruning_task, db_session):
|
if PREVENT_SIMULTANEOUS_PRUNING:
|
||||||
|
pruning_type_task_name = name_cc_prune_task()
|
||||||
|
last_pruning_type_task = get_latest_task_by_type(
|
||||||
|
pruning_type_task_name, db_session
|
||||||
|
)
|
||||||
|
|
||||||
|
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||||
|
last_pruning_type_task, db_session
|
||||||
|
):
|
||||||
|
logger.info("Another Connector is already pruning. Skipping.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||||
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
|
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -84,25 +103,36 @@ def should_prune_cc_pair(
|
|||||||
return time_since_last_pruning.total_seconds() >= connector.prune_freq
|
return time_since_last_pruning.total_seconds() >= connector.prune_freq
|
||||||
|
|
||||||
|
|
||||||
|
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
|
||||||
|
return {doc.id for doc in doc_batch}
|
||||||
|
|
||||||
|
|
||||||
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
|
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
|
||||||
"""
|
"""
|
||||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||||
all docs using the load_from_state and grab out the IDs
|
all docs using the load_from_state and grab out the IDs
|
||||||
"""
|
"""
|
||||||
all_connector_doc_ids: set[str] = set()
|
all_connector_doc_ids: set[str] = set()
|
||||||
|
|
||||||
|
doc_batch_generator = None
|
||||||
if isinstance(runnable_connector, IdConnector):
|
if isinstance(runnable_connector, IdConnector):
|
||||||
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
|
all_connector_doc_ids = runnable_connector.retrieve_all_source_ids()
|
||||||
elif isinstance(runnable_connector, LoadConnector):
|
elif isinstance(runnable_connector, LoadConnector):
|
||||||
doc_batch_generator = runnable_connector.load_from_state()
|
doc_batch_generator = runnable_connector.load_from_state()
|
||||||
for doc_batch in doc_batch_generator:
|
|
||||||
all_connector_doc_ids.update(doc.id for doc in doc_batch)
|
|
||||||
elif isinstance(runnable_connector, PollConnector):
|
elif isinstance(runnable_connector, PollConnector):
|
||||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||||
end = datetime.now(timezone.utc).timestamp()
|
end = datetime.now(timezone.utc).timestamp()
|
||||||
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
|
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
|
||||||
for doc_batch in doc_batch_generator:
|
|
||||||
all_connector_doc_ids.update(doc.id for doc in doc_batch)
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||||
|
|
||||||
|
if doc_batch_generator:
|
||||||
|
doc_batch_processing_func = document_batch_to_ids
|
||||||
|
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||||
|
doc_batch_processing_func = rate_limit_builder(
|
||||||
|
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||||
|
)(document_batch_to_ids)
|
||||||
|
for doc_batch in doc_batch_generator:
|
||||||
|
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||||
|
|
||||||
return all_connector_doc_ids
|
return all_connector_doc_ids
|
||||||
|
@@ -22,8 +22,13 @@ def name_document_set_sync_task(document_set_id: int) -> str:
|
|||||||
return f"sync_doc_set_{document_set_id}"
|
return f"sync_doc_set_{document_set_id}"
|
||||||
|
|
||||||
|
|
||||||
def name_cc_prune_task(connector_id: int, credential_id: int) -> str:
|
def name_cc_prune_task(
|
||||||
return f"prune_connector_credential_pair_{connector_id}_{credential_id}"
|
connector_id: int | None = None, credential_id: int | None = None
|
||||||
|
) -> str:
|
||||||
|
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
|
||||||
|
if not connector_id or not credential_id:
|
||||||
|
task_name = "prune_connector_credential_pair"
|
||||||
|
return task_name
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=Callable)
|
T = TypeVar("T", bound=Callable)
|
||||||
|
@@ -214,6 +214,15 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
|||||||
|
|
||||||
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
|
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
|
||||||
|
|
||||||
|
PREVENT_SIMULTANEOUS_PRUNING = (
|
||||||
|
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||||
|
MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int(
|
||||||
|
os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# Indexing Configs
|
# Indexing Configs
|
||||||
|
@@ -26,6 +26,23 @@ def get_latest_task(
|
|||||||
return latest_task
|
return latest_task
|
||||||
|
|
||||||
|
|
||||||
|
def get_latest_task_by_type(
|
||||||
|
task_name: str,
|
||||||
|
db_session: Session,
|
||||||
|
) -> TaskQueueState | None:
|
||||||
|
stmt = (
|
||||||
|
select(TaskQueueState)
|
||||||
|
.where(TaskQueueState.task_name.like(f"%{task_name}%"))
|
||||||
|
.order_by(desc(TaskQueueState.id))
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = db_session.execute(stmt)
|
||||||
|
latest_task = result.scalars().first()
|
||||||
|
|
||||||
|
return latest_task
|
||||||
|
|
||||||
|
|
||||||
def register_task(
|
def register_task(
|
||||||
task_id: str,
|
task_id: str,
|
||||||
task_name: str,
|
task_name: str,
|
||||||
@@ -66,7 +83,7 @@ def mark_task_finished(
|
|||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
def check_live_task_not_timed_out(
|
def check_task_is_live_and_not_timed_out(
|
||||||
task: TaskQueueState,
|
task: TaskQueueState,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
timeout: int = JOB_TIMEOUT,
|
timeout: int = JOB_TIMEOUT,
|
||||||
|
@@ -1,7 +1,7 @@
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.db.models import UserGroup
|
from danswer.db.models import UserGroup
|
||||||
from danswer.db.tasks import check_live_task_not_timed_out
|
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||||
from danswer.db.tasks import get_latest_task
|
from danswer.db.tasks import get_latest_task
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||||
@@ -16,7 +16,7 @@ def should_sync_user_groups(user_group: UserGroup, db_session: Session) -> bool:
|
|||||||
task_name = name_user_group_sync_task(user_group.id)
|
task_name = name_user_group_sync_task(user_group.id)
|
||||||
latest_sync = get_latest_task(task_name, db_session)
|
latest_sync = get_latest_task(task_name, db_session)
|
||||||
|
|
||||||
if latest_sync and check_live_task_not_timed_out(latest_sync, db_session):
|
if latest_sync and check_task_is_live_and_not_timed_out(latest_sync, db_session):
|
||||||
logger.info("TTL check is already being performed. Skipping.")
|
logger.info("TTL check is already being performed. Skipping.")
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@@ -34,7 +34,7 @@ def should_perform_chat_ttl_check(
|
|||||||
if not latest_task:
|
if not latest_task:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if latest_task and check_live_task_not_timed_out(latest_task, db_session):
|
if latest_task and check_task_is_live_and_not_timed_out(latest_task, db_session):
|
||||||
logger.info("TTL check is already being performed. Skipping.")
|
logger.info("TTL check is already being performed. Skipping.")
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
Reference in New Issue
Block a user