Feature/background prune 2 (#2583)

* first cut at redis

* some new helper functions for the db

* ignore kombu tables in alembic migrations (used by celery)

* multiline commands for readability, add vespa_metadata_sync queue to worker

* typo fix

* fix returning tuple fields

* add constants

* fix _get_access_for_document

* docstrings!

* fix double function declaration and typing

* fix type hinting

* add a global redis pool

* Add get_document function

* use task_logger in various celery tasks

* add celeryconfig.py to simplify configuration. Will be used in a subsequent commit

* Add celery redis helper. used in a subsequent PR

* kombu warning getting spammy since celery is not self managing its queue in Postgres any more

* add last_modified and last_synced to documents

* fix task naming convention

* use celeryconfig.py

* the big one. adds queues and tasks, updates functions to use the queues with priorities, etc

* change vespa index log line to debug

* mypy fixes

* update alembic migration

* fix fence ordering, rename to "monitor", fix fetch_versioned_implementation call

* mypy

* switch to monotonic time

* fix startup dependencies on redis

* rebase alembic migration

* kombu cleanup - fail silently

* mypy

* add redis_host environment override

* update REDIS_HOST env var in docker-compose.dev.yml

* update the rest of the docker files

* in flight

* harden indexing-status endpoint against db changes happening in the background.  Needs further improvement but OK for now.

* allow no task syncs to run because we create certain objects with no entries but initially marked as out of date

* add back writing to vespa on indexing

* actually working connector deletion

* update contributing guide

* backporting fixes from background_deletion

* renaming cache to cache_volume

* add redis password to various deployments

* try setting up pr testing for helm

* fix indent

* hopefully this release version actually exists

* fix command line option to --chart-dirs

* fetch-depth 0

* edit values.yaml

* try setting ct working directory

* bypass testing only on change for now

* move files and lint them

* update helm testing

* some issues suggest using --config works

* add vespa repo

* add postgresql repo

* increase timeout

* try amd64 runner

* fix redis password reference

* add comment to helm chart testing workflow

* rename helm testing workflow to disable it

* adding clarifying comments

* address code review

* missed a file

* remove commented warning ... just not needed

* fix imports

* refactor to use update_single

* mypy fixes

* add vespa test

* multiple celery workers

* update logs as well and set prefetch multipliers appropriate to the worker intent

* add db refresh to connector deletion

* add some preliminary locking

* organize tasks into separate files

* celery auto associates tasks created inside another task, which bloats the result metadata considerably. trail=False prevents this.

* code review fixes

* move monitor_usergroup_taskset to ee, improve logging

* add multi workers to dev_run_background_jobs.py

* update supervisord with some recommended settings for celery

* name celery workers and shorten dev script prefixing

* add configurable sql alchemy engine settings on startup (needed for various intents like API server, different celery workers and tasks, etc)

* fix comments

* autoscale sqlalchemy pool size to celery concurrency (allow override later?)

* supervisord needs the percent symbols escaped

* use name as primary check, some minor refactoring and type hinting too.

* stash merge (may not function yet)

* remove dead code

* more cleanup

* remove dead file

* we shouldn't be checking for deletion attempts in the db any more

* print cc_pair_id

* print status on status mismatch again

* add logging when cc_pair isn't present

* don't indexing any ingestion type connectors, and don't pause any connectors that aren't active

* add more specific check for deletion completion

* remove flaky mediawiki test site

* move is_pruning

* remove unused code

* remove old function

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
rkuo-danswer 2024-10-07 11:16:17 -07:00 committed by GitHub
parent 64909d74f9
commit 3404c7eb1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 644 additions and 455 deletions

View File

@ -0,0 +1,27 @@
"""add last_pruned to the connector_credential_pair table
Revision ID: ac5eaac849f9
Revises: 52a219fb5233
Create Date: 2024-09-10 15:04:26.437118
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ac5eaac849f9"
down_revision = "46b7a812670f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# last pruned represents the last time the connector was pruned
op.add_column(
"connector_credential_pair",
sa.Column("last_pruned", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "last_pruned")

View File

@ -19,6 +19,7 @@ from celery.utils.log import get_task_logger
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
@ -104,6 +105,13 @@ def celery_task_postrun(
r.srem(rcd.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcp = RedisConnectorPruning(cc_pair_id)
r.srem(rcp.taskset_key, task_id)
return
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
@ -236,6 +244,18 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
@ -330,7 +350,11 @@ def on_setup_logging(
class HubPeriodicTask(bootsteps.StartStopStep):
"""Regularly reacquires the primary worker lock outside of the task queue.
Use the task_logger in this class to avoid double logging."""
Use the task_logger in this class to avoid double logging.
This cannot be done inside a regular beat task because it must run on schedule and
a queue of existing work would starve the task from running.
"""
# it's unclear to me whether using the hub's timer or the bootstep timer is better
requires = {"celery.worker.components:Hub"}
@ -405,6 +429,7 @@ celery_app.autodiscover_tasks(
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
]
)
@ -425,7 +450,7 @@ celery_app.conf.beat_schedule.update(
"task": "check_for_connector_deletion_task",
# don't need to check too often, since we kick off a deletion initially
# during the API call that actually marks the CC pair for deletion
"schedule": timedelta(minutes=1),
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
@ -433,8 +458,8 @@ celery_app.conf.beat_schedule.update(
celery_app.conf.beat_schedule.update(
{
"check-for-prune": {
"task": "check_for_prune_task",
"schedule": timedelta(seconds=5),
"task": "check_for_prune_task_2",
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}

View File

@ -343,6 +343,125 @@ class RedisConnectorDeletion(RedisObjectHelper):
return len(async_results)
class RedisConnectorPruning(RedisObjectHelper):
"""Celery will kick off a long running generator task to crawl the connector and
find any missing docs, which will each then get a new cleanup task. The progress of
those tasks will then be monitored to completion.
Example rough happy path order:
Check connectorpruning_fence_1
Send generator task with id connectorpruning+generator_1_{uuid}
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
in taskset connectorpruning_taskset_1
on completion, generator sets connectorpruning_generator_complete_1
celery postrun removes subtasks from taskset
monitor beat task cleans up when taskset reaches 0 items
"""
PREFIX = "connectorpruning"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, id: int) -> None:
"""id: the cc_pair_id of the connector credential pair"""
super().__init__(id)
self.documents_to_prune: set[str] = set()
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
if not cc_pair:
return None
for doc_id in self.documents_to_prune:
current_time = time.monotonic()
if lock and current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc_id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=self._id, db_session=db_session
)
if not cc_pair:
raise ValueError(f"cc_pair_id {self._id} does not exist.")
if redis_client.exists(self.fence_key):
return True
return False
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists

View File

@ -1,3 +1,4 @@
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
@ -5,8 +6,6 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.task_utils import name_cc_prune_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
@ -17,14 +16,8 @@ from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.engine import get_db_current_time
from danswer.db.enums import TaskStatus
from danswer.db.models import Connector
from danswer.db.models import Credential
from danswer.db.models import TaskQueueState
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
@ -70,72 +63,19 @@ def get_deletion_attempt_snapshot(
)
def skip_cc_pair_pruning_by_task(
pruning_task: TaskQueueState | None, db_session: Session
) -> bool:
"""task should be the latest prune task for this cc_pair"""
if not ALLOW_SIMULTANEOUS_PRUNING:
# if only one prune is allowed at any time, then check to see if any prune
# is active
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
)
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
return True
if pruning_task and check_task_is_live_and_not_timed_out(pruning_task, db_session):
# if the last task is live right now, we shouldn't start a new one
return True
return False
def should_prune_cc_pair(
connector: Connector, credential: Credential, db_session: Session
) -> bool:
if not connector.prune_freq:
return False
pruning_task_name = name_cc_prune_task(
connector_id=connector.id, credential_id=credential.id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
if skip_cc_pair_pruning_by_task(last_pruning_task, db_session):
return False
current_db_time = get_db_current_time(db_session)
if not last_pruning_task:
# If the connector has never been pruned, then compare vs when the connector
# was created
time_since_initialization = current_db_time - connector.time_created
if time_since_initialization.total_seconds() >= connector.prune_freq:
return True
return False
if not last_pruning_task.start_time:
# if the last prune task hasn't started, we shouldn't start a new one
return False
# if the last prune task has a start time, then compare against it to determine
# if we should start
time_since_last_pruning = current_db_time - last_pruning_task.start_time
return time_since_last_pruning.total_seconds() >= connector.prune_freq
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
return {doc.id for doc in doc_batch}
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
progress_callback: Callable[[int], None] | None = None,
) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.
"""
all_connector_doc_ids: set[str] = set()
@ -158,6 +98,8 @@ def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> se
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
if progress_callback:
progress_callback(len(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids
@ -177,9 +119,10 @@ def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
def celery_is_worker_primary(worker: Any) -> bool:
"""There are multiple approaches that could be taken, but the way we do it is to
check the hostname set for the celery worker, either in celeryconfig.py or on the
command line."""
"""There are multiple approaches that could be taken to determine if a celery worker
is 'primary', as defined by us. But the way we do it is to check the hostname set
for the celery worker, which can be done either in celeryconfig.py or on the
command line with '--hostname'."""
hostname = worker.hostname
if hostname.startswith("light"):
return False

View File

@ -1,12 +1,12 @@
import redis
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from celery.utils.log import get_task_logger
from redis import Redis
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import ObjectDeletedError
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@ -14,17 +14,10 @@ from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import get_last_attempt
from danswer.db.models import ConnectorCredentialPair
from danswer.db.search_settings import get_current_search_settings
from danswer.redis.redis_pool import get_redis_client
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
@shared_task(
name="check_for_connector_deletion_task",
soft_time_limit=JOB_TIMEOUT,
@ -90,21 +83,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return None
search_settings = get_current_search_settings(db_session)
last_indexing = get_last_attempt(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
search_settings_id=search_settings.id,
db_session=db_session,
)
if last_indexing:
if (
last_indexing.status == IndexingStatus.IN_PROGRESS
or last_indexing.status == IndexingStatus.NOT_STARTED
):
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rcd.taskset_key)

View File

@ -7,18 +7,15 @@ from typing import Any
from celery import shared_task
from celery.contrib.abortable import AbortableTask # type: ignore
from celery.exceptions import TaskRevokedError
from celery.utils.log import get_task_logger
from sqlalchemy import inspect
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
@shared_task(
name="kombu_message_cleanup_task",

View File

@ -1,61 +1,165 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
import redis
from celery import shared_task
from celery.utils.log import get_task_logger
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.celery_utils import should_prune_cc_pair
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_prune_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
@shared_task(
name="check_for_prune_task",
name="check_for_prune_task_2",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task() -> None:
"""Runs periodically to check if any prune tasks should be run and adds them
to the queue"""
def check_for_prune_task_2() -> None:
r = get_redis_client()
with Session(get_sqlalchemy_engine()) as db_session:
all_cc_pairs = get_connector_credential_pairs(db_session)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
for cc_pair in all_cc_pairs:
if should_prune_cc_pair(
connector=cc_pair.connector,
credential=cc_pair.credential,
db_session=db_session,
):
task_logger.info(f"Pruning the {cc_pair.connector.name} connector")
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
prune_documents_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
tasks_created = ccpair_pruning_generator_task_creation_helper(
cc_pair, db_session, r, lock_beat
)
if not tasks_created:
continue
task_logger.info(f"Pruning started: cc_pair_id={cc_pair.id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(name="prune_documents_task", soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int) -> None:
def ccpair_pruning_generator_task_creation_helper(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
) -> int | None:
"""Returns an int if pruning is triggered.
The int represents the number of prune tasks generated (in this case, only one
because the task is a long running generator task.)
Returns None if no pruning is triggered (due to not being needed or
other reasons such as simultaneous pruning restrictions.
Checks for scheduling related conditions, then delegates the rest of the checks to
try_creating_prune_generator_task.
"""
lock_beat.reacquire()
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return None
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
# if never pruned, use the connector time created as the last_pruned time
last_pruned = cc_pair.connector.time_created
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return None
return try_creating_prune_generator_task(cc_pair, db_session, r)
def try_creating_prune_generator_task(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
) -> int | None:
"""Checks for any conditions that should block the pruning generator task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger prunes immediately.
"""
if not ALLOW_SIMULTANEOUS_PRUNING:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
return None
rcp = RedisConnectorPruning(cc_pair.id)
# skip pruning if already pruning
if r.exists(rcp.fence_key):
return None
# skip pruning if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
kwargs=dict(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
),
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
# set this only after all tasks have been added
r.set(rcp.fence_key, 1)
return 1
@shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT)
def connector_pruning_generator_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
r = get_redis_client()
with Session(get_sqlalchemy_engine()) as db_session:
try:
cc_pair = get_connector_credential_pair(
@ -70,6 +174,12 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
)
return
rcp = RedisConnectorPruning(cc_pair.id)
# Define the callback function
def redis_increment_callback(amount: int) -> None:
r.incrby(rcp.generator_progress_key, amount)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
@ -78,10 +188,12 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
cc_pair.credential,
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector
runnable_connector, redis_increment_callback
)
# a list of docs in our local index
all_indexed_document_ids = {
doc.id
for doc in get_documents_for_connector_credential_pair(
@ -91,30 +203,37 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
)
}
# generate list of docs to remove (no longer in the source)
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
task_logger.info(
f"Pruning set collected: "
f"cc_pair_id={cc_pair.id} "
f"docs_to_remove={len(doc_ids_to_remove)} "
f"doc_source={cc_pair.connector.source}"
)
if len(doc_ids_to_remove) == 0:
task_logger.info(
f"No docs to prune from {cc_pair.connector.source} connector"
)
return
rcp.documents_to_prune = set(doc_ids_to_remove)
task_logger.info(
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
delete_connector_credential_pair_batch(
document_ids=doc_ids_to_remove,
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
tasks_generated = rcp.generate_tasks(celery_app, db_session, r, None)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnectorPruning.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
task_logger.exception(
f"Failed to run pruning for connector id {connector_id}."
)
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e

View File

@ -0,0 +1,113 @@
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import task_logger
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
from danswer.db.document import get_document_connector_count
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@shared_task(
name="document_by_cc_pair_cleanup_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def document_by_cc_pair_cleanup_task(
self: Task, document_id: str, connector_id: int, credential_id: int
) -> bool:
"""A lightweight subtask used to clean up document to cc pair relationships.
Created by connection deletion and connector pruning parent tasks."""
"""
To delete a connector / credential pair:
(1) find all documents associated with connector / credential pair where there
this the is only connector / credential pair that has indexed it
(2) delete all documents from document stores
(3) delete all entries from postgres
(4) find all documents associated with connector / credential pair where there
are multiple connector / credential pairs that have indexed it
(5) update document store entries to remove access associated with the
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
count = get_document_connector_count(db_session, document_id)
if count == 1:
# count == 1 means this is the only remaining cc_pair reference to the doc
# delete it from vespa and the db
document_index.delete(doc_ids=[document_id])
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
)
elif count > 1:
# count > 1 means the document still has cc_pair references
doc = get_document(document_id, db_session)
if not doc:
return False
# the below functions do not include cc_pairs being deleted.
# i.e. they will correctly omit access for the current cc_pair
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
document_index.update_single(document_id, fields=fields)
# there are still other cc_pair references to the doc, so just resync to Vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_synced(document_id, db_session)
else:
pass
db_session.commit()
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@ -5,20 +5,22 @@ import redis
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.utils.log import get_task_logger
from redis import Redis
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_ccpair_as_pruned
from danswer.db.connector_credential_pair import add_deletion_failure_message
from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
@ -49,10 +51,6 @@ from danswer.utils.variable_functionality import (
from danswer.utils.variable_functionality import noop_fallback
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
@ -279,7 +277,7 @@ def monitor_document_set_taskset(
fence_key = key_bytes.decode("utf-8")
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id is None:
task_logger.warning("could not parse document set id from {key}")
task_logger.warning(f"could not parse document set id from {fence_key}")
return
rds = RedisDocumentSet(document_set_id)
@ -326,7 +324,7 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
task_logger.warning("could not parse document set id from {key}")
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
rcd = RedisConnectorDeletion(cc_pair_id)
@ -351,6 +349,9 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
task_logger.warning(
f"monitor_connector_deletion_taskset - cc_pair_id not found: cc_pair_id={cc_pair_id}"
)
return
try:
@ -402,20 +403,67 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
add_deletion_failure_message(db_session, cc_pair.id, error_message)
task_logger.exception(
f"Failed to run connector_deletion. "
f"connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
f"cc_pair_id={cc_pair_id} connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Successfully deleted connector_credential_pair with connector_id: '{cc_pair.connector_id}' "
f"and credential_id: '{cc_pair.credential_id}'. "
f"Deleted {initial_count} docs."
f"Successfully deleted cc_pair: "
f"cc_pair_id={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"docs_deleted={initial_count}"
)
r.delete(rcd.taskset_key)
r.delete(rcd.fence_key)
def monitor_ccpair_pruning_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
task_logger.warning(
f"monitor_connector_pruning_taskset: could not parse cc_pair_id from {fence_key}"
)
return
rcp = RedisConnectorPruning(cc_pair_id)
fence_value = r.get(rcp.fence_key)
if fence_value is None:
return
generator_value = r.get(rcp.generator_complete_key)
if generator_value is None:
return
try:
initial_count = int(cast(int, generator_value))
except ValueError:
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rcp.taskset_key))
task_logger.info(
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
mark_ccpair_as_pruned(cc_pair_id, db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
)
r.delete(rcp.taskset_key)
r.delete(rcp.generator_progress_key)
r.delete(rcp.generator_complete_key)
r.delete(rcp.fence_key)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300)
def monitor_vespa_sync() -> None:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
@ -457,6 +505,9 @@ def monitor_vespa_sync() -> None:
)
monitor_usergroup_taskset(key_bytes, r, db_session)
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)

View File

@ -1,211 +0,0 @@
"""
To delete a connector / credential pair:
(1) find all documents associated with connector / credential pair where there
this the is only connector / credential pair that has indexed it
(2) delete all documents from document stores
(3) delete all entries from postgres
(4) find all documents associated with connector / credential pair where there
are multiple connector / credential pairs that have indexed it
(5) update document store entries to remove access associated with the
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.utils.log import get_task_logger
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_document
from danswer.access.access import get_access_for_documents
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
from danswer.db.document import get_document_connector_count
from danswer.db.document import get_document_connector_counts
from danswer.db.document import mark_document_as_synced
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
logger = setup_logger()
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
_DELETION_BATCH_SIZE = 1000
def delete_connector_credential_pair_batch(
document_ids: list[str],
connector_id: int,
credential_id: int,
document_index: DocumentIndex,
) -> None:
"""
Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore
it gets permanently deleted.
"""
with Session(get_sqlalchemy_engine()) as db_session:
# acquire lock for all documents in this batch so that indexing can't
# override the deletion
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
document_connector_counts = get_document_connector_counts(
db_session=db_session, document_ids=document_ids
)
# figure out which docs need to be completely deleted
document_ids_to_delete = [
document_id
for document_id, cnt in document_connector_counts
if cnt == 1
]
logger.debug(f"Deleting documents: {document_ids_to_delete}")
document_index.delete(doc_ids=document_ids_to_delete)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=document_ids_to_delete,
)
# figure out which docs need to be updated
document_ids_to_update = [
document_id for document_id, cnt in document_connector_counts if cnt > 1
]
# maps document id to list of document set names
new_doc_sets_for_documents: dict[str, set[str]] = {
document_id_and_document_set_names_tuple[0]: set(
document_id_and_document_set_names_tuple[1]
)
for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents(
db_session=db_session,
document_ids=document_ids_to_update,
)
}
# determine future ACLs for documents in batch
access_for_documents = get_access_for_documents(
document_ids=document_ids_to_update,
db_session=db_session,
)
# update Vespa
logger.debug(f"Updating documents: {document_ids_to_update}")
update_requests = [
UpdateRequest(
document_ids=[document_id],
access=access,
document_sets=new_doc_sets_for_documents[document_id],
)
for document_id, access in access_for_documents.items()
]
document_index.update(update_requests=update_requests)
# clean up Postgres
delete_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids_to_update,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
db_session.commit()
@shared_task(
name="document_by_cc_pair_cleanup_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def document_by_cc_pair_cleanup_task(
self: Task, document_id: str, connector_id: int, credential_id: int
) -> bool:
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
count = get_document_connector_count(db_session, document_id)
if count == 1:
# count == 1 means this is the only remaining cc_pair reference to the doc
# delete it from vespa and the db
document_index.delete_single(doc_id=document_id)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
)
elif count > 1:
# count > 1 means the document still has cc_pair references
doc = get_document(document_id, db_session)
if not doc:
return False
# the below functions do not include cc_pairs being deleted.
# i.e. they will correctly omit access for the current cc_pair
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
document_index.update_single(document_id, fields=fields)
# there are still other cc_pair references to the doc, so just resync to Vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_synced(document_id, db_session)
else:
pass
# update_docs_last_modified__no_commit(
# db_session=db_session,
# document_ids=[document_id],
# )
db_session.commit()
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@ -187,10 +187,9 @@ class PostgresAdvisoryLocks(Enum):
class DanswerCeleryQueues:
VESPA_DOCSET_SYNC_GENERATOR = "vespa_docset_sync_generator"
VESPA_USERGROUP_SYNC_GENERATOR = "vespa_usergroup_sync_generator"
VESPA_METADATA_SYNC = "vespa_metadata_sync"
CONNECTOR_DELETION = "connector_deletion"
CONNECTOR_PRUNING = "connector_pruning"
class DanswerRedisLocks:
@ -198,7 +197,7 @@ class DanswerRedisLocks:
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
MONITOR_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:monitor_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
class DanswerCeleryPriority(int, Enum):

View File

@ -1,3 +1,5 @@
from datetime import datetime
from datetime import timezone
from typing import cast
from sqlalchemy import and_
@ -268,3 +270,15 @@ def create_initial_default_connector(db_session: Session) -> None:
)
db_session.add(connector)
db_session.commit()
def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None:
stmt = select(ConnectorCredentialPair).where(
ConnectorCredentialPair.id == cc_pair_id
)
cc_pair = db_session.scalar(stmt)
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
cc_pair.last_pruned = datetime.now(timezone.utc)
db_session.commit()

View File

@ -414,6 +414,12 @@ class ConnectorCredentialPair(Base):
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
)
# last successful prune
last_pruned: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, index=True
)
total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
connector: Mapped["Connector"] = relationship(

View File

@ -10,9 +10,11 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
from danswer.background.celery.celery_utils import skip_cc_pair_pruning_by_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
from danswer.db.connector_credential_pair import add_credential_to_connector
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import remove_credential_from_connector
@ -31,6 +33,7 @@ from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from danswer.db.models import User
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import CCPairFullInfo
from danswer.server.documents.models import CCStatusUpdateRequest
from danswer.server.documents.models import CeleryTaskStatus
@ -203,7 +206,7 @@ def get_cc_pair_latest_prune(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> CeleryTaskStatus:
) -> bool:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
@ -216,24 +219,8 @@ def get_cc_pair_latest_prune(
detail="Connection not found for current user's permissions",
)
# look up the last prune task for this connector (if it exists)
pruning_task_name = name_cc_prune_task(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
if not last_pruning_task:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="No pruning task found.",
)
return CeleryTaskStatus(
id=last_pruning_task.task_id,
name=last_pruning_task.task_name,
status=last_pruning_task.status,
start_time=last_pruning_task.start_time,
register_time=last_pruning_task.register_time,
)
rcp = RedisConnectorPruning(cc_pair.id)
return rcp.is_pruning(db_session, get_redis_client())
@router.post("/admin/cc-pair/{cc_pair_id}/prune")
@ -242,8 +229,7 @@ def prune_cc_pair(
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[list[int]]:
# avoiding circular refs
from danswer.background.celery.tasks.pruning.tasks import prune_documents_task
"""Triggers pruning on a particular cc_pair immediately"""
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
@ -257,26 +243,26 @@ def prune_cc_pair(
detail="Connection not found for current user's permissions",
)
pruning_task_name = name_cc_prune_task(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
last_pruning_task = get_latest_task(pruning_task_name, db_session)
if skip_cc_pair_pruning_by_task(
last_pruning_task,
db_session=db_session,
):
r = get_redis_client()
rcp = RedisConnectorPruning(cc_pair_id)
if rcp.is_pruning(db_session, r):
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Pruning task already in progress.",
)
logger.info(f"Pruning the {cc_pair.connector.name} connector.")
prune_documents_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
logger.info(
f"Pruning cc_pair: cc_pair_id={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_prune_generator_task(cc_pair, db_session, r)
if not tasks_created:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Pruning task creation failed.",
)
return StatusResponse(
success=True,
@ -353,14 +339,6 @@ def sync_cc_pair(
status_code=HTTPStatus.CONFLICT,
detail="Sync task already in progress.",
)
if skip_cc_pair_pruning_by_task(
last_sync_task,
db_session=db_session,
):
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Sync task already in progress.",
)
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
sync_external_doc_permissions_task.apply_async(

View File

@ -15,10 +15,10 @@ logger = setup_logger()
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
"""This function is likely to move in the worker refactor happening next."""
key = key_bytes.decode("utf-8")
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
fence_key = key_bytes.decode("utf-8")
usergroup_id = RedisUserGroup.get_id_from_fence_key(fence_key)
if not usergroup_id:
task_logger.warning("Could not parse usergroup id from {key}")
task_logger.warning(f"Could not parse usergroup id from {fence_key}")
return
rug = RedisUserGroup(usergroup_id)

View File

@ -603,7 +603,7 @@ def delete_user_group_cc_pair_relationship__no_commit(
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
raise ValueError(
f"Connector Credential Pair '{cc_pair_id}' is not in the DELETING state"
f"Connector Credential Pair '{cc_pair_id}' is not in the DELETING state. status={cc_pair.status}"
)
delete_stmt = delete(UserGroup__ConnectorCredentialPair).where(

View File

@ -274,10 +274,10 @@ class CCPairManager:
result.raise_for_status()
@staticmethod
def get_prune_task(
def is_pruning(
cc_pair: DATestCCPair,
user_performing_action: DATestUser | None = None,
) -> CeleryTaskStatus:
) -> bool:
response = requests.get(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
headers=user_performing_action.headers
@ -285,28 +285,21 @@ class CCPairManager:
else GENERAL_HEADERS,
)
response.raise_for_status()
return CeleryTaskStatus(**response.json())
response_bool = response.json()
return response_bool
@staticmethod
def wait_for_prune(
cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""after: The task register time must be after this time."""
start = time.monotonic()
while True:
task = CCPairManager.get_prune_task(cc_pair, user_performing_action)
if not task:
raise ValueError("Prune task not found.")
if not task.register_time or task.register_time < after:
raise ValueError("Prune task register time is too early.")
if task.status == TaskStatus.SUCCESS:
# Pruning succeeded
return
result = CCPairManager.is_pruning(cc_pair, user_performing_action)
if not result:
break
elapsed = time.monotonic() - start
if elapsed > timeout:
@ -380,16 +373,31 @@ class CCPairManager:
@staticmethod
def wait_for_deletion_completion(
cc_pair_id: int | None = None,
user_performing_action: DATestUser | None = None,
) -> None:
"""if cc_pair_id is not specified, just waits until no connectors are in the deleting state.
if cc_pair_id is specified, checks to ensure the specific cc_pair_id is gone.
We had a bug where the connector was paused in the middle of deleting, so specifying the
cc_pair_id is good to do."""
start = time.monotonic()
while True:
fetched_cc_pairs = CCPairManager.get_all(user_performing_action)
if all(
cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING
for cc_pair in fetched_cc_pairs
):
return
cc_pairs = CCPairManager.get_all(user_performing_action)
if cc_pair_id:
found = False
for cc_pair in cc_pairs:
if cc_pair.cc_pair_id == cc_pair_id:
found = True
break
if not found:
return
else:
if all(
cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING
for cc_pair in cc_pairs
):
return
if time.monotonic() - start > MAX_DELAY:
raise TimeoutError(

View File

@ -195,9 +195,8 @@ def test_slack_prune(
)
# Prune the cc_pair
before = datetime.now(timezone.utc)
CCPairManager.prune(cc_pair, user_performing_action=admin_user)
CCPairManager.wait_for_prune(cc_pair, before, user_performing_action=admin_user)
CCPairManager.wait_for_prune(cc_pair, user_performing_action=admin_user)
# ----------------------------VERIFY THE CHANGES---------------------------
# Ensure admin user can't see deleted messages

View File

@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import create_index_attempt_error
from danswer.db.models import IndexAttempt
from danswer.db.search_settings import get_current_search_settings
@ -117,6 +118,22 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
user_performing_action=admin_user,
)
# inject an index attempt and index attempt error (exercises foreign key errors)
with Session(get_sqlalchemy_engine()) as db_session:
attempt_id = create_index_attempt(
connector_credential_pair_id=cc_pair_1.id,
search_settings_id=1,
db_session=db_session,
)
create_index_attempt_error(
index_attempt_id=attempt_id,
batch=1,
docs=[],
exception_msg="",
exception_traceback="",
db_session=db_session,
)
# Update local records to match the database for later comparison
user_group_1.cc_pair_ids = []
user_group_2.cc_pair_ids = [cc_pair_2.id]
@ -125,7 +142,9 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
cc_pair_1.groups = []
cc_pair_2.groups = [user_group_2.id]
CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user)
CCPairManager.wait_for_deletion_completion(
cc_pair_id=cc_pair_1.id, user_performing_action=admin_user
)
# validate vespa documents
DocumentManager.verify(
@ -303,7 +322,9 @@ def test_connector_deletion_for_overlapping_connectors(
)
# wait for deletion to finish
CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user)
CCPairManager.wait_for_deletion_completion(
cc_pair_id=cc_pair_1.id, user_performing_action=admin_user
)
print("Connector 1 deleted")

View File

@ -171,7 +171,9 @@ def test_cc_pair_permissions(reset: None) -> None:
# Test deleting the cc pair
CCPairManager.delete(valid_cc_pair, user_performing_action=curator)
CCPairManager.wait_for_deletion_completion(user_performing_action=curator)
CCPairManager.wait_for_deletion_completion(
cc_pair_id=valid_cc_pair.id, user_performing_action=curator
)
CCPairManager.verify(
cc_pair=valid_cc_pair,

View File

@ -77,7 +77,9 @@ def test_whole_curator_flow(reset: None) -> None:
# Verify that the curator can delete the CC pair
CCPairManager.delete(cc_pair=test_cc_pair, user_performing_action=curator)
CCPairManager.wait_for_deletion_completion(user_performing_action=curator)
CCPairManager.wait_for_deletion_completion(
cc_pair_id=test_cc_pair.id, user_performing_action=curator
)
# Verify that the CC pair has been deleted
CCPairManager.verify(
@ -158,7 +160,9 @@ def test_global_curator_flow(reset: None) -> None:
# Verify that the curator can delete the CC pair
CCPairManager.delete(cc_pair=test_cc_pair, user_performing_action=global_curator)
CCPairManager.wait_for_deletion_completion(user_performing_action=global_curator)
CCPairManager.wait_for_deletion_completion(
cc_pair_id=test_cc_pair.id, user_performing_action=global_curator
)
# Verify that the CC pair has been deleted
CCPairManager.verify(

View File

@ -105,12 +105,9 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
logger.info("Removing courses.html.")
os.remove(os.path.join(website_tgt, "courses.html"))
# store the time again as a reference for the pruning timestamps
now = datetime.now(timezone.utc)
CCPairManager.prune(cc_pair_1, user_performing_action=admin_user)
CCPairManager.wait_for_prune(
cc_pair_1, now, timeout=60, user_performing_action=admin_user
cc_pair_1, timeout=60, user_performing_action=admin_user
)
selected_cc_pair = CCPairManager.get_one(

View File

@ -10,7 +10,7 @@ from danswer.connectors.mediawiki import family
NON_BUILTIN_WIKIS: Final[list[tuple[str, str]]] = [
("https://fallout.fandom.com", "falloutwiki"),
("https://harrypotter.fandom.com/wiki/", "harrypotterwiki"),
("https://artofproblemsolving.com/wiki", "artofproblemsolving"),
# ("https://artofproblemsolving.com/wiki", "artofproblemsolving"), # FLAKY
("https://www.bogleheads.org/wiki/Main_Page", "bogleheadswiki"),
("https://bogleheads.org/wiki/Main_Page", "bogleheadswiki"),
("https://www.dandwiki.com/wiki/", "dungeonsanddragons"),