diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py
index 9413dd978..caae8be30 100644
--- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py
+++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py
@@ -5,7 +5,6 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
-from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -37,7 +36,7 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
- lock_beat = r.lock(
+ lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -60,7 +59,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
- self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
+ self.app, cc_pair_id, db_session, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
@@ -86,7 +85,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,
cc_pair_id: int,
db_session: Session,
- r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
diff --git a/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py b/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py
index eef14e980..babf9b69b 100644
--- a/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py
+++ b/backend/danswer/background/celery/tasks/doc_permission_syncing/tasks.py
@@ -8,6 +8,7 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
+from redis.lock import Lock as RedisLock
from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
@@ -27,7 +28,7 @@ from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
- RedisConnectorPermissionSyncData,
+ RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
@@ -138,7 +139,7 @@ def try_creating_permissions_sync_task(
LOCK_TIMEOUT = 30
- lock = r.lock(
+ lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
@@ -162,7 +163,7 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
- app.send_task(
+ result = app.send_task(
"connector_permission_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
@@ -174,8 +175,8 @@ def try_creating_permissions_sync_task(
)
# set a basic fence to start
- payload = RedisConnectorPermissionSyncData(
- started=None,
+ payload = RedisConnectorPermissionSyncPayload(
+ started=None, celery_task_id=result.id
)
redis_connector.permissions.set_fence(payload)
@@ -247,9 +248,11 @@ def connector_permission_sync_generator_task(
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
- payload = RedisConnectorPermissionSyncData(
- started=datetime.now(timezone.utc),
- )
+ payload = redis_connector.permissions.payload
+ if not payload:
+ raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
+
+ payload.started = datetime.now(timezone.utc)
redis_connector.permissions.set_fence(payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
diff --git a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py
index c3f0f6c6f..d80b2b518 100644
--- a/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py
+++ b/backend/danswer/background/celery/tasks/external_group_syncing/tasks.py
@@ -8,6 +8,7 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
+from redis.lock import Lock as RedisLock
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
@@ -24,6 +25,9 @@ from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
+from danswer.redis.redis_connector_ext_group_sync import (
+ RedisConnectorExternalGroupSyncPayload,
+)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
@@ -107,7 +111,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
- tasks_created = try_creating_permissions_sync_task(
+ tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
@@ -125,7 +129,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
-def try_creating_permissions_sync_task(
+def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
@@ -156,7 +160,7 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
- _ = app.send_task(
+ result = app.send_task(
"connector_external_group_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
@@ -166,8 +170,13 @@ def try_creating_permissions_sync_task(
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
- # set a basic fence to start
- redis_connector.external_group_sync.set_fence(True)
+
+ payload = RedisConnectorExternalGroupSyncPayload(
+ started=datetime.now(timezone.utc),
+ celery_task_id=result.id,
+ )
+
+ redis_connector.external_group_sync.set_fence(payload)
except Exception:
task_logger.exception(
@@ -203,7 +212,7 @@ def connector_external_group_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
- lock = r.lock(
+ lock: RedisLock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
@@ -253,7 +262,6 @@ def connector_external_group_sync_generator_task(
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
-
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
@@ -264,6 +272,6 @@ def connector_external_group_sync_generator_task(
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
- redis_connector.external_group_sync.set_fence(False)
+ redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()
diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py
index 9ebab40d0..4525b1e94 100644
--- a/backend/danswer/background/celery/tasks/indexing/tasks.py
+++ b/backend/danswer/background/celery/tasks/indexing/tasks.py
@@ -39,12 +39,13 @@ from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
+from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
+from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
-from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
@@ -209,17 +210,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
- # Get the primary search settings
- primary_search_settings = get_current_search_settings(db_session)
- search_settings = [primary_search_settings]
-
- # Check for secondary search settings
- secondary_search_settings = get_secondary_search_settings(db_session)
- if secondary_search_settings is not None:
- # If secondary settings exist, add them to the list
- search_settings.append(secondary_search_settings)
-
- for search_settings_instance in search_settings:
+ search_settings_list: list[SearchSettings] = get_active_search_settings(
+ db_session
+ )
+ for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@@ -237,7 +231,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)
search_settings_primary = False
- if search_settings_instance.id == primary_search_settings.id:
+ if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
@@ -245,13 +239,13 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
last_index=last_attempt,
search_settings_instance=search_settings_instance,
search_settings_primary=search_settings_primary,
- secondary_index_building=len(search_settings) > 1,
+ secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue
reindex = False
- if search_settings_instance.id == primary_search_settings.id:
+ if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
@@ -284,7 +278,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"Connector indexing queued: "
f"index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
- f"search_settings={search_settings_instance.id} "
+ f"search_settings={search_settings_instance.id}"
)
tasks_created += 1
@@ -529,8 +523,11 @@ def try_creating_indexing_task(
return index_attempt_id
-@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
+@shared_task(
+ name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
+)
def connector_indexing_proxy_task(
+ self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
@@ -543,6 +540,10 @@ def connector_indexing_proxy_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
+
+ if not self.request.id:
+ task_logger.error("self.request.id is None!")
+
client = SimpleJobClient()
job = client.submit(
@@ -571,8 +572,30 @@ def connector_indexing_proxy_task(
f"search_settings={search_settings_id}"
)
+ redis_connector = RedisConnector(tenant_id, cc_pair_id)
+ redis_connector_index = redis_connector.new_index(search_settings_id)
+
while True:
- sleep(10)
+ sleep(5)
+
+ if self.request.id and redis_connector_index.terminating(self.request.id):
+ task_logger.warning(
+ "Indexing proxy - termination signal detected: "
+ f"attempt={index_attempt_id} "
+ f"tenant={tenant_id} "
+ f"cc_pair={cc_pair_id} "
+ f"search_settings={search_settings_id}"
+ )
+
+ with get_session_with_tenant(tenant_id) as db_session:
+ mark_attempt_canceled(
+ index_attempt_id,
+ db_session,
+ "Connector termination signal detected",
+ )
+
+ job.cancel()
+ break
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py
index ec7f52bc0..f491ff27b 100644
--- a/backend/danswer/background/celery/tasks/vespa/tasks.py
+++ b/backend/danswer/background/celery/tasks/vespa/tasks.py
@@ -46,6 +46,7 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
+from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
@@ -58,7 +59,7 @@ from danswer.redis.redis_connector_credential_pair import RedisConnectorCredenti
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
- RedisConnectorPermissionSyncData,
+ RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
@@ -588,7 +589,7 @@ def monitor_ccpair_permissions_taskset(
if remaining > 0:
return
- payload: RedisConnectorPermissionSyncData | None = (
+ payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
@@ -596,9 +597,7 @@ def monitor_ccpair_permissions_taskset(
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
- redis_connector.permissions.taskset_clear()
- redis_connector.permissions.generator_clear()
- redis_connector.permissions.set_fence(None)
+ redis_connector.permissions.reset()
def monitor_ccpair_indexing_taskset(
@@ -678,11 +677,15 @@ def monitor_ccpair_indexing_taskset(
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
- mark_attempt_failed(
- index_attempt_id=payload.index_attempt_id,
- db_session=db_session,
- failure_reason=msg,
- )
+ if (
+ index_attempt.status != IndexingStatus.CANCELED
+ and index_attempt.status != IndexingStatus.FAILED
+ ):
+ mark_attempt_failed(
+ index_attempt_id=payload.index_attempt_id,
+ db_session=db_session,
+ failure_reason=msg,
+ )
redis_connector_index.reset()
return
@@ -692,6 +695,7 @@ def monitor_ccpair_indexing_taskset(
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
+ f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -724,7 +728,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
- n_celery = celery_get_queue_length("celery", r)
+ n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)
diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py
index 699e4682c..40ed778f0 100644
--- a/backend/danswer/background/indexing/run_indexing.py
+++ b/backend/danswer/background/indexing/run_indexing.py
@@ -19,6 +19,7 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
+from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
@@ -87,6 +88,10 @@ def _get_connector_runner(
)
+class ConnectorStopSignal(Exception):
+ """A custom exception used to signal a stop in processing."""
+
+
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
@@ -208,9 +213,7 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
- raise RuntimeError(
- "_run_indexing: Connector stop signal detected"
- )
+ raise ConnectorStopSignal("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
@@ -304,26 +307,16 @@ def _run_indexing(
)
except Exception as e:
logger.exception(
- f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
+ f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
)
- # Only mark the attempt as a complete failure if this is the first indexing window.
- # Otherwise, some progress was made - the next run will not start from the beginning.
- # In this case, it is not accurate to mark it as a failure. When the next run begins,
- # if that fails immediately, it will be marked as a failure.
- #
- # NOTE: if the connector is manually disabled, we should mark it as a failure regardless
- # to give better clarity in the UI, as the next run will never happen.
- if (
- ind == 0
- or not db_cc_pair.status.is_active()
- or index_attempt.status != IndexingStatus.IN_PROGRESS
- ):
- mark_attempt_failed(
+
+ if isinstance(e, ConnectorStopSignal):
+ mark_attempt_canceled(
index_attempt.id,
db_session,
- failure_reason=str(e),
- full_exception_trace=traceback.format_exc(),
+ reason=str(e),
)
+
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@@ -335,6 +328,37 @@ def _run_indexing(
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
+ else:
+ # Only mark the attempt as a complete failure if this is the first indexing window.
+ # Otherwise, some progress was made - the next run will not start from the beginning.
+ # In this case, it is not accurate to mark it as a failure. When the next run begins,
+ # if that fails immediately, it will be marked as a failure.
+ #
+ # NOTE: if the connector is manually disabled, we should mark it as a failure regardless
+ # to give better clarity in the UI, as the next run will never happen.
+ if (
+ ind == 0
+ or not db_cc_pair.status.is_active()
+ or index_attempt.status != IndexingStatus.IN_PROGRESS
+ ):
+ mark_attempt_failed(
+ index_attempt.id,
+ db_session,
+ failure_reason=str(e),
+ full_exception_trace=traceback.format_exc(),
+ )
+
+ if is_primary:
+ update_connector_credential_pair(
+ db_session=db_session,
+ connector_id=db_connector.id,
+ credential_id=db_credential.id,
+ net_docs=net_doc_change,
+ )
+
+ if INDEXING_TRACER_INTERVAL > 0:
+ tracer.stop()
+ raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
diff --git a/backend/danswer/db/search_settings.py b/backend/danswer/db/search_settings.py
index 4f437eaae..1134b326a 100644
--- a/backend/danswer/db/search_settings.py
+++ b/backend/danswer/db/search_settings.py
@@ -143,6 +143,25 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
return latest_settings
+def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
+ """Returns active search settings. The first entry will always be the current search
+ settings. If there are new search settings that are being migrated to, those will be
+ the second entry."""
+ search_settings_list: list[SearchSettings] = []
+
+ # Get the primary search settings
+ primary_search_settings = get_current_search_settings(db_session)
+ search_settings_list.append(primary_search_settings)
+
+ # Check for secondary search settings
+ secondary_search_settings = get_secondary_search_settings(db_session)
+ if secondary_search_settings is not None:
+ # If secondary settings exist, add them to the list
+ search_settings_list.append(secondary_search_settings)
+
+ return search_settings_list
+
+
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
query = select(SearchSettings).order_by(SearchSettings.id.desc())
result = db_session.execute(query)
diff --git a/backend/danswer/redis/redis_connector.py b/backend/danswer/redis/redis_connector.py
index 8b52a2fd8..8d82fc119 100644
--- a/backend/danswer/redis/redis_connector.py
+++ b/backend/danswer/redis/redis_connector.py
@@ -1,5 +1,8 @@
+import time
+
import redis
+from danswer.db.models import SearchSettings
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
@@ -31,6 +34,44 @@ class RedisConnector:
self.tenant_id, self.id, search_settings_id, self.redis
)
+ def wait_for_indexing_termination(
+ self,
+ search_settings_list: list[SearchSettings],
+ timeout: float = 15.0,
+ ) -> bool:
+ """
+ Returns True if all indexing for the given redis connector is finished within the given timeout.
+ Returns False if the timeout is exceeded
+
+ This check does not guarantee that current indexings being terminated
+ won't get restarted midflight
+ """
+
+ finished = False
+
+ start = time.monotonic()
+
+ while True:
+ still_indexing = False
+ for search_settings in search_settings_list:
+ redis_connector_index = self.new_index(search_settings.id)
+ if redis_connector_index.fenced:
+ still_indexing = True
+ break
+
+ if not still_indexing:
+ finished = True
+ break
+
+ now = time.monotonic()
+ if now - start > timeout:
+ break
+
+ time.sleep(1)
+ continue
+
+ return finished
+
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
diff --git a/backend/danswer/redis/redis_connector_doc_perm_sync.py b/backend/danswer/redis/redis_connector_doc_perm_sync.py
index d9c3cd814..7b3748fcc 100644
--- a/backend/danswer/redis/redis_connector_doc_perm_sync.py
+++ b/backend/danswer/redis/redis_connector_doc_perm_sync.py
@@ -14,8 +14,9 @@ from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
-class RedisConnectorPermissionSyncData(BaseModel):
+class RedisConnectorPermissionSyncPayload(BaseModel):
started: datetime | None
+ celery_task_id: str | None
class RedisConnectorPermissionSync:
@@ -78,14 +79,14 @@ class RedisConnectorPermissionSync:
return False
@property
- def payload(self) -> RedisConnectorPermissionSyncData | None:
+ def payload(self) -> RedisConnectorPermissionSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
- payload = RedisConnectorPermissionSyncData.model_validate_json(
+ payload = RedisConnectorPermissionSyncPayload.model_validate_json(
cast(str, fence_str)
)
@@ -93,7 +94,7 @@ class RedisConnectorPermissionSync:
def set_fence(
self,
- payload: RedisConnectorPermissionSyncData | None,
+ payload: RedisConnectorPermissionSyncPayload | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)
@@ -162,6 +163,12 @@ class RedisConnectorPermissionSync:
return len(async_results)
+ def reset(self) -> None:
+ self.redis.delete(self.generator_progress_key)
+ self.redis.delete(self.generator_complete_key)
+ self.redis.delete(self.taskset_key)
+ self.redis.delete(self.fence_key)
+
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"
diff --git a/backend/danswer/redis/redis_connector_ext_group_sync.py b/backend/danswer/redis/redis_connector_ext_group_sync.py
index 631845648..bbe539c39 100644
--- a/backend/danswer/redis/redis_connector_ext_group_sync.py
+++ b/backend/danswer/redis/redis_connector_ext_group_sync.py
@@ -1,11 +1,18 @@
+from datetime import datetime
from typing import cast
import redis
from celery import Celery
+from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
+class RedisConnectorExternalGroupSyncPayload(BaseModel):
+ started: datetime | None
+ celery_task_id: str | None
+
+
class RedisConnectorExternalGroupSync:
"""Manages interactions with redis for external group syncing tasks. Should only be accessed
through RedisConnector."""
@@ -68,12 +75,29 @@ class RedisConnectorExternalGroupSync:
return False
- def set_fence(self, value: bool) -> None:
- if not value:
+ @property
+ def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
+ # read related data and evaluate/print task progress
+ fence_bytes = cast(bytes, self.redis.get(self.fence_key))
+ if fence_bytes is None:
+ return None
+
+ fence_str = fence_bytes.decode("utf-8")
+ payload = RedisConnectorExternalGroupSyncPayload.model_validate_json(
+ cast(str, fence_str)
+ )
+
+ return payload
+
+ def set_fence(
+ self,
+ payload: RedisConnectorExternalGroupSyncPayload | None,
+ ) -> None:
+ if not payload:
self.redis.delete(self.fence_key)
return
- self.redis.set(self.fence_key, 0)
+ self.redis.set(self.fence_key, payload.model_dump_json())
@property
def generator_complete(self) -> int | None:
diff --git a/backend/danswer/redis/redis_connector_index.py b/backend/danswer/redis/redis_connector_index.py
index 10fd3667f..40b194af0 100644
--- a/backend/danswer/redis/redis_connector_index.py
+++ b/backend/danswer/redis/redis_connector_index.py
@@ -29,6 +29,8 @@ class RedisConnectorIndex:
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
+ TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
+
def __init__(
self,
tenant_id: str | None,
@@ -51,6 +53,7 @@ class RedisConnectorIndex:
self.generator_lock_key = (
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
)
+ self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
@@ -92,6 +95,18 @@ class RedisConnectorIndex:
self.redis.set(self.fence_key, payload.model_dump_json())
+ def terminating(self, celery_task_id: str) -> bool:
+ if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
+ return True
+
+ return False
+
+ def set_terminate(self, celery_task_id: str) -> None:
+ """This sets a signal. It does not block!"""
+ # We shouldn't need very long to terminate the spawned task.
+ # 10 minute TTL is good.
+ self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
+
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)
diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py
index 55808ebce..46bdb2078 100644
--- a/backend/danswer/server/documents/cc_pair.py
+++ b/backend/danswer/server/documents/cc_pair.py
@@ -6,6 +6,7 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
+from fastapi.responses import JSONResponse
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@@ -37,7 +38,9 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import count_index_attempts_for_connector
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
+from danswer.db.models import SearchSettings
from danswer.db.models import User
+from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
@@ -158,7 +161,19 @@ def update_cc_pair_status(
status_update_request: CCStatusUpdateRequest,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
-) -> None:
+ tenant_id: str | None = Depends(get_current_tenant_id),
+) -> JSONResponse:
+ """This method may wait up to 30 seconds if pausing the connector due to the need to
+ terminate tasks in progress. Tasks are not guaranteed to terminate within the
+ timeout.
+
+ Returns HTTPStatus.OK if everything finished.
+ Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks
+ did not finish within the timeout.
+ """
+ WAIT_TIMEOUT = 15.0
+ still_terminating = False
+
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
@@ -173,10 +188,76 @@ def update_cc_pair_status(
)
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
- cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
+ search_settings_list: list[SearchSettings] = get_active_search_settings(
+ db_session
+ )
+ cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
cancel_indexing_attempts_past_model(db_session)
+ redis_connector = RedisConnector(tenant_id, cc_pair_id)
+
+ try:
+ redis_connector.stop.set_fence(True)
+ while True:
+ logger.debug(
+ f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}"
+ )
+ wait_succeeded = redis_connector.wait_for_indexing_termination(
+ search_settings_list, WAIT_TIMEOUT
+ )
+ if wait_succeeded:
+ logger.debug(
+ f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}"
+ )
+ break
+
+ logger.debug(
+ "Wait for indexing soft termination timed out. "
+ f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}"
+ )
+
+ for search_settings in search_settings_list:
+ redis_connector_index = redis_connector.new_index(
+ search_settings.id
+ )
+ if not redis_connector_index.fenced:
+ continue
+
+ index_payload = redis_connector_index.payload
+ if not index_payload:
+ continue
+
+ if not index_payload.celery_task_id:
+ continue
+
+ # Revoke the task to prevent it from running
+ primary_app.control.revoke(index_payload.celery_task_id)
+
+ # If it is running, then signaling for termination will get the
+ # watchdog thread to kill the spawned task
+ redis_connector_index.set_terminate(index_payload.celery_task_id)
+
+ logger.debug(
+ f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}"
+ )
+ wait_succeeded = redis_connector.wait_for_indexing_termination(
+ search_settings_list, WAIT_TIMEOUT
+ )
+ if wait_succeeded:
+ logger.debug(
+ f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}"
+ )
+ break
+
+ logger.debug(
+ f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}"
+ )
+ still_terminating = True
+ break
+ finally:
+ redis_connector.stop.set_fence(False)
+
update_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -185,6 +266,18 @@ def update_cc_pair_status(
db_session.commit()
+ if still_terminating:
+ return JSONResponse(
+ status_code=HTTPStatus.ACCEPTED,
+ content={
+ "message": "Request accepted, background task termination still in progress"
+ },
+ )
+
+ return JSONResponse(
+ status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
+ )
+
@router.put("/admin/cc-pair/{cc_pair_id}/name")
def update_cc_pair_name(
@@ -267,9 +360,9 @@ def prune_cc_pair(
)
logger.info(
- f"Pruning cc_pair: cc_pair_id={cc_pair_id} "
- f"connector_id={cc_pair.connector_id} "
- f"credential_id={cc_pair.credential_id} "
+ f"Pruning cc_pair: cc_pair={cc_pair_id} "
+ f"connector={cc_pair.connector_id} "
+ f"credential={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_prune_generator_task(
diff --git a/backend/tests/integration/common_utils/managers/cc_pair.py b/backend/tests/integration/common_utils/managers/cc_pair.py
index b37822d34..d32e10056 100644
--- a/backend/tests/integration/common_utils/managers/cc_pair.py
+++ b/backend/tests/integration/common_utils/managers/cc_pair.py
@@ -240,7 +240,85 @@ class CCPairManager:
result.raise_for_status()
@staticmethod
- def wait_for_indexing(
+ def wait_for_indexing_inactive(
+ cc_pair: DATestCCPair,
+ timeout: float = MAX_DELAY,
+ user_performing_action: DATestUser | None = None,
+ ) -> None:
+ """wait for the number of docs to be indexed on the connector.
+ This is used to test pausing a connector in the middle of indexing and
+ terminating that indexing."""
+ print(f"Indexing wait for inactive starting: cc_pair={cc_pair.id}")
+ start = time.monotonic()
+ while True:
+ fetched_cc_pairs = CCPairManager.get_indexing_statuses(
+ user_performing_action
+ )
+ for fetched_cc_pair in fetched_cc_pairs:
+ if fetched_cc_pair.cc_pair_id != cc_pair.id:
+ continue
+
+ if fetched_cc_pair.in_progress:
+ continue
+
+ print(f"Indexing is inactive: cc_pair={cc_pair.id}")
+ return
+
+ elapsed = time.monotonic() - start
+ if elapsed > timeout:
+ raise TimeoutError(
+ f"Indexing wait for inactive timed out: cc_pair={cc_pair.id} timeout={timeout}s"
+ )
+
+ print(
+ f"Indexing wait for inactive still waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s"
+ )
+ time.sleep(5)
+
+ @staticmethod
+ def wait_for_indexing_in_progress(
+ cc_pair: DATestCCPair,
+ timeout: float = MAX_DELAY,
+ num_docs: int = 16,
+ user_performing_action: DATestUser | None = None,
+ ) -> None:
+ """wait for the number of docs to be indexed on the connector.
+ This is used to test pausing a connector in the middle of indexing and
+ terminating that indexing."""
+ start = time.monotonic()
+ while True:
+ fetched_cc_pairs = CCPairManager.get_indexing_statuses(
+ user_performing_action
+ )
+ for fetched_cc_pair in fetched_cc_pairs:
+ if fetched_cc_pair.cc_pair_id != cc_pair.id:
+ continue
+
+ if not fetched_cc_pair.in_progress:
+ continue
+
+ if fetched_cc_pair.docs_indexed >= num_docs:
+ print(
+ "Indexed at least the requested number of docs: "
+ f"cc_pair={cc_pair.id} "
+ f"docs_indexed={fetched_cc_pair.docs_indexed} "
+ f"num_docs={num_docs}"
+ )
+ return
+
+ elapsed = time.monotonic() - start
+ if elapsed > timeout:
+ raise TimeoutError(
+ f"Indexing in progress wait timed out: cc_pair={cc_pair.id} timeout={timeout}s"
+ )
+
+ print(
+ f"Indexing in progress waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s"
+ )
+ time.sleep(5)
+
+ @staticmethod
+ def wait_for_indexing_completion(
cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,
diff --git a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py
index 3c3733254..8045501ce 100644
--- a/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py
+++ b/backend/tests/integration/connector_job_tests/slack/test_permission_sync.py
@@ -78,7 +78,7 @@ def test_slack_permission_sync(
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
- CCPairManager.wait_for_indexing(
+ CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
@@ -113,7 +113,7 @@ def test_slack_permission_sync(
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
- CCPairManager.wait_for_indexing(
+ CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
@@ -305,7 +305,7 @@ def test_slack_group_permission_sync(
# Run indexing
CCPairManager.run_once(cc_pair, admin_user)
- CCPairManager.wait_for_indexing(
+ CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
diff --git a/backend/tests/integration/connector_job_tests/slack/test_prune.py b/backend/tests/integration/connector_job_tests/slack/test_prune.py
index 2dfc3d075..b2decb658 100644
--- a/backend/tests/integration/connector_job_tests/slack/test_prune.py
+++ b/backend/tests/integration/connector_job_tests/slack/test_prune.py
@@ -74,7 +74,7 @@ def test_slack_prune(
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
- CCPairManager.wait_for_indexing(
+ CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
@@ -113,7 +113,7 @@ def test_slack_prune(
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
- CCPairManager.wait_for_indexing(
+ CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
diff --git a/backend/tests/integration/tests/connector/test_connector_creation.py b/backend/tests/integration/tests/connector/test_connector_creation.py
index acfafe943..61085c5a5 100644
--- a/backend/tests/integration/tests/connector/test_connector_creation.py
+++ b/backend/tests/integration/tests/connector/test_connector_creation.py
@@ -58,7 +58,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
user_performing_action=admin_user,
)
- CCPairManager.wait_for_indexing(
+ CCPairManager.wait_for_indexing_completion(
cc_pair_1, now, timeout=120, user_performing_action=admin_user
)
@@ -71,7 +71,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
user_performing_action=admin_user,
)
- CCPairManager.wait_for_indexing(
+ CCPairManager.wait_for_indexing_completion(
cc_pair_2, now, timeout=120, user_performing_action=admin_user
)
@@ -82,3 +82,48 @@ def test_overlapping_connector_creation(reset: None) -> None:
assert info_2
assert info_1.num_docs_indexed == info_2.num_docs_indexed
+
+
+def test_connector_pause_while_indexing(reset: None) -> None:
+ """Tests that we can pause a connector while indexing is in progress and that
+ tasks end early or abort as a result.
+
+ TODO: This does not specifically test for soft or hard termination code paths.
+ Design specific tests for those use cases.
+ """
+ admin_user: DATestUser = UserManager.create(name="admin_user")
+
+ config = {
+ "wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
+ "space": "",
+ "is_cloud": True,
+ "page_id": "",
+ }
+
+ credential = {
+ "confluence_username": os.environ["CONFLUENCE_USER_NAME"],
+ "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
+ }
+
+ # store the time before we create the connector so that we know after
+ # when the indexing should have started
+ datetime.now(timezone.utc)
+
+ # create connector
+ cc_pair_1 = CCPairManager.create_from_scratch(
+ source=DocumentSource.CONFLUENCE,
+ connector_specific_config=config,
+ credential_json=credential,
+ user_performing_action=admin_user,
+ )
+
+ CCPairManager.wait_for_indexing_in_progress(
+ cc_pair_1, timeout=60, num_docs=16, user_performing_action=admin_user
+ )
+
+ CCPairManager.pause_cc_pair(cc_pair_1, user_performing_action=admin_user)
+
+ CCPairManager.wait_for_indexing_inactive(
+ cc_pair_1, timeout=60, user_performing_action=admin_user
+ )
+ return
diff --git a/backend/tests/integration/tests/pruning/test_pruning.py b/backend/tests/integration/tests/pruning/test_pruning.py
index 9d9a41c70..beb1e8efb 100644
--- a/backend/tests/integration/tests/pruning/test_pruning.py
+++ b/backend/tests/integration/tests/pruning/test_pruning.py
@@ -135,7 +135,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
user_performing_action=admin_user,
)
- CCPairManager.wait_for_indexing(
+ CCPairManager.wait_for_indexing_completion(
cc_pair_1, now, timeout=60, user_performing_action=admin_user
)
diff --git a/web/src/app/admin/configuration/search/UpgradingPage.tsx b/web/src/app/admin/configuration/search/UpgradingPage.tsx
index ecd7f8731..98653c4aa 100644
--- a/web/src/app/admin/configuration/search/UpgradingPage.tsx
+++ b/web/src/app/admin/configuration/search/UpgradingPage.tsx
@@ -161,7 +161,7 @@ export default function UpgradingPage({
reindexingProgress={sortedReindexingProgress}
/>
) : (
-
You're currently switching embedding models, but there - are no connectors to re-index. This means the transition will + are no connectors to reindex. This means the transition will be quick and seamless!
diff --git a/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx b/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx index 71d26a8eb..b5b4e7ecb 100644 --- a/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx +++ b/web/src/app/admin/connector/[ccPairId]/ModifyStatusButtonCluster.tsx @@ -6,6 +6,8 @@ import { usePopup } from "@/components/admin/connectors/Popup"; import { mutate } from "swr"; import { buildCCPairInfoUrl } from "./lib"; import { setCCPairStatus } from "@/lib/ccPair"; +import { useState } from "react"; +import { LoadingAnimation } from "@/components/Loading"; export function ModifyStatusButtonCluster({ ccPair, @@ -13,44 +15,72 @@ export function ModifyStatusButtonCluster({ ccPair: CCPairFullInfo; }) { const { popup, setPopup } = usePopup(); + const [isUpdating, setIsUpdating] = useState(false); + + const handleStatusChange = async ( + newStatus: ConnectorCredentialPairStatus + ) => { + if (isUpdating) return; // Prevent double-clicks or multiple requests + setIsUpdating(true); + + try { + // Call the backend to update the status + await setCCPairStatus(ccPair.id, newStatus, setPopup); + + // Use mutate to revalidate the status on the backend + await mutate(buildCCPairInfoUrl(ccPair.id)); + } catch (error) { + console.error("Failed to update status", error); + } finally { + // Reset local updating state and button text after mutation + setIsUpdating(false); + } + }; + + // Compute the button text based on current state and backend status + const buttonText = + ccPair.status === ConnectorCredentialPairStatus.PAUSED + ? "Re-Enable" + : "Pause"; + + const tooltip = + ccPair.status === ConnectorCredentialPairStatus.PAUSED + ? "Click to start indexing again!" + : "When paused, the connector's documents will still be visible. However, no new documents will be indexed."; return ( <> {popup} - {ccPair.status === ConnectorCredentialPairStatus.PAUSED ? ( - - ) : ( - - )} + > ); } diff --git a/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx b/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx index af0e2a8f4..962339e9f 100644 --- a/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx +++ b/web/src/app/admin/connector/[ccPairId]/ReIndexButton.tsx @@ -121,7 +121,7 @@ export function ReIndexButton({ {popup}