mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-12 04:40:09 +02:00
Feature/background prune 2 (#2583)
* first cut at redis * some new helper functions for the db * ignore kombu tables in alembic migrations (used by celery) * multiline commands for readability, add vespa_metadata_sync queue to worker * typo fix * fix returning tuple fields * add constants * fix _get_access_for_document * docstrings! * fix double function declaration and typing * fix type hinting * add a global redis pool * Add get_document function * use task_logger in various celery tasks * add celeryconfig.py to simplify configuration. Will be used in a subsequent commit * Add celery redis helper. used in a subsequent PR * kombu warning getting spammy since celery is not self managing its queue in Postgres any more * add last_modified and last_synced to documents * fix task naming convention * use celeryconfig.py * the big one. adds queues and tasks, updates functions to use the queues with priorities, etc * change vespa index log line to debug * mypy fixes * update alembic migration * fix fence ordering, rename to "monitor", fix fetch_versioned_implementation call * mypy * switch to monotonic time * fix startup dependencies on redis * rebase alembic migration * kombu cleanup - fail silently * mypy * add redis_host environment override * update REDIS_HOST env var in docker-compose.dev.yml * update the rest of the docker files * in flight * harden indexing-status endpoint against db changes happening in the background. Needs further improvement but OK for now. * allow no task syncs to run because we create certain objects with no entries but initially marked as out of date * add back writing to vespa on indexing * actually working connector deletion * update contributing guide * backporting fixes from background_deletion * renaming cache to cache_volume * add redis password to various deployments * try setting up pr testing for helm * fix indent * hopefully this release version actually exists * fix command line option to --chart-dirs * fetch-depth 0 * edit values.yaml * try setting ct working directory * bypass testing only on change for now * move files and lint them * update helm testing * some issues suggest using --config works * add vespa repo * add postgresql repo * increase timeout * try amd64 runner * fix redis password reference * add comment to helm chart testing workflow * rename helm testing workflow to disable it * adding clarifying comments * address code review * missed a file * remove commented warning ... just not needed * fix imports * refactor to use update_single * mypy fixes * add vespa test * multiple celery workers * update logs as well and set prefetch multipliers appropriate to the worker intent * add db refresh to connector deletion * add some preliminary locking * organize tasks into separate files * celery auto associates tasks created inside another task, which bloats the result metadata considerably. trail=False prevents this. * code review fixes * move monitor_usergroup_taskset to ee, improve logging * add multi workers to dev_run_background_jobs.py * update supervisord with some recommended settings for celery * name celery workers and shorten dev script prefixing * add configurable sql alchemy engine settings on startup (needed for various intents like API server, different celery workers and tasks, etc) * fix comments * autoscale sqlalchemy pool size to celery concurrency (allow override later?) * supervisord needs the percent symbols escaped * use name as primary check, some minor refactoring and type hinting too. * stash merge (may not function yet) * remove dead code * more cleanup * remove dead file * we shouldn't be checking for deletion attempts in the db any more * print cc_pair_id * print status on status mismatch again * add logging when cc_pair isn't present * don't indexing any ingestion type connectors, and don't pause any connectors that aren't active * add more specific check for deletion completion * remove flaky mediawiki test site * move is_pruning * remove unused code * remove old function --------- Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
parent
64909d74f9
commit
3404c7eb1d
@ -0,0 +1,27 @@
|
||||
"""add last_pruned to the connector_credential_pair table
|
||||
|
||||
Revision ID: ac5eaac849f9
|
||||
Revises: 52a219fb5233
|
||||
Create Date: 2024-09-10 15:04:26.437118
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ac5eaac849f9"
|
||||
down_revision = "46b7a812670f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# last pruned represents the last time the connector was pruned
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("last_pruned", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector_credential_pair", "last_pruned")
|
@ -19,6 +19,7 @@ from celery.utils.log import get_task_logger
|
||||
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
@ -104,6 +105,13 @@ def celery_task_postrun(
|
||||
r.srem(rcd.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
|
||||
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
r.srem(rcp.taskset_key, task_id)
|
||||
return
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
@ -236,6 +244,18 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
@ -330,7 +350,11 @@ def on_setup_logging(
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||
Use the task_logger in this class to avoid double logging."""
|
||||
Use the task_logger in this class to avoid double logging.
|
||||
|
||||
This cannot be done inside a regular beat task because it must run on schedule and
|
||||
a queue of existing work would starve the task from running.
|
||||
"""
|
||||
|
||||
# it's unclear to me whether using the hub's timer or the bootstep timer is better
|
||||
requires = {"celery.worker.components:Hub"}
|
||||
@ -405,6 +429,7 @@ celery_app.autodiscover_tasks(
|
||||
"danswer.background.celery.tasks.connector_deletion",
|
||||
"danswer.background.celery.tasks.periodic",
|
||||
"danswer.background.celery.tasks.pruning",
|
||||
"danswer.background.celery.tasks.shared",
|
||||
"danswer.background.celery.tasks.vespa",
|
||||
]
|
||||
)
|
||||
@ -425,7 +450,7 @@ celery_app.conf.beat_schedule.update(
|
||||
"task": "check_for_connector_deletion_task",
|
||||
# don't need to check too often, since we kick off a deletion initially
|
||||
# during the API call that actually marks the CC pair for deletion
|
||||
"schedule": timedelta(minutes=1),
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
@ -433,8 +458,8 @@ celery_app.conf.beat_schedule.update(
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"check-for-prune": {
|
||||
"task": "check_for_prune_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"task": "check_for_prune_task_2",
|
||||
"schedule": timedelta(seconds=60),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
|
@ -343,6 +343,125 @@ class RedisConnectorDeletion(RedisObjectHelper):
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorPruning(RedisObjectHelper):
|
||||
"""Celery will kick off a long running generator task to crawl the connector and
|
||||
find any missing docs, which will each then get a new cleanup task. The progress of
|
||||
those tasks will then be monitored to completion.
|
||||
|
||||
Example rough happy path order:
|
||||
Check connectorpruning_fence_1
|
||||
Send generator task with id connectorpruning+generator_1_{uuid}
|
||||
|
||||
generator runs connector with callbacks that increment connectorpruning_generator_progress_1
|
||||
generator creates many subtasks with id connectorpruning+sub_1_{uuid}
|
||||
in taskset connectorpruning_taskset_1
|
||||
on completion, generator sets connectorpruning_generator_complete_1
|
||||
|
||||
celery postrun removes subtasks from taskset
|
||||
monitor beat task cleans up when taskset reaches 0 items
|
||||
"""
|
||||
|
||||
PREFIX = "connectorpruning"
|
||||
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire pruning process
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
|
||||
|
||||
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
|
||||
SUBTASK_PREFIX = PREFIX + "+sub"
|
||||
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # a signal that contains generator progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # a signal that the generator has finished
|
||||
|
||||
def __init__(self, id: int) -> None:
|
||||
"""id: the cc_pair_id of the connector credential pair"""
|
||||
|
||||
super().__init__(id)
|
||||
self.documents_to_prune: set[str] = set()
|
||||
|
||||
@property
|
||||
def generator_task_id_prefix(self) -> str:
|
||||
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_progress_key(self) -> str:
|
||||
# example: connectorpruning_generator_progress_1
|
||||
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def generator_complete_key(self) -> str:
|
||||
# example: connectorpruning_generator_complete_1
|
||||
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
|
||||
|
||||
@property
|
||||
def subtask_id_prefix(self) -> str:
|
||||
return f"{self.SUBTASK_PREFIX}_{self._id}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock | None,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
for doc_id in self.documents_to_prune:
|
||||
current_time = time.monotonic()
|
||||
if lock and current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.subtask_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc_id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
|
||||
"""A single example of a helper method being refactored into the redis helper"""
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=self._id, db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"cc_pair_id {self._id} does not exist.")
|
||||
|
||||
if redis_client.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@ -5,8 +6,6 @@ from typing import Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
@ -17,14 +16,8 @@ from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.db.tasks import get_latest_task_by_type
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -70,72 +63,19 @@ def get_deletion_attempt_snapshot(
|
||||
)
|
||||
|
||||
|
||||
def skip_cc_pair_pruning_by_task(
|
||||
pruning_task: TaskQueueState | None, db_session: Session
|
||||
) -> bool:
|
||||
"""task should be the latest prune task for this cc_pair"""
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
# if only one prune is allowed at any time, then check to see if any prune
|
||||
# is active
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
)
|
||||
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
return True
|
||||
|
||||
if pruning_task and check_task_is_live_and_not_timed_out(pruning_task, db_session):
|
||||
# if the last task is live right now, we shouldn't start a new one
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def should_prune_cc_pair(
|
||||
connector: Connector, credential: Credential, db_session: Session
|
||||
) -> bool:
|
||||
if not connector.prune_freq:
|
||||
return False
|
||||
|
||||
pruning_task_name = name_cc_prune_task(
|
||||
connector_id=connector.id, credential_id=credential.id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
|
||||
if skip_cc_pair_pruning_by_task(last_pruning_task, db_session):
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
|
||||
if not last_pruning_task:
|
||||
# If the connector has never been pruned, then compare vs when the connector
|
||||
# was created
|
||||
time_since_initialization = current_db_time - connector.time_created
|
||||
if time_since_initialization.total_seconds() >= connector.prune_freq:
|
||||
return True
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
# if the last prune task hasn't started, we shouldn't start a new one
|
||||
return False
|
||||
|
||||
# if the last prune task has a start time, then compare against it to determine
|
||||
# if we should start
|
||||
time_since_last_pruning = current_db_time - last_pruning_task.start_time
|
||||
return time_since_last_pruning.total_seconds() >= connector.prune_freq
|
||||
|
||||
|
||||
def document_batch_to_ids(doc_batch: list[Document]) -> set[str]:
|
||||
return {doc.id for doc in doc_batch}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> set[str]:
|
||||
def extract_ids_from_runnable_connector(
|
||||
runnable_connector: BaseConnector,
|
||||
progress_callback: Callable[[int], None] | None = None,
|
||||
) -> set[str]:
|
||||
"""
|
||||
If the PruneConnector hasnt been implemented for the given connector, just pull
|
||||
all docs using the load_from_state and grab out the IDs
|
||||
all docs using the load_from_state and grab out the IDs.
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
|
||||
@ -158,6 +98,8 @@ def extract_ids_from_runnable_connector(runnable_connector: BaseConnector) -> se
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
if progress_callback:
|
||||
progress_callback(len(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
|
||||
return all_connector_doc_ids
|
||||
@ -177,9 +119,10 @@ def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
|
||||
|
||||
|
||||
def celery_is_worker_primary(worker: Any) -> bool:
|
||||
"""There are multiple approaches that could be taken, but the way we do it is to
|
||||
check the hostname set for the celery worker, either in celeryconfig.py or on the
|
||||
command line."""
|
||||
"""There are multiple approaches that could be taken to determine if a celery worker
|
||||
is 'primary', as defined by us. But the way we do it is to check the hostname set
|
||||
for the celery worker, which can be done either in celeryconfig.py or on the
|
||||
command line with '--hostname'."""
|
||||
hostname = worker.hostname
|
||||
if hostname.startswith("light"):
|
||||
return False
|
||||
|
@ -1,12 +1,12 @@
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.utils.log import get_task_logger
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@ -14,17 +14,10 @@ from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
@ -90,21 +83,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
last_indexing = get_last_attempt(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
search_settings_id=search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if last_indexing:
|
||||
if (
|
||||
last_indexing.status == IndexingStatus.IN_PROGRESS
|
||||
or last_indexing.status == IndexingStatus.NOT_STARTED
|
||||
):
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
|
@ -7,18 +7,15 @@ from typing import Any
|
||||
from celery import shared_task
|
||||
from celery.contrib.abortable import AbortableTask # type: ignore
|
||||
from celery.exceptions import TaskRevokedError
|
||||
from celery.utils.log import get_task_logger
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="kombu_message_cleanup_task",
|
||||
|
@ -1,61 +1,165 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import shared_task
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.celery_utils import should_prune_cc_pair
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="check_for_prune_task",
|
||||
name="check_for_prune_task_2",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_prune_task() -> None:
|
||||
"""Runs periodically to check if any prune tasks should be run and adds them
|
||||
to the queue"""
|
||||
def check_for_prune_task_2() -> None:
|
||||
r = get_redis_client()
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
for cc_pair in all_cc_pairs:
|
||||
if should_prune_cc_pair(
|
||||
connector=cc_pair.connector,
|
||||
credential=cc_pair.credential,
|
||||
db_session=db_session,
|
||||
):
|
||||
task_logger.info(f"Pruning the {cc_pair.connector.name} connector")
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
prune_documents_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
tasks_created = ccpair_pruning_generator_task_creation_helper(
|
||||
cc_pair, db_session, r, lock_beat
|
||||
)
|
||||
if not tasks_created:
|
||||
continue
|
||||
|
||||
task_logger.info(f"Pruning started: cc_pair_id={cc_pair.id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_cc_prune_task)
|
||||
@celery_app.task(name="prune_documents_task", soft_time_limit=JOB_TIMEOUT)
|
||||
def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
def ccpair_pruning_generator_task_creation_helper(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
"""Returns an int if pruning is triggered.
|
||||
The int represents the number of prune tasks generated (in this case, only one
|
||||
because the task is a long running generator task.)
|
||||
Returns None if no pruning is triggered (due to not being needed or
|
||||
other reasons such as simultaneous pruning restrictions.
|
||||
|
||||
Checks for scheduling related conditions, then delegates the rest of the checks to
|
||||
try_creating_prune_generator_task.
|
||||
"""
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
# skip pruning if no prune frequency is set
|
||||
# pruning can still be forced via the API which will run a pruning task directly
|
||||
if not cc_pair.connector.prune_freq:
|
||||
return None
|
||||
|
||||
# skip pruning if the next scheduled prune time hasn't been reached yet
|
||||
last_pruned = cc_pair.last_pruned
|
||||
if not last_pruned:
|
||||
# if never pruned, use the connector time created as the last_pruned time
|
||||
last_pruned = cc_pair.connector.time_created
|
||||
|
||||
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
|
||||
if datetime.now(timezone.utc) < next_prune:
|
||||
return None
|
||||
|
||||
return try_creating_prune_generator_task(cc_pair, db_session, r)
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the pruning generator task from being
|
||||
created, then creates the task.
|
||||
|
||||
Does not check for scheduling related conditions as this function
|
||||
is used to trigger prunes immediately.
|
||||
"""
|
||||
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
return None
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair.id)
|
||||
|
||||
# skip pruning if already pruning
|
||||
if r.exists(rcp.fence_key):
|
||||
return None
|
||||
|
||||
# skip pruning if the cc_pair is deleting
|
||||
db_session.refresh(cc_pair)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
# add a long running generator task to the queue
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
|
||||
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
"connector_pruning_generator_task",
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcp.fence_key, 1)
|
||||
return 1
|
||||
|
||||
|
||||
@shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT)
|
||||
def connector_pruning_generator_task(connector_id: int, credential_id: int) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
@ -70,6 +174,12 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
)
|
||||
return
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair.id)
|
||||
|
||||
# Define the callback function
|
||||
def redis_increment_callback(amount: int) -> None:
|
||||
r.incrby(rcp.generator_progress_key, amount)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
@ -78,10 +188,12 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
# a list of docs in the source
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
runnable_connector
|
||||
runnable_connector, redis_increment_callback
|
||||
)
|
||||
|
||||
# a list of docs in our local index
|
||||
all_indexed_document_ids = {
|
||||
doc.id
|
||||
for doc in get_documents_for_connector_credential_pair(
|
||||
@ -91,30 +203,37 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
)
|
||||
}
|
||||
|
||||
# generate list of docs to remove (no longer in the source)
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
task_logger.info(
|
||||
f"Pruning set collected: "
|
||||
f"cc_pair_id={cc_pair.id} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)} "
|
||||
f"doc_source={cc_pair.connector.source}"
|
||||
)
|
||||
|
||||
if len(doc_ids_to_remove) == 0:
|
||||
task_logger.info(
|
||||
f"No docs to prune from {cc_pair.connector.source} connector"
|
||||
)
|
||||
return
|
||||
rcp.documents_to_prune = set(doc_ids_to_remove)
|
||||
|
||||
task_logger.info(
|
||||
f"pruning {len(doc_ids_to_remove)} doc(s) from {cc_pair.connector.source} connector"
|
||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
delete_connector_credential_pair_batch(
|
||||
document_ids=doc_ids_to_remove,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_index=document_index,
|
||||
tasks_generated = rcp.generate_tasks(celery_app, db_session, r, None)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorPruning.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(rcp.generator_complete_key, tasks_generated)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning for connector id {connector_id}."
|
||||
)
|
||||
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.fence_key)
|
||||
raise e
|
||||
|
113
backend/danswer/background/celery/tasks/shared/tasks.py
Normal file
113
backend/danswer/background/celery/tasks/shared/tasks.py
Normal file
@ -0,0 +1,113 @@
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_document_connector_count
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def document_by_cc_pair_cleanup_task(
|
||||
self: Task, document_id: str, connector_id: int, credential_id: int
|
||||
) -> bool:
|
||||
"""A lightweight subtask used to clean up document to cc pair relationships.
|
||||
Created by connection deletion and connector pruning parent tasks."""
|
||||
|
||||
"""
|
||||
To delete a connector / credential pair:
|
||||
(1) find all documents associated with connector / credential pair where there
|
||||
this the is only connector / credential pair that has indexed it
|
||||
(2) delete all documents from document stores
|
||||
(3) delete all entries from postgres
|
||||
(4) find all documents associated with connector / credential pair where there
|
||||
are multiple connector / credential pairs that have indexed it
|
||||
(5) update document store entries to remove access associated with the
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
count = get_document_connector_count(db_session, document_id)
|
||||
if count == 1:
|
||||
# count == 1 means this is the only remaining cc_pair reference to the doc
|
||||
# delete it from vespa and the db
|
||||
document_index.delete(doc_ids=[document_id])
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
)
|
||||
elif count > 1:
|
||||
# count > 1 means the document still has cc_pair references
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# the below functions do not include cc_pairs being deleted.
|
||||
# i.e. they will correctly omit access for the current cc_pair
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
fields = VespaDocumentFields(
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
document_index.update_single(document_id, fields=fields)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
else:
|
||||
pass
|
||||
|
||||
db_session.commit()
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
@ -5,20 +5,22 @@ import redis
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.utils.log import get_task_logger
|
||||
from redis import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.celery.celery_app import task_logger
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import mark_ccpair_as_pruned
|
||||
from danswer.db.connector_credential_pair import add_deletion_failure_message
|
||||
from danswer.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
@ -49,10 +51,6 @@ from danswer.utils.variable_functionality import (
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
@ -279,7 +277,7 @@ def monitor_document_set_taskset(
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
if document_set_id is None:
|
||||
task_logger.warning("could not parse document set id from {key}")
|
||||
task_logger.warning(f"could not parse document set id from {fence_key}")
|
||||
return
|
||||
|
||||
rds = RedisDocumentSet(document_set_id)
|
||||
@ -326,7 +324,7 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id is None:
|
||||
task_logger.warning("could not parse document set id from {key}")
|
||||
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
|
||||
return
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
@ -351,6 +349,9 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"monitor_connector_deletion_taskset - cc_pair_id not found: cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
@ -402,20 +403,67 @@ def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||
add_deletion_failure_message(db_session, cc_pair.id, error_message)
|
||||
task_logger.exception(
|
||||
f"Failed to run connector_deletion. "
|
||||
f"connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
|
||||
f"cc_pair_id={cc_pair_id} connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Successfully deleted connector_credential_pair with connector_id: '{cc_pair.connector_id}' "
|
||||
f"and credential_id: '{cc_pair.credential_id}'. "
|
||||
f"Deleted {initial_count} docs."
|
||||
f"Successfully deleted cc_pair: "
|
||||
f"cc_pair_id={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"docs_deleted={initial_count}"
|
||||
)
|
||||
|
||||
r.delete(rcd.taskset_key)
|
||||
r.delete(rcd.fence_key)
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id = RedisConnectorPruning.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id is None:
|
||||
task_logger.warning(
|
||||
f"monitor_connector_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
|
||||
fence_value = r.get(rcp.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
generator_value = r.get(rcp.generator_complete_key)
|
||||
if generator_value is None:
|
||||
return
|
||||
|
||||
try:
|
||||
initial_count = int(cast(int, generator_value))
|
||||
except ValueError:
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rcp.taskset_key))
|
||||
task_logger.info(
|
||||
f"Connector pruning progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
mark_ccpair_as_pruned(cc_pair_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
|
||||
r.delete(rcp.taskset_key)
|
||||
r.delete(rcp.generator_progress_key)
|
||||
r.delete(rcp.generator_complete_key)
|
||||
r.delete(rcp.fence_key)
|
||||
|
||||
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300)
|
||||
def monitor_vespa_sync() -> None:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
@ -457,6 +505,9 @@ def monitor_vespa_sync() -> None:
|
||||
)
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
|
||||
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
|
@ -1,211 +0,0 @@
|
||||
"""
|
||||
To delete a connector / credential pair:
|
||||
(1) find all documents associated with connector / credential pair where there
|
||||
this the is only connector / credential pair that has indexed it
|
||||
(2) delete all documents from document stores
|
||||
(3) delete all entries from postgres
|
||||
(4) find all documents associated with connector / credential pair where there
|
||||
are multiple connector / credential pairs that have indexed it
|
||||
(5) update document store entries to remove access associated with the
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.utils.log import get_task_logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_document_connector_count
|
||||
from danswer.db.document import get_document_connector_counts
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.document_index.interfaces import VespaDocumentFields
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# use this within celery tasks to get celery task specific logging
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
_DELETION_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def delete_connector_credential_pair_batch(
|
||||
document_ids: list[str],
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
document_index: DocumentIndex,
|
||||
) -> None:
|
||||
"""
|
||||
Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore
|
||||
it gets permanently deleted.
|
||||
"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# acquire lock for all documents in this batch so that indexing can't
|
||||
# override the deletion
|
||||
with prepare_to_modify_documents(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
):
|
||||
document_connector_counts = get_document_connector_counts(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
|
||||
# figure out which docs need to be completely deleted
|
||||
document_ids_to_delete = [
|
||||
document_id
|
||||
for document_id, cnt in document_connector_counts
|
||||
if cnt == 1
|
||||
]
|
||||
logger.debug(f"Deleting documents: {document_ids_to_delete}")
|
||||
|
||||
document_index.delete(doc_ids=document_ids_to_delete)
|
||||
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_delete,
|
||||
)
|
||||
|
||||
# figure out which docs need to be updated
|
||||
document_ids_to_update = [
|
||||
document_id for document_id, cnt in document_connector_counts if cnt > 1
|
||||
]
|
||||
|
||||
# maps document id to list of document set names
|
||||
new_doc_sets_for_documents: dict[str, set[str]] = {
|
||||
document_id_and_document_set_names_tuple[0]: set(
|
||||
document_id_and_document_set_names_tuple[1]
|
||||
)
|
||||
for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
)
|
||||
}
|
||||
|
||||
# determine future ACLs for documents in batch
|
||||
access_for_documents = get_access_for_documents(
|
||||
document_ids=document_ids_to_update,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# update Vespa
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=access,
|
||||
document_sets=new_doc_sets_for_documents[document_id],
|
||||
)
|
||||
for document_id, access in access_for_documents.items()
|
||||
]
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
# clean up Postgres
|
||||
delete_documents_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def document_by_cc_pair_cleanup_task(
|
||||
self: Task, document_id: str, connector_id: int, credential_id: int
|
||||
) -> bool:
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
count = get_document_connector_count(db_session, document_id)
|
||||
if count == 1:
|
||||
# count == 1 means this is the only remaining cc_pair reference to the doc
|
||||
# delete it from vespa and the db
|
||||
document_index.delete_single(doc_id=document_id)
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
)
|
||||
elif count > 1:
|
||||
# count > 1 means the document still has cc_pair references
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# the below functions do not include cc_pairs being deleted.
|
||||
# i.e. they will correctly omit access for the current cc_pair
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
fields = VespaDocumentFields(
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
document_index.update_single(document_id, fields=fields)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
else:
|
||||
pass
|
||||
|
||||
# update_docs_last_modified__no_commit(
|
||||
# db_session=db_session,
|
||||
# document_ids=[document_id],
|
||||
# )
|
||||
|
||||
db_session.commit()
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
@ -187,10 +187,9 @@ class PostgresAdvisoryLocks(Enum):
|
||||
|
||||
|
||||
class DanswerCeleryQueues:
|
||||
VESPA_DOCSET_SYNC_GENERATOR = "vespa_docset_sync_generator"
|
||||
VESPA_USERGROUP_SYNC_GENERATOR = "vespa_usergroup_sync_generator"
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
|
||||
|
||||
class DanswerRedisLocks:
|
||||
@ -198,7 +197,7 @@ class DanswerRedisLocks:
|
||||
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
MONITOR_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:monitor_connector_deletion_beat"
|
||||
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
|
||||
|
||||
|
||||
class DanswerCeleryPriority(int, Enum):
|
||||
|
@ -1,3 +1,5 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import and_
|
||||
@ -268,3 +270,15 @@ def create_initial_default_connector(db_session: Session) -> None:
|
||||
)
|
||||
db_session.add(connector)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None:
|
||||
stmt = select(ConnectorCredentialPair).where(
|
||||
ConnectorCredentialPair.id == cc_pair_id
|
||||
)
|
||||
cc_pair = db_session.scalar(stmt)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
|
||||
|
||||
cc_pair.last_pruned = datetime.now(timezone.utc)
|
||||
db_session.commit()
|
||||
|
@ -414,6 +414,12 @@ class ConnectorCredentialPair(Base):
|
||||
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), default=None
|
||||
)
|
||||
|
||||
# last successful prune
|
||||
last_pruned: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, index=True
|
||||
)
|
||||
|
||||
total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
connector: Mapped["Connector"] = relationship(
|
||||
|
@ -10,9 +10,11 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from danswer.background.celery.celery_utils import skip_cc_pair_pruning_by_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
from danswer.db.connector_credential_pair import add_credential_to_connector
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import remove_credential_from_connector
|
||||
@ -31,6 +33,7 @@ from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from danswer.db.models import User
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import CCPairFullInfo
|
||||
from danswer.server.documents.models import CCStatusUpdateRequest
|
||||
from danswer.server.documents.models import CeleryTaskStatus
|
||||
@ -203,7 +206,7 @@ def get_cc_pair_latest_prune(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CeleryTaskStatus:
|
||||
) -> bool:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
@ -216,24 +219,8 @@ def get_cc_pair_latest_prune(
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
# look up the last prune task for this connector (if it exists)
|
||||
pruning_task_name = name_cc_prune_task(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
if not last_pruning_task:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="No pruning task found.",
|
||||
)
|
||||
|
||||
return CeleryTaskStatus(
|
||||
id=last_pruning_task.task_id,
|
||||
name=last_pruning_task.task_name,
|
||||
status=last_pruning_task.status,
|
||||
start_time=last_pruning_task.start_time,
|
||||
register_time=last_pruning_task.register_time,
|
||||
)
|
||||
rcp = RedisConnectorPruning(cc_pair.id)
|
||||
return rcp.is_pruning(db_session, get_redis_client())
|
||||
|
||||
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/prune")
|
||||
@ -242,8 +229,7 @@ def prune_cc_pair(
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[list[int]]:
|
||||
# avoiding circular refs
|
||||
from danswer.background.celery.tasks.pruning.tasks import prune_documents_task
|
||||
"""Triggers pruning on a particular cc_pair immediately"""
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@ -257,26 +243,26 @@ def prune_cc_pair(
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
pruning_task_name = name_cc_prune_task(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
last_pruning_task = get_latest_task(pruning_task_name, db_session)
|
||||
if skip_cc_pair_pruning_by_task(
|
||||
last_pruning_task,
|
||||
db_session=db_session,
|
||||
):
|
||||
r = get_redis_client()
|
||||
rcp = RedisConnectorPruning(cc_pair_id)
|
||||
if rcp.is_pruning(db_session, r):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Pruning task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(f"Pruning the {cc_pair.connector.name} connector.")
|
||||
prune_documents_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
logger.info(
|
||||
f"Pruning cc_pair: cc_pair_id={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
tasks_created = try_creating_prune_generator_task(cc_pair, db_session, r)
|
||||
if not tasks_created:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="Pruning task creation failed.",
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
@ -353,14 +339,6 @@ def sync_cc_pair(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Sync task already in progress.",
|
||||
)
|
||||
if skip_cc_pair_pruning_by_task(
|
||||
last_sync_task,
|
||||
db_session=db_session,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
|
@ -15,10 +15,10 @@ logger = setup_logger()
|
||||
|
||||
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
|
||||
"""This function is likely to move in the worker refactor happening next."""
|
||||
key = key_bytes.decode("utf-8")
|
||||
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
usergroup_id = RedisUserGroup.get_id_from_fence_key(fence_key)
|
||||
if not usergroup_id:
|
||||
task_logger.warning("Could not parse usergroup id from {key}")
|
||||
task_logger.warning(f"Could not parse usergroup id from {fence_key}")
|
||||
return
|
||||
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
|
@ -603,7 +603,7 @@ def delete_user_group_cc_pair_relationship__no_commit(
|
||||
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
raise ValueError(
|
||||
f"Connector Credential Pair '{cc_pair_id}' is not in the DELETING state"
|
||||
f"Connector Credential Pair '{cc_pair_id}' is not in the DELETING state. status={cc_pair.status}"
|
||||
)
|
||||
|
||||
delete_stmt = delete(UserGroup__ConnectorCredentialPair).where(
|
||||
|
@ -274,10 +274,10 @@ class CCPairManager:
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_prune_task(
|
||||
def is_pruning(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> CeleryTaskStatus:
|
||||
) -> bool:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/prune",
|
||||
headers=user_performing_action.headers
|
||||
@ -285,28 +285,21 @@ class CCPairManager:
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return CeleryTaskStatus(**response.json())
|
||||
response_bool = response.json()
|
||||
return response_bool
|
||||
|
||||
@staticmethod
|
||||
def wait_for_prune(
|
||||
cc_pair: DATestCCPair,
|
||||
after: datetime,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""after: The task register time must be after this time."""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
task = CCPairManager.get_prune_task(cc_pair, user_performing_action)
|
||||
if not task:
|
||||
raise ValueError("Prune task not found.")
|
||||
|
||||
if not task.register_time or task.register_time < after:
|
||||
raise ValueError("Prune task register time is too early.")
|
||||
|
||||
if task.status == TaskStatus.SUCCESS:
|
||||
# Pruning succeeded
|
||||
return
|
||||
result = CCPairManager.is_pruning(cc_pair, user_performing_action)
|
||||
if not result:
|
||||
break
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > timeout:
|
||||
@ -380,16 +373,31 @@ class CCPairManager:
|
||||
|
||||
@staticmethod
|
||||
def wait_for_deletion_completion(
|
||||
cc_pair_id: int | None = None,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""if cc_pair_id is not specified, just waits until no connectors are in the deleting state.
|
||||
if cc_pair_id is specified, checks to ensure the specific cc_pair_id is gone.
|
||||
We had a bug where the connector was paused in the middle of deleting, so specifying the
|
||||
cc_pair_id is good to do."""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
fetched_cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
if all(
|
||||
cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING
|
||||
for cc_pair in fetched_cc_pairs
|
||||
):
|
||||
return
|
||||
cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
if cc_pair_id:
|
||||
found = False
|
||||
for cc_pair in cc_pairs:
|
||||
if cc_pair.cc_pair_id == cc_pair_id:
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
return
|
||||
else:
|
||||
if all(
|
||||
cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING
|
||||
for cc_pair in cc_pairs
|
||||
):
|
||||
return
|
||||
|
||||
if time.monotonic() - start > MAX_DELAY:
|
||||
raise TimeoutError(
|
||||
|
@ -195,9 +195,8 @@ def test_slack_prune(
|
||||
)
|
||||
|
||||
# Prune the cc_pair
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.prune(cc_pair, user_performing_action=admin_user)
|
||||
CCPairManager.wait_for_prune(cc_pair, before, user_performing_action=admin_user)
|
||||
CCPairManager.wait_for_prune(cc_pair, user_performing_action=admin_user)
|
||||
|
||||
# ----------------------------VERIFY THE CHANGES---------------------------
|
||||
# Ensure admin user can't see deleted messages
|
||||
|
@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import create_index_attempt_error
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
@ -117,6 +118,22 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# inject an index attempt and index attempt error (exercises foreign key errors)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
attempt_id = create_index_attempt(
|
||||
connector_credential_pair_id=cc_pair_1.id,
|
||||
search_settings_id=1,
|
||||
db_session=db_session,
|
||||
)
|
||||
create_index_attempt_error(
|
||||
index_attempt_id=attempt_id,
|
||||
batch=1,
|
||||
docs=[],
|
||||
exception_msg="",
|
||||
exception_traceback="",
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Update local records to match the database for later comparison
|
||||
user_group_1.cc_pair_ids = []
|
||||
user_group_2.cc_pair_ids = [cc_pair_2.id]
|
||||
@ -125,7 +142,9 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
cc_pair_1.groups = []
|
||||
cc_pair_2.groups = [user_group_2.id]
|
||||
|
||||
CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user)
|
||||
CCPairManager.wait_for_deletion_completion(
|
||||
cc_pair_id=cc_pair_1.id, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
# validate vespa documents
|
||||
DocumentManager.verify(
|
||||
@ -303,7 +322,9 @@ def test_connector_deletion_for_overlapping_connectors(
|
||||
)
|
||||
|
||||
# wait for deletion to finish
|
||||
CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user)
|
||||
CCPairManager.wait_for_deletion_completion(
|
||||
cc_pair_id=cc_pair_1.id, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
print("Connector 1 deleted")
|
||||
|
||||
|
@ -171,7 +171,9 @@ def test_cc_pair_permissions(reset: None) -> None:
|
||||
|
||||
# Test deleting the cc pair
|
||||
CCPairManager.delete(valid_cc_pair, user_performing_action=curator)
|
||||
CCPairManager.wait_for_deletion_completion(user_performing_action=curator)
|
||||
CCPairManager.wait_for_deletion_completion(
|
||||
cc_pair_id=valid_cc_pair.id, user_performing_action=curator
|
||||
)
|
||||
|
||||
CCPairManager.verify(
|
||||
cc_pair=valid_cc_pair,
|
||||
|
@ -77,7 +77,9 @@ def test_whole_curator_flow(reset: None) -> None:
|
||||
|
||||
# Verify that the curator can delete the CC pair
|
||||
CCPairManager.delete(cc_pair=test_cc_pair, user_performing_action=curator)
|
||||
CCPairManager.wait_for_deletion_completion(user_performing_action=curator)
|
||||
CCPairManager.wait_for_deletion_completion(
|
||||
cc_pair_id=test_cc_pair.id, user_performing_action=curator
|
||||
)
|
||||
|
||||
# Verify that the CC pair has been deleted
|
||||
CCPairManager.verify(
|
||||
@ -158,7 +160,9 @@ def test_global_curator_flow(reset: None) -> None:
|
||||
|
||||
# Verify that the curator can delete the CC pair
|
||||
CCPairManager.delete(cc_pair=test_cc_pair, user_performing_action=global_curator)
|
||||
CCPairManager.wait_for_deletion_completion(user_performing_action=global_curator)
|
||||
CCPairManager.wait_for_deletion_completion(
|
||||
cc_pair_id=test_cc_pair.id, user_performing_action=global_curator
|
||||
)
|
||||
|
||||
# Verify that the CC pair has been deleted
|
||||
CCPairManager.verify(
|
||||
|
@ -105,12 +105,9 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
logger.info("Removing courses.html.")
|
||||
os.remove(os.path.join(website_tgt, "courses.html"))
|
||||
|
||||
# store the time again as a reference for the pruning timestamps
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
CCPairManager.prune(cc_pair_1, user_performing_action=admin_user)
|
||||
CCPairManager.wait_for_prune(
|
||||
cc_pair_1, now, timeout=60, user_performing_action=admin_user
|
||||
cc_pair_1, timeout=60, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
selected_cc_pair = CCPairManager.get_one(
|
||||
|
@ -10,7 +10,7 @@ from danswer.connectors.mediawiki import family
|
||||
NON_BUILTIN_WIKIS: Final[list[tuple[str, str]]] = [
|
||||
("https://fallout.fandom.com", "falloutwiki"),
|
||||
("https://harrypotter.fandom.com/wiki/", "harrypotterwiki"),
|
||||
("https://artofproblemsolving.com/wiki", "artofproblemsolving"),
|
||||
# ("https://artofproblemsolving.com/wiki", "artofproblemsolving"), # FLAKY
|
||||
("https://www.bogleheads.org/wiki/Main_Page", "bogleheadswiki"),
|
||||
("https://bogleheads.org/wiki/Main_Page", "bogleheadswiki"),
|
||||
("https://www.dandwiki.com/wiki/", "dungeonsanddragons"),
|
||||
|
Loading…
x
Reference in New Issue
Block a user