diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index 9413dd978..caae8be30 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -5,7 +5,6 @@ from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded -from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session @@ -37,7 +36,7 @@ class TaskDependencyError(RuntimeError): def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None: r = get_redis_client(tenant_id=tenant_id) - lock_beat = r.lock( + lock_beat: RedisLock = r.lock( DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK, timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) @@ -60,7 +59,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N redis_connector = RedisConnector(tenant_id, cc_pair_id) try: try_generate_document_cc_pair_cleanup_tasks( - self.app, cc_pair_id, db_session, r, lock_beat, tenant_id + self.app, cc_pair_id, db_session, lock_beat, tenant_id ) except TaskDependencyError as e: # this means we wanted to start deleting but dependent tasks were running @@ -86,7 +85,6 @@ def try_generate_document_cc_pair_cleanup_tasks( app: Celery, cc_pair_id: int, db_session: Session, - r: Redis, lock_beat: RedisLock, tenant_id: str | None, ) -> int | None: diff --git a/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py index eef14e980..babf9b69b 100644 --- a/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py +++ b/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py @@ -8,6 +8,7 @@ from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis +from redis.lock import Lock as RedisLock from danswer.access.models import DocExternalAccess from danswer.background.celery.apps.app_base import task_logger @@ -27,7 +28,7 @@ from danswer.db.models import ConnectorCredentialPair from danswer.db.users import batch_add_ext_perm_user_if_not_exists from danswer.redis.redis_connector import RedisConnector from danswer.redis.redis_connector_doc_perm_sync import ( - RedisConnectorPermissionSyncData, + RedisConnectorPermissionSyncPayload, ) from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import doc_permission_sync_ctx @@ -138,7 +139,7 @@ def try_creating_permissions_sync_task( LOCK_TIMEOUT = 30 - lock = r.lock( + lock: RedisLock = r.lock( DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks", timeout=LOCK_TIMEOUT, ) @@ -162,7 +163,7 @@ def try_creating_permissions_sync_task( custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}" - app.send_task( + result = app.send_task( "connector_permission_sync_generator_task", kwargs=dict( cc_pair_id=cc_pair_id, @@ -174,8 +175,8 @@ def try_creating_permissions_sync_task( ) # set a basic fence to start - payload = RedisConnectorPermissionSyncData( - started=None, + payload = RedisConnectorPermissionSyncPayload( + started=None, celery_task_id=result.id ) redis_connector.permissions.set_fence(payload) @@ -247,9 +248,11 @@ def connector_permission_sync_generator_task( logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}") - payload = RedisConnectorPermissionSyncData( - started=datetime.now(timezone.utc), - ) + payload = redis_connector.permissions.payload + if not payload: + raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}") + + payload.started = datetime.now(timezone.utc) redis_connector.permissions.set_fence(payload) document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair) diff --git a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py index c3f0f6c6f..d80b2b518 100644 --- a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py @@ -8,6 +8,7 @@ from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis +from redis.lock import Lock as RedisLock from danswer.background.celery.apps.app_base import task_logger from danswer.configs.app_configs import JOB_TIMEOUT @@ -24,6 +25,9 @@ from danswer.db.enums import AccessType from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.models import ConnectorCredentialPair from danswer.redis.redis_connector import RedisConnector +from danswer.redis.redis_connector_ext_group_sync import ( + RedisConnectorExternalGroupSyncPayload, +) from danswer.redis.redis_pool import get_redis_client from danswer.utils.logger import setup_logger from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs @@ -107,7 +111,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None: cc_pair_ids_to_sync.append(cc_pair.id) for cc_pair_id in cc_pair_ids_to_sync: - tasks_created = try_creating_permissions_sync_task( + tasks_created = try_creating_external_group_sync_task( self.app, cc_pair_id, r, tenant_id ) if not tasks_created: @@ -125,7 +129,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None: lock_beat.release() -def try_creating_permissions_sync_task( +def try_creating_external_group_sync_task( app: Celery, cc_pair_id: int, r: Redis, @@ -156,7 +160,7 @@ def try_creating_permissions_sync_task( custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}" - _ = app.send_task( + result = app.send_task( "connector_external_group_sync_generator_task", kwargs=dict( cc_pair_id=cc_pair_id, @@ -166,8 +170,13 @@ def try_creating_permissions_sync_task( task_id=custom_task_id, priority=DanswerCeleryPriority.HIGH, ) - # set a basic fence to start - redis_connector.external_group_sync.set_fence(True) + + payload = RedisConnectorExternalGroupSyncPayload( + started=datetime.now(timezone.utc), + celery_task_id=result.id, + ) + + redis_connector.external_group_sync.set_fence(payload) except Exception: task_logger.exception( @@ -203,7 +212,7 @@ def connector_external_group_sync_generator_task( r = get_redis_client(tenant_id=tenant_id) - lock = r.lock( + lock: RedisLock = r.lock( DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX + f"_{redis_connector.id}", timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT, @@ -253,7 +262,6 @@ def connector_external_group_sync_generator_task( ) mark_cc_pair_as_external_group_synced(db_session, cc_pair.id) - except Exception as e: task_logger.exception( f"Failed to run external group sync: cc_pair={cc_pair_id}" @@ -264,6 +272,6 @@ def connector_external_group_sync_generator_task( raise e finally: # we always want to clear the fence after the task is done or failed so it doesn't get stuck - redis_connector.external_group_sync.set_fence(False) + redis_connector.external_group_sync.set_fence(None) if lock.owned(): lock.release() diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index 9ebab40d0..4525b1e94 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -39,12 +39,13 @@ from danswer.db.index_attempt import delete_index_attempt from danswer.db.index_attempt import get_all_index_attempts_by_status from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import get_last_attempt_for_cc_pair +from danswer.db.index_attempt import mark_attempt_canceled from danswer.db.index_attempt import mark_attempt_failed from danswer.db.models import ConnectorCredentialPair from danswer.db.models import IndexAttempt from danswer.db.models import SearchSettings +from danswer.db.search_settings import get_active_search_settings from danswer.db.search_settings import get_current_search_settings -from danswer.db.search_settings import get_secondary_search_settings from danswer.db.swap_index import check_index_swap from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface from danswer.natural_language_processing.search_nlp_models import EmbeddingModel @@ -209,17 +210,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: redis_connector = RedisConnector(tenant_id, cc_pair_id) with get_session_with_tenant(tenant_id) as db_session: - # Get the primary search settings - primary_search_settings = get_current_search_settings(db_session) - search_settings = [primary_search_settings] - - # Check for secondary search settings - secondary_search_settings = get_secondary_search_settings(db_session) - if secondary_search_settings is not None: - # If secondary settings exist, add them to the list - search_settings.append(secondary_search_settings) - - for search_settings_instance in search_settings: + search_settings_list: list[SearchSettings] = get_active_search_settings( + db_session + ) + for search_settings_instance in search_settings_list: redis_connector_index = redis_connector.new_index( search_settings_instance.id ) @@ -237,7 +231,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: ) search_settings_primary = False - if search_settings_instance.id == primary_search_settings.id: + if search_settings_instance.id == search_settings_list[0].id: search_settings_primary = True if not _should_index( @@ -245,13 +239,13 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: last_index=last_attempt, search_settings_instance=search_settings_instance, search_settings_primary=search_settings_primary, - secondary_index_building=len(search_settings) > 1, + secondary_index_building=len(search_settings_list) > 1, db_session=db_session, ): continue reindex = False - if search_settings_instance.id == primary_search_settings.id: + if search_settings_instance.id == search_settings_list[0].id: # the indexing trigger is only checked and cleared with the primary search settings if cc_pair.indexing_trigger is not None: if cc_pair.indexing_trigger == IndexingMode.REINDEX: @@ -284,7 +278,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None: f"Connector indexing queued: " f"index_attempt={attempt_id} " f"cc_pair={cc_pair.id} " - f"search_settings={search_settings_instance.id} " + f"search_settings={search_settings_instance.id}" ) tasks_created += 1 @@ -529,8 +523,11 @@ def try_creating_indexing_task( return index_attempt_id -@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True) +@shared_task( + name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True +) def connector_indexing_proxy_task( + self: Task, index_attempt_id: int, cc_pair_id: int, search_settings_id: int, @@ -543,6 +540,10 @@ def connector_indexing_proxy_task( f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id}" ) + + if not self.request.id: + task_logger.error("self.request.id is None!") + client = SimpleJobClient() job = client.submit( @@ -571,8 +572,30 @@ def connector_indexing_proxy_task( f"search_settings={search_settings_id}" ) + redis_connector = RedisConnector(tenant_id, cc_pair_id) + redis_connector_index = redis_connector.new_index(search_settings_id) + while True: - sleep(10) + sleep(5) + + if self.request.id and redis_connector_index.terminating(self.request.id): + task_logger.warning( + "Indexing proxy - termination signal detected: " + f"attempt={index_attempt_id} " + f"tenant={tenant_id} " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings_id}" + ) + + with get_session_with_tenant(tenant_id) as db_session: + mark_attempt_canceled( + index_attempt_id, + db_session, + "Connector termination signal detected", + ) + + job.cancel() + break # do nothing for ongoing jobs that haven't been stopped if not job.done(): diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index ec7f52bc0..f491ff27b 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -46,6 +46,7 @@ from danswer.db.document_set import fetch_document_sets_for_document from danswer.db.document_set import get_document_set_by_id from danswer.db.document_set import mark_document_set_as_synced from danswer.db.engine import get_session_with_tenant +from danswer.db.enums import IndexingStatus from danswer.db.index_attempt import delete_index_attempts from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import mark_attempt_failed @@ -58,7 +59,7 @@ from danswer.redis.redis_connector_credential_pair import RedisConnectorCredenti from danswer.redis.redis_connector_delete import RedisConnectorDelete from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from danswer.redis.redis_connector_doc_perm_sync import ( - RedisConnectorPermissionSyncData, + RedisConnectorPermissionSyncPayload, ) from danswer.redis.redis_connector_index import RedisConnectorIndex from danswer.redis.redis_connector_prune import RedisConnectorPrune @@ -588,7 +589,7 @@ def monitor_ccpair_permissions_taskset( if remaining > 0: return - payload: RedisConnectorPermissionSyncData | None = ( + payload: RedisConnectorPermissionSyncPayload | None = ( redis_connector.permissions.payload ) start_time: datetime | None = payload.started if payload else None @@ -596,9 +597,7 @@ def monitor_ccpair_permissions_taskset( mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time) task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}") - redis_connector.permissions.taskset_clear() - redis_connector.permissions.generator_clear() - redis_connector.permissions.set_fence(None) + redis_connector.permissions.reset() def monitor_ccpair_indexing_taskset( @@ -678,11 +677,15 @@ def monitor_ccpair_indexing_taskset( index_attempt = get_index_attempt(db_session, payload.index_attempt_id) if index_attempt: - mark_attempt_failed( - index_attempt_id=payload.index_attempt_id, - db_session=db_session, - failure_reason=msg, - ) + if ( + index_attempt.status != IndexingStatus.CANCELED + and index_attempt.status != IndexingStatus.FAILED + ): + mark_attempt_failed( + index_attempt_id=payload.index_attempt_id, + db_session=db_session, + failure_reason=msg, + ) redis_connector_index.reset() return @@ -692,6 +695,7 @@ def monitor_ccpair_indexing_taskset( task_logger.info( f"Connector indexing finished: cc_pair={cc_pair_id} " f"search_settings={search_settings_id} " + f"progress={progress} " f"status={status_enum.name} " f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}" ) @@ -724,7 +728,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool: # print current queue lengths r_celery = self.app.broker_connection().channel().client # type: ignore - n_celery = celery_get_queue_length("celery", r) + n_celery = celery_get_queue_length("celery", r_celery) n_indexing = celery_get_queue_length( DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery ) diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 699e4682c..40ed778f0 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -19,6 +19,7 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim from danswer.db.connector_credential_pair import update_connector_credential_pair from danswer.db.engine import get_session_with_tenant from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.db.index_attempt import mark_attempt_canceled from danswer.db.index_attempt import mark_attempt_failed from danswer.db.index_attempt import mark_attempt_partially_succeeded from danswer.db.index_attempt import mark_attempt_succeeded @@ -87,6 +88,10 @@ def _get_connector_runner( ) +class ConnectorStopSignal(Exception): + """A custom exception used to signal a stop in processing.""" + + def _run_indexing( db_session: Session, index_attempt: IndexAttempt, @@ -208,9 +213,7 @@ def _run_indexing( # contents still need to be initially pulled. if callback: if callback.should_stop(): - raise RuntimeError( - "_run_indexing: Connector stop signal detected" - ) + raise ConnectorStopSignal("Connector stop signal detected") # TODO: should we move this into the above callback instead? db_session.refresh(db_cc_pair) @@ -304,26 +307,16 @@ def _run_indexing( ) except Exception as e: logger.exception( - f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds" + f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds" ) - # Only mark the attempt as a complete failure if this is the first indexing window. - # Otherwise, some progress was made - the next run will not start from the beginning. - # In this case, it is not accurate to mark it as a failure. When the next run begins, - # if that fails immediately, it will be marked as a failure. - # - # NOTE: if the connector is manually disabled, we should mark it as a failure regardless - # to give better clarity in the UI, as the next run will never happen. - if ( - ind == 0 - or not db_cc_pair.status.is_active() - or index_attempt.status != IndexingStatus.IN_PROGRESS - ): - mark_attempt_failed( + + if isinstance(e, ConnectorStopSignal): + mark_attempt_canceled( index_attempt.id, db_session, - failure_reason=str(e), - full_exception_trace=traceback.format_exc(), + reason=str(e), ) + if is_primary: update_connector_credential_pair( db_session=db_session, @@ -335,6 +328,37 @@ def _run_indexing( if INDEXING_TRACER_INTERVAL > 0: tracer.stop() raise e + else: + # Only mark the attempt as a complete failure if this is the first indexing window. + # Otherwise, some progress was made - the next run will not start from the beginning. + # In this case, it is not accurate to mark it as a failure. When the next run begins, + # if that fails immediately, it will be marked as a failure. + # + # NOTE: if the connector is manually disabled, we should mark it as a failure regardless + # to give better clarity in the UI, as the next run will never happen. + if ( + ind == 0 + or not db_cc_pair.status.is_active() + or index_attempt.status != IndexingStatus.IN_PROGRESS + ): + mark_attempt_failed( + index_attempt.id, + db_session, + failure_reason=str(e), + full_exception_trace=traceback.format_exc(), + ) + + if is_primary: + update_connector_credential_pair( + db_session=db_session, + connector_id=db_connector.id, + credential_id=db_credential.id, + net_docs=net_doc_change, + ) + + if INDEXING_TRACER_INTERVAL > 0: + tracer.stop() + raise e # break => similar to success case. As mentioned above, if the next run fails for the same # reason it will then be marked as a failure diff --git a/backend/danswer/db/search_settings.py b/backend/danswer/db/search_settings.py index 4f437eaae..1134b326a 100644 --- a/backend/danswer/db/search_settings.py +++ b/backend/danswer/db/search_settings.py @@ -143,6 +143,25 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None: return latest_settings +def get_active_search_settings(db_session: Session) -> list[SearchSettings]: + """Returns active search settings. The first entry will always be the current search + settings. If there are new search settings that are being migrated to, those will be + the second entry.""" + search_settings_list: list[SearchSettings] = [] + + # Get the primary search settings + primary_search_settings = get_current_search_settings(db_session) + search_settings_list.append(primary_search_settings) + + # Check for secondary search settings + secondary_search_settings = get_secondary_search_settings(db_session) + if secondary_search_settings is not None: + # If secondary settings exist, add them to the list + search_settings_list.append(secondary_search_settings) + + return search_settings_list + + def get_all_search_settings(db_session: Session) -> list[SearchSettings]: query = select(SearchSettings).order_by(SearchSettings.id.desc()) result = db_session.execute(query) diff --git a/backend/danswer/redis/redis_connector.py b/backend/danswer/redis/redis_connector.py index 8b52a2fd8..8d82fc119 100644 --- a/backend/danswer/redis/redis_connector.py +++ b/backend/danswer/redis/redis_connector.py @@ -1,5 +1,8 @@ +import time + import redis +from danswer.db.models import SearchSettings from danswer.redis.redis_connector_delete import RedisConnectorDelete from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync @@ -31,6 +34,44 @@ class RedisConnector: self.tenant_id, self.id, search_settings_id, self.redis ) + def wait_for_indexing_termination( + self, + search_settings_list: list[SearchSettings], + timeout: float = 15.0, + ) -> bool: + """ + Returns True if all indexing for the given redis connector is finished within the given timeout. + Returns False if the timeout is exceeded + + This check does not guarantee that current indexings being terminated + won't get restarted midflight + """ + + finished = False + + start = time.monotonic() + + while True: + still_indexing = False + for search_settings in search_settings_list: + redis_connector_index = self.new_index(search_settings.id) + if redis_connector_index.fenced: + still_indexing = True + break + + if not still_indexing: + finished = True + break + + now = time.monotonic() + if now - start > timeout: + break + + time.sleep(1) + continue + + return finished + @staticmethod def get_id_from_fence_key(key: str) -> str | None: """ diff --git a/backend/danswer/redis/redis_connector_doc_perm_sync.py b/backend/danswer/redis/redis_connector_doc_perm_sync.py index d9c3cd814..7b3748fcc 100644 --- a/backend/danswer/redis/redis_connector_doc_perm_sync.py +++ b/backend/danswer/redis/redis_connector_doc_perm_sync.py @@ -14,8 +14,9 @@ from danswer.configs.constants import DanswerCeleryPriority from danswer.configs.constants import DanswerCeleryQueues -class RedisConnectorPermissionSyncData(BaseModel): +class RedisConnectorPermissionSyncPayload(BaseModel): started: datetime | None + celery_task_id: str | None class RedisConnectorPermissionSync: @@ -78,14 +79,14 @@ class RedisConnectorPermissionSync: return False @property - def payload(self) -> RedisConnectorPermissionSyncData | None: + def payload(self) -> RedisConnectorPermissionSyncPayload | None: # read related data and evaluate/print task progress fence_bytes = cast(bytes, self.redis.get(self.fence_key)) if fence_bytes is None: return None fence_str = fence_bytes.decode("utf-8") - payload = RedisConnectorPermissionSyncData.model_validate_json( + payload = RedisConnectorPermissionSyncPayload.model_validate_json( cast(str, fence_str) ) @@ -93,7 +94,7 @@ class RedisConnectorPermissionSync: def set_fence( self, - payload: RedisConnectorPermissionSyncData | None, + payload: RedisConnectorPermissionSyncPayload | None, ) -> None: if not payload: self.redis.delete(self.fence_key) @@ -162,6 +163,12 @@ class RedisConnectorPermissionSync: return len(async_results) + def reset(self) -> None: + self.redis.delete(self.generator_progress_key) + self.redis.delete(self.generator_complete_key) + self.redis.delete(self.taskset_key) + self.redis.delete(self.fence_key) + @staticmethod def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None: taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}" diff --git a/backend/danswer/redis/redis_connector_ext_group_sync.py b/backend/danswer/redis/redis_connector_ext_group_sync.py index 631845648..bbe539c39 100644 --- a/backend/danswer/redis/redis_connector_ext_group_sync.py +++ b/backend/danswer/redis/redis_connector_ext_group_sync.py @@ -1,11 +1,18 @@ +from datetime import datetime from typing import cast import redis from celery import Celery +from pydantic import BaseModel from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session +class RedisConnectorExternalGroupSyncPayload(BaseModel): + started: datetime | None + celery_task_id: str | None + + class RedisConnectorExternalGroupSync: """Manages interactions with redis for external group syncing tasks. Should only be accessed through RedisConnector.""" @@ -68,12 +75,29 @@ class RedisConnectorExternalGroupSync: return False - def set_fence(self, value: bool) -> None: - if not value: + @property + def payload(self) -> RedisConnectorExternalGroupSyncPayload | None: + # read related data and evaluate/print task progress + fence_bytes = cast(bytes, self.redis.get(self.fence_key)) + if fence_bytes is None: + return None + + fence_str = fence_bytes.decode("utf-8") + payload = RedisConnectorExternalGroupSyncPayload.model_validate_json( + cast(str, fence_str) + ) + + return payload + + def set_fence( + self, + payload: RedisConnectorExternalGroupSyncPayload | None, + ) -> None: + if not payload: self.redis.delete(self.fence_key) return - self.redis.set(self.fence_key, 0) + self.redis.set(self.fence_key, payload.model_dump_json()) @property def generator_complete(self) -> int | None: diff --git a/backend/danswer/redis/redis_connector_index.py b/backend/danswer/redis/redis_connector_index.py index 10fd3667f..40b194af0 100644 --- a/backend/danswer/redis/redis_connector_index.py +++ b/backend/danswer/redis/redis_connector_index.py @@ -29,6 +29,8 @@ class RedisConnectorIndex: GENERATOR_LOCK_PREFIX = "da_lock:indexing" + TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate + def __init__( self, tenant_id: str | None, @@ -51,6 +53,7 @@ class RedisConnectorIndex: self.generator_lock_key = ( f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}" ) + self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}" @classmethod def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str: @@ -92,6 +95,18 @@ class RedisConnectorIndex: self.redis.set(self.fence_key, payload.model_dump_json()) + def terminating(self, celery_task_id: str) -> bool: + if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"): + return True + + return False + + def set_terminate(self, celery_task_id: str) -> None: + """This sets a signal. It does not block!""" + # We shouldn't need very long to terminate the spawned task. + # 10 minute TTL is good. + self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600) + def set_generator_complete(self, payload: int | None) -> None: if not payload: self.redis.delete(self.generator_complete_key) diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 55808ebce..46bdb2078 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -6,6 +6,7 @@ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Query +from fastapi.responses import JSONResponse from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -37,7 +38,9 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import count_index_attempts_for_connector from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id +from danswer.db.models import SearchSettings from danswer.db.models import User +from danswer.db.search_settings import get_active_search_settings from danswer.db.search_settings import get_current_search_settings from danswer.redis.redis_connector import RedisConnector from danswer.redis.redis_pool import get_redis_client @@ -158,7 +161,19 @@ def update_cc_pair_status( status_update_request: CCStatusUpdateRequest, user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), -) -> None: + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + """This method may wait up to 30 seconds if pausing the connector due to the need to + terminate tasks in progress. Tasks are not guaranteed to terminate within the + timeout. + + Returns HTTPStatus.OK if everything finished. + Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks + did not finish within the timeout. + """ + WAIT_TIMEOUT = 15.0 + still_terminating = False + cc_pair = get_connector_credential_pair_from_id( cc_pair_id=cc_pair_id, db_session=db_session, @@ -173,10 +188,76 @@ def update_cc_pair_status( ) if status_update_request.status == ConnectorCredentialPairStatus.PAUSED: - cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session) + search_settings_list: list[SearchSettings] = get_active_search_settings( + db_session + ) + cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session) cancel_indexing_attempts_past_model(db_session) + redis_connector = RedisConnector(tenant_id, cc_pair_id) + + try: + redis_connector.stop.set_fence(True) + while True: + logger.debug( + f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}" + ) + wait_succeeded = redis_connector.wait_for_indexing_termination( + search_settings_list, WAIT_TIMEOUT + ) + if wait_succeeded: + logger.debug( + f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}" + ) + break + + logger.debug( + "Wait for indexing soft termination timed out. " + f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}" + ) + + for search_settings in search_settings_list: + redis_connector_index = redis_connector.new_index( + search_settings.id + ) + if not redis_connector_index.fenced: + continue + + index_payload = redis_connector_index.payload + if not index_payload: + continue + + if not index_payload.celery_task_id: + continue + + # Revoke the task to prevent it from running + primary_app.control.revoke(index_payload.celery_task_id) + + # If it is running, then signaling for termination will get the + # watchdog thread to kill the spawned task + redis_connector_index.set_terminate(index_payload.celery_task_id) + + logger.debug( + f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}" + ) + wait_succeeded = redis_connector.wait_for_indexing_termination( + search_settings_list, WAIT_TIMEOUT + ) + if wait_succeeded: + logger.debug( + f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}" + ) + break + + logger.debug( + f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}" + ) + still_terminating = True + break + finally: + redis_connector.stop.set_fence(False) + update_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, @@ -185,6 +266,18 @@ def update_cc_pair_status( db_session.commit() + if still_terminating: + return JSONResponse( + status_code=HTTPStatus.ACCEPTED, + content={ + "message": "Request accepted, background task termination still in progress" + }, + ) + + return JSONResponse( + status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)} + ) + @router.put("/admin/cc-pair/{cc_pair_id}/name") def update_cc_pair_name( @@ -267,9 +360,9 @@ def prune_cc_pair( ) logger.info( - f"Pruning cc_pair: cc_pair_id={cc_pair_id} " - f"connector_id={cc_pair.connector_id} " - f"credential_id={cc_pair.credential_id} " + f"Pruning cc_pair: cc_pair={cc_pair_id} " + f"connector={cc_pair.connector_id} " + f"credential={cc_pair.credential_id} " f"{cc_pair.connector.name} connector." ) tasks_created = try_creating_prune_generator_task( diff --git a/backend/tests/integration/common_utils/managers/cc_pair.py b/backend/tests/integration/common_utils/managers/cc_pair.py index b37822d34..d32e10056 100644 --- a/backend/tests/integration/common_utils/managers/cc_pair.py +++ b/backend/tests/integration/common_utils/managers/cc_pair.py @@ -240,7 +240,85 @@ class CCPairManager: result.raise_for_status() @staticmethod - def wait_for_indexing( + def wait_for_indexing_inactive( + cc_pair: DATestCCPair, + timeout: float = MAX_DELAY, + user_performing_action: DATestUser | None = None, + ) -> None: + """wait for the number of docs to be indexed on the connector. + This is used to test pausing a connector in the middle of indexing and + terminating that indexing.""" + print(f"Indexing wait for inactive starting: cc_pair={cc_pair.id}") + start = time.monotonic() + while True: + fetched_cc_pairs = CCPairManager.get_indexing_statuses( + user_performing_action + ) + for fetched_cc_pair in fetched_cc_pairs: + if fetched_cc_pair.cc_pair_id != cc_pair.id: + continue + + if fetched_cc_pair.in_progress: + continue + + print(f"Indexing is inactive: cc_pair={cc_pair.id}") + return + + elapsed = time.monotonic() - start + if elapsed > timeout: + raise TimeoutError( + f"Indexing wait for inactive timed out: cc_pair={cc_pair.id} timeout={timeout}s" + ) + + print( + f"Indexing wait for inactive still waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s" + ) + time.sleep(5) + + @staticmethod + def wait_for_indexing_in_progress( + cc_pair: DATestCCPair, + timeout: float = MAX_DELAY, + num_docs: int = 16, + user_performing_action: DATestUser | None = None, + ) -> None: + """wait for the number of docs to be indexed on the connector. + This is used to test pausing a connector in the middle of indexing and + terminating that indexing.""" + start = time.monotonic() + while True: + fetched_cc_pairs = CCPairManager.get_indexing_statuses( + user_performing_action + ) + for fetched_cc_pair in fetched_cc_pairs: + if fetched_cc_pair.cc_pair_id != cc_pair.id: + continue + + if not fetched_cc_pair.in_progress: + continue + + if fetched_cc_pair.docs_indexed >= num_docs: + print( + "Indexed at least the requested number of docs: " + f"cc_pair={cc_pair.id} " + f"docs_indexed={fetched_cc_pair.docs_indexed} " + f"num_docs={num_docs}" + ) + return + + elapsed = time.monotonic() - start + if elapsed > timeout: + raise TimeoutError( + f"Indexing in progress wait timed out: cc_pair={cc_pair.id} timeout={timeout}s" + ) + + print( + f"Indexing in progress waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s" + ) + time.sleep(5) + + @staticmethod + def wait_for_indexing_completion( cc_pair: DATestCCPair, after: datetime, timeout: float = MAX_DELAY, diff --git a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py index 3c3733254..8045501ce 100644 --- a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py +++ b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py @@ -78,7 +78,7 @@ def test_slack_permission_sync( access_type=AccessType.SYNC, user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, @@ -113,7 +113,7 @@ def test_slack_permission_sync( # Run indexing before = datetime.now(timezone.utc) CCPairManager.run_once(cc_pair, admin_user) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, @@ -305,7 +305,7 @@ def test_slack_group_permission_sync( # Run indexing CCPairManager.run_once(cc_pair, admin_user) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, diff --git a/backend/tests/integration/connector_job_tests/slack/test_prune.py b/backend/tests/integration/connector_job_tests/slack/test_prune.py index 2dfc3d075..b2decb658 100644 --- a/backend/tests/integration/connector_job_tests/slack/test_prune.py +++ b/backend/tests/integration/connector_job_tests/slack/test_prune.py @@ -74,7 +74,7 @@ def test_slack_prune( access_type=AccessType.SYNC, user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, @@ -113,7 +113,7 @@ def test_slack_prune( # Run indexing before = datetime.now(timezone.utc) CCPairManager.run_once(cc_pair, admin_user) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair=cc_pair, after=before, user_performing_action=admin_user, diff --git a/backend/tests/integration/tests/connector/test_connector_creation.py b/backend/tests/integration/tests/connector/test_connector_creation.py index acfafe943..61085c5a5 100644 --- a/backend/tests/integration/tests/connector/test_connector_creation.py +++ b/backend/tests/integration/tests/connector/test_connector_creation.py @@ -58,7 +58,7 @@ def test_overlapping_connector_creation(reset: None) -> None: user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair_1, now, timeout=120, user_performing_action=admin_user ) @@ -71,7 +71,7 @@ def test_overlapping_connector_creation(reset: None) -> None: user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair_2, now, timeout=120, user_performing_action=admin_user ) @@ -82,3 +82,48 @@ def test_overlapping_connector_creation(reset: None) -> None: assert info_2 assert info_1.num_docs_indexed == info_2.num_docs_indexed + + +def test_connector_pause_while_indexing(reset: None) -> None: + """Tests that we can pause a connector while indexing is in progress and that + tasks end early or abort as a result. + + TODO: This does not specifically test for soft or hard termination code paths. + Design specific tests for those use cases. + """ + admin_user: DATestUser = UserManager.create(name="admin_user") + + config = { + "wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"], + "space": "", + "is_cloud": True, + "page_id": "", + } + + credential = { + "confluence_username": os.environ["CONFLUENCE_USER_NAME"], + "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"], + } + + # store the time before we create the connector so that we know after + # when the indexing should have started + datetime.now(timezone.utc) + + # create connector + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.CONFLUENCE, + connector_specific_config=config, + credential_json=credential, + user_performing_action=admin_user, + ) + + CCPairManager.wait_for_indexing_in_progress( + cc_pair_1, timeout=60, num_docs=16, user_performing_action=admin_user + ) + + CCPairManager.pause_cc_pair(cc_pair_1, user_performing_action=admin_user) + + CCPairManager.wait_for_indexing_inactive( + cc_pair_1, timeout=60, user_performing_action=admin_user + ) + return diff --git a/backend/tests/integration/tests/pruning/test_pruning.py b/backend/tests/integration/tests/pruning/test_pruning.py index 9d9a41c70..beb1e8efb 100644 --- a/backend/tests/integration/tests/pruning/test_pruning.py +++ b/backend/tests/integration/tests/pruning/test_pruning.py @@ -135,7 +135,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None: user_performing_action=admin_user, ) - CCPairManager.wait_for_indexing( + CCPairManager.wait_for_indexing_completion( cc_pair_1, now, timeout=60, user_performing_action=admin_user ) diff --git a/web/src/app/admin/configuration/search/UpgradingPage.tsx b/web/src/app/admin/configuration/search/UpgradingPage.tsx index ecd7f8731..98653c4aa 100644 --- a/web/src/app/admin/configuration/search/UpgradingPage.tsx +++ b/web/src/app/admin/configuration/search/UpgradingPage.tsx @@ -161,7 +161,7 @@ export default function UpgradingPage({ reindexingProgress={sortedReindexingProgress} /> ) : ( - + )} ) : ( @@ -171,7 +171,7 @@ export default function UpgradingPage({

You're currently switching embedding models, but there - are no connectors to re-index. This means the transition will + are no connectors to reindex. This means the transition will be quick and seamless!

diff --git a/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx b/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx index 71d26a8eb..b5b4e7ecb 100644 --- a/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx +++ b/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx @@ -6,6 +6,8 @@ import { usePopup } from "@/components/admin/connectors/Popup"; import { mutate } from "swr"; import { buildCCPairInfoUrl } from "./lib"; import { setCCPairStatus } from "@/lib/ccPair"; +import { useState } from "react"; +import { LoadingAnimation } from "@/components/Loading"; export function ModifyStatusButtonCluster({ ccPair, @@ -13,44 +15,72 @@ export function ModifyStatusButtonCluster({ ccPair: CCPairFullInfo; }) { const { popup, setPopup } = usePopup(); + const [isUpdating, setIsUpdating] = useState(false); + + const handleStatusChange = async ( + newStatus: ConnectorCredentialPairStatus + ) => { + if (isUpdating) return; // Prevent double-clicks or multiple requests + setIsUpdating(true); + + try { + // Call the backend to update the status + await setCCPairStatus(ccPair.id, newStatus, setPopup); + + // Use mutate to revalidate the status on the backend + await mutate(buildCCPairInfoUrl(ccPair.id)); + } catch (error) { + console.error("Failed to update status", error); + } finally { + // Reset local updating state and button text after mutation + setIsUpdating(false); + } + }; + + // Compute the button text based on current state and backend status + const buttonText = + ccPair.status === ConnectorCredentialPairStatus.PAUSED + ? "Re-Enable" + : "Pause"; + + const tooltip = + ccPair.status === ConnectorCredentialPairStatus.PAUSED + ? "Click to start indexing again!" + : "When paused, the connector's documents will still be visible. However, no new documents will be indexed."; return ( <> {popup} - {ccPair.status === ConnectorCredentialPairStatus.PAUSED ? ( - - ) : ( - - )} + ); } diff --git a/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx b/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx index af0e2a8f4..962339e9f 100644 --- a/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx +++ b/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx @@ -121,7 +121,7 @@ export function ReIndexButton({ {popup}