Feature/kill indexing (#3213)

* checkpoint

* add celery termination of the task

* rename to RedisConnectorPermissionSyncPayload, add RedisLock to more places, add get_active_search_settings

* rename payload

* pretty sure these weren't named correctly

* testing in progress

* cleanup

* remove space

* merge fix

* three dots animation on Pausing

* improve messaging when connector is stopped or killed and animate buttons

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
rkuo-danswer 2024-11-27 21:32:45 -08:00 committed by GitHub
parent 5be7d27285
commit 7f1e4a02bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 539 additions and 126 deletions

View File

@ -5,7 +5,6 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@ -37,7 +36,7 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@ -60,7 +59,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
self.app, cc_pair_id, db_session, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
@ -86,7 +85,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,
cc_pair_id: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:

View File

@ -8,6 +8,7 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
@ -27,7 +28,7 @@ from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
@ -138,7 +139,7 @@ def try_creating_permissions_sync_task(
LOCK_TIMEOUT = 30
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
@ -162,7 +163,7 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
app.send_task(
result = app.send_task(
"connector_permission_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
@ -174,8 +175,8 @@ def try_creating_permissions_sync_task(
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncData(
started=None,
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
)
redis_connector.permissions.set_fence(payload)
@ -247,9 +248,11 @@ def connector_permission_sync_generator_task(
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc),
)
payload = redis_connector.permissions.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
payload.started = datetime.now(timezone.utc)
redis_connector.permissions.set_fence(payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)

View File

@ -8,6 +8,7 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
@ -24,6 +25,9 @@ from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
@ -107,7 +111,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
@ -125,7 +129,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
def try_creating_permissions_sync_task(
def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
@ -156,7 +160,7 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
_ = app.send_task(
result = app.send_task(
"connector_external_group_sync_generator_task",
kwargs=dict(
cc_pair_id=cc_pair_id,
@ -166,8 +170,13 @@ def try_creating_permissions_sync_task(
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
redis_connector.external_group_sync.set_fence(True)
payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)
redis_connector.external_group_sync.set_fence(payload)
except Exception:
task_logger.exception(
@ -203,7 +212,7 @@ def connector_external_group_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
lock: RedisLock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
@ -253,7 +262,6 @@ def connector_external_group_sync_generator_task(
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
@ -264,6 +272,6 @@ def connector_external_group_sync_generator_task(
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(False)
redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()

View File

@ -39,12 +39,13 @@ from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
@ -209,17 +210,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@ -237,7 +231,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)
search_settings_primary = False
if search_settings_instance.id == primary_search_settings.id:
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
@ -245,13 +239,13 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
last_index=last_attempt,
search_settings_instance=search_settings_instance,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings) > 1,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.id == primary_search_settings.id:
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
@ -284,7 +278,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"Connector indexing queued: "
f"index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"search_settings={search_settings_instance.id}"
)
tasks_created += 1
@ -529,8 +523,11 @@ def try_creating_indexing_task(
return index_attempt_id
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
@shared_task(
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
)
def connector_indexing_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
@ -543,6 +540,10 @@ def connector_indexing_proxy_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
job = client.submit(
@ -571,8 +572,30 @@ def connector_indexing_proxy_task(
f"search_settings={search_settings_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
while True:
sleep(10)
sleep(5)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing proxy - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
job.cancel()
break
# do nothing for ongoing jobs that haven't been stopped
if not job.done():

View File

@ -46,6 +46,7 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
@ -58,7 +59,7 @@ from danswer.redis.redis_connector_credential_pair import RedisConnectorCredenti
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
@ -588,7 +589,7 @@ def monitor_ccpair_permissions_taskset(
if remaining > 0:
return
payload: RedisConnectorPermissionSyncData | None = (
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
@ -596,9 +597,7 @@ def monitor_ccpair_permissions_taskset(
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.taskset_clear()
redis_connector.permissions.generator_clear()
redis_connector.permissions.set_fence(None)
redis_connector.permissions.reset()
def monitor_ccpair_indexing_taskset(
@ -678,11 +677,15 @@ def monitor_ccpair_indexing_taskset(
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
redis_connector_index.reset()
return
@ -692,6 +695,7 @@ def monitor_ccpair_indexing_taskset(
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@ -724,7 +728,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r)
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)

View File

@ -19,6 +19,7 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
@ -87,6 +88,10 @@ def _get_connector_runner(
)
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
@ -208,9 +213,7 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise RuntimeError(
"_run_indexing: Connector stop signal detected"
)
raise ConnectorStopSignal("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
@ -304,26 +307,16 @@ def _run_indexing(
)
except Exception as e:
logger.exception(
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
)
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
if isinstance(e, ConnectorStopSignal):
mark_attempt_canceled(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
reason=str(e),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@ -335,6 +328,37 @@ def _run_indexing(
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
else:
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure

View File

@ -143,6 +143,25 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
return latest_settings
def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
"""Returns active search settings. The first entry will always be the current search
settings. If there are new search settings that are being migrated to, those will be
the second entry."""
search_settings_list: list[SearchSettings] = []
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings_list.append(primary_search_settings)
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings_list.append(secondary_search_settings)
return search_settings_list
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
query = select(SearchSettings).order_by(SearchSettings.id.desc())
result = db_session.execute(query)

View File

@ -1,5 +1,8 @@
import time
import redis
from danswer.db.models import SearchSettings
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
@ -31,6 +34,44 @@ class RedisConnector:
self.tenant_id, self.id, search_settings_id, self.redis
)
def wait_for_indexing_termination(
self,
search_settings_list: list[SearchSettings],
timeout: float = 15.0,
) -> bool:
"""
Returns True if all indexing for the given redis connector is finished within the given timeout.
Returns False if the timeout is exceeded
This check does not guarantee that current indexings being terminated
won't get restarted midflight
"""
finished = False
start = time.monotonic()
while True:
still_indexing = False
for search_settings in search_settings_list:
redis_connector_index = self.new_index(search_settings.id)
if redis_connector_index.fenced:
still_indexing = True
break
if not still_indexing:
finished = True
break
now = time.monotonic()
if now - start > timeout:
break
time.sleep(1)
continue
return finished
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""

View File

@ -14,8 +14,9 @@ from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
class RedisConnectorPermissionSyncData(BaseModel):
class RedisConnectorPermissionSyncPayload(BaseModel):
started: datetime | None
celery_task_id: str | None
class RedisConnectorPermissionSync:
@ -78,14 +79,14 @@ class RedisConnectorPermissionSync:
return False
@property
def payload(self) -> RedisConnectorPermissionSyncData | None:
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorPermissionSyncData.model_validate_json(
payload = RedisConnectorPermissionSyncPayload.model_validate_json(
cast(str, fence_str)
)
@ -93,7 +94,7 @@ class RedisConnectorPermissionSync:
def set_fence(
self,
payload: RedisConnectorPermissionSyncData | None,
payload: RedisConnectorPermissionSyncPayload | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)
@ -162,6 +163,12 @@ class RedisConnectorPermissionSync:
return len(async_results)
def reset(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"

View File

@ -1,11 +1,18 @@
from datetime import datetime
from typing import cast
import redis
from celery import Celery
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
class RedisConnectorExternalGroupSyncPayload(BaseModel):
started: datetime | None
celery_task_id: str | None
class RedisConnectorExternalGroupSync:
"""Manages interactions with redis for external group syncing tasks. Should only be accessed
through RedisConnector."""
@ -68,12 +75,29 @@ class RedisConnectorExternalGroupSync:
return False
def set_fence(self, value: bool) -> None:
if not value:
@property
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorExternalGroupSyncPayload.model_validate_json(
cast(str, fence_str)
)
return payload
def set_fence(
self,
payload: RedisConnectorExternalGroupSyncPayload | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)
return
self.redis.set(self.fence_key, 0)
self.redis.set(self.fence_key, payload.model_dump_json())
@property
def generator_complete(self) -> int | None:

View File

@ -29,6 +29,8 @@ class RedisConnectorIndex:
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
def __init__(
self,
tenant_id: str | None,
@ -51,6 +53,7 @@ class RedisConnectorIndex:
self.generator_lock_key = (
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
)
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
@classmethod
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
@ -92,6 +95,18 @@ class RedisConnectorIndex:
self.redis.set(self.fence_key, payload.model_dump_json())
def terminating(self, celery_task_id: str) -> bool:
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
return True
return False
def set_terminate(self, celery_task_id: str) -> None:
"""This sets a signal. It does not block!"""
# We shouldn't need very long to terminate the spawned task.
# 10 minute TTL is good.
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)

View File

@ -6,6 +6,7 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi.responses import JSONResponse
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@ -37,7 +38,9 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import count_index_attempts_for_connector
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from danswer.db.models import SearchSettings
from danswer.db.models import User
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
@ -158,7 +161,19 @@ def update_cc_pair_status(
status_update_request: CCStatusUpdateRequest,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""This method may wait up to 30 seconds if pausing the connector due to the need to
terminate tasks in progress. Tasks are not guaranteed to terminate within the
timeout.
Returns HTTPStatus.OK if everything finished.
Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks
did not finish within the timeout.
"""
WAIT_TIMEOUT = 15.0
still_terminating = False
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
@ -173,10 +188,76 @@ def update_cc_pair_status(
)
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
cancel_indexing_attempts_past_model(db_session)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
redis_connector.stop.set_fence(True)
while True:
logger.debug(
f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}"
)
wait_succeeded = redis_connector.wait_for_indexing_termination(
search_settings_list, WAIT_TIMEOUT
)
if wait_succeeded:
logger.debug(
f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}"
)
break
logger.debug(
"Wait for indexing soft termination timed out. "
f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}"
)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
if not redis_connector_index.fenced:
continue
index_payload = redis_connector_index.payload
if not index_payload:
continue
if not index_payload.celery_task_id:
continue
# Revoke the task to prevent it from running
primary_app.control.revoke(index_payload.celery_task_id)
# If it is running, then signaling for termination will get the
# watchdog thread to kill the spawned task
redis_connector_index.set_terminate(index_payload.celery_task_id)
logger.debug(
f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}"
)
wait_succeeded = redis_connector.wait_for_indexing_termination(
search_settings_list, WAIT_TIMEOUT
)
if wait_succeeded:
logger.debug(
f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}"
)
break
logger.debug(
f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}"
)
still_terminating = True
break
finally:
redis_connector.stop.set_fence(False)
update_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@ -185,6 +266,18 @@ def update_cc_pair_status(
db_session.commit()
if still_terminating:
return JSONResponse(
status_code=HTTPStatus.ACCEPTED,
content={
"message": "Request accepted, background task termination still in progress"
},
)
return JSONResponse(
status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
)
@router.put("/admin/cc-pair/{cc_pair_id}/name")
def update_cc_pair_name(
@ -267,9 +360,9 @@ def prune_cc_pair(
)
logger.info(
f"Pruning cc_pair: cc_pair_id={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"Pruning cc_pair: cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_prune_generator_task(

View File

@ -240,7 +240,85 @@ class CCPairManager:
result.raise_for_status()
@staticmethod
def wait_for_indexing(
def wait_for_indexing_inactive(
cc_pair: DATestCCPair,
timeout: float = MAX_DELAY,
user_performing_action: DATestUser | None = None,
) -> None:
"""wait for the number of docs to be indexed on the connector.
This is used to test pausing a connector in the middle of indexing and
terminating that indexing."""
print(f"Indexing wait for inactive starting: cc_pair={cc_pair.id}")
start = time.monotonic()
while True:
fetched_cc_pairs = CCPairManager.get_indexing_statuses(
user_performing_action
)
for fetched_cc_pair in fetched_cc_pairs:
if fetched_cc_pair.cc_pair_id != cc_pair.id:
continue
if fetched_cc_pair.in_progress:
continue
print(f"Indexing is inactive: cc_pair={cc_pair.id}")
return
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"Indexing wait for inactive timed out: cc_pair={cc_pair.id} timeout={timeout}s"
)
print(
f"Indexing wait for inactive still waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s"
)
time.sleep(5)
@staticmethod
def wait_for_indexing_in_progress(
cc_pair: DATestCCPair,
timeout: float = MAX_DELAY,
num_docs: int = 16,
user_performing_action: DATestUser | None = None,
) -> None:
"""wait for the number of docs to be indexed on the connector.
This is used to test pausing a connector in the middle of indexing and
terminating that indexing."""
start = time.monotonic()
while True:
fetched_cc_pairs = CCPairManager.get_indexing_statuses(
user_performing_action
)
for fetched_cc_pair in fetched_cc_pairs:
if fetched_cc_pair.cc_pair_id != cc_pair.id:
continue
if not fetched_cc_pair.in_progress:
continue
if fetched_cc_pair.docs_indexed >= num_docs:
print(
"Indexed at least the requested number of docs: "
f"cc_pair={cc_pair.id} "
f"docs_indexed={fetched_cc_pair.docs_indexed} "
f"num_docs={num_docs}"
)
return
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"Indexing in progress wait timed out: cc_pair={cc_pair.id} timeout={timeout}s"
)
print(
f"Indexing in progress waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s"
)
time.sleep(5)
@staticmethod
def wait_for_indexing_completion(
cc_pair: DATestCCPair,
after: datetime,
timeout: float = MAX_DELAY,

View File

@ -78,7 +78,7 @@ def test_slack_permission_sync(
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
@ -113,7 +113,7 @@ def test_slack_permission_sync(
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
@ -305,7 +305,7 @@ def test_slack_group_permission_sync(
# Run indexing
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,

View File

@ -74,7 +74,7 @@ def test_slack_prune(
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
@ -113,7 +113,7 @@ def test_slack_prune(
# Run indexing
before = datetime.now(timezone.utc)
CCPairManager.run_once(cc_pair, admin_user)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,

View File

@ -58,7 +58,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair_1, now, timeout=120, user_performing_action=admin_user
)
@ -71,7 +71,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair_2, now, timeout=120, user_performing_action=admin_user
)
@ -82,3 +82,48 @@ def test_overlapping_connector_creation(reset: None) -> None:
assert info_2
assert info_1.num_docs_indexed == info_2.num_docs_indexed
def test_connector_pause_while_indexing(reset: None) -> None:
"""Tests that we can pause a connector while indexing is in progress and that
tasks end early or abort as a result.
TODO: This does not specifically test for soft or hard termination code paths.
Design specific tests for those use cases.
"""
admin_user: DATestUser = UserManager.create(name="admin_user")
config = {
"wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
"space": "",
"is_cloud": True,
"page_id": "",
}
credential = {
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
}
# store the time before we create the connector so that we know after
# when the indexing should have started
datetime.now(timezone.utc)
# create connector
cc_pair_1 = CCPairManager.create_from_scratch(
source=DocumentSource.CONFLUENCE,
connector_specific_config=config,
credential_json=credential,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing_in_progress(
cc_pair_1, timeout=60, num_docs=16, user_performing_action=admin_user
)
CCPairManager.pause_cc_pair(cc_pair_1, user_performing_action=admin_user)
CCPairManager.wait_for_indexing_inactive(
cc_pair_1, timeout=60, user_performing_action=admin_user
)
return

View File

@ -135,7 +135,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing(
CCPairManager.wait_for_indexing_completion(
cc_pair_1, now, timeout=60, user_performing_action=admin_user
)

View File

@ -161,7 +161,7 @@ export default function UpgradingPage({
reindexingProgress={sortedReindexingProgress}
/>
) : (
<ErrorCallout errorTitle="Failed to fetch re-indexing progress" />
<ErrorCallout errorTitle="Failed to fetch reindexing progress" />
)}
</>
) : (
@ -171,7 +171,7 @@ export default function UpgradingPage({
</h3>
<p className="mb-4 text-text-800">
You&apos;re currently switching embedding models, but there
are no connectors to re-index. This means the transition will
are no connectors to reindex. This means the transition will
be quick and seamless!
</p>
<p className="text-text-600">

View File

@ -6,6 +6,8 @@ import { usePopup } from "@/components/admin/connectors/Popup";
import { mutate } from "swr";
import { buildCCPairInfoUrl } from "./lib";
import { setCCPairStatus } from "@/lib/ccPair";
import { useState } from "react";
import { LoadingAnimation } from "@/components/Loading";
export function ModifyStatusButtonCluster({
ccPair,
@ -13,44 +15,72 @@ export function ModifyStatusButtonCluster({
ccPair: CCPairFullInfo;
}) {
const { popup, setPopup } = usePopup();
const [isUpdating, setIsUpdating] = useState(false);
const handleStatusChange = async (
newStatus: ConnectorCredentialPairStatus
) => {
if (isUpdating) return; // Prevent double-clicks or multiple requests
setIsUpdating(true);
try {
// Call the backend to update the status
await setCCPairStatus(ccPair.id, newStatus, setPopup);
// Use mutate to revalidate the status on the backend
await mutate(buildCCPairInfoUrl(ccPair.id));
} catch (error) {
console.error("Failed to update status", error);
} finally {
// Reset local updating state and button text after mutation
setIsUpdating(false);
}
};
// Compute the button text based on current state and backend status
const buttonText =
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "Re-Enable"
: "Pause";
const tooltip =
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "Click to start indexing again!"
: "When paused, the connector's documents will still be visible. However, no new documents will be indexed.";
return (
<>
{popup}
{ccPair.status === ConnectorCredentialPairStatus.PAUSED ? (
<Button
variant="success-reverse"
onClick={() =>
setCCPairStatus(
ccPair.id,
ConnectorCredentialPairStatus.ACTIVE,
setPopup,
() => mutate(buildCCPairInfoUrl(ccPair.id))
)
}
tooltip="Click to start indexing again!"
>
Re-Enable
</Button>
) : (
<Button
variant="default"
onClick={() =>
setCCPairStatus(
ccPair.id,
ConnectorCredentialPairStatus.PAUSED,
setPopup,
() => mutate(buildCCPairInfoUrl(ccPair.id))
)
}
tooltip={
"When paused, the connectors documents will still" +
" be visible. However, no new documents will be indexed."
}
>
Pause
</Button>
)}
<Button
className="flex items-center justify-center w-auto min-w-[100px] px-4 py-2"
variant={
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "success-reverse"
: "default"
}
disabled={isUpdating}
onClick={() =>
handleStatusChange(
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? ConnectorCredentialPairStatus.ACTIVE
: ConnectorCredentialPairStatus.PAUSED
)
}
tooltip={tooltip}
>
{isUpdating ? (
<LoadingAnimation
text={
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "Resuming"
: "Pausing"
}
size="text-md"
/>
) : (
buttonText
)}
</Button>
</>
);
}

View File

@ -121,7 +121,7 @@ export function ReIndexButton({
{popup}
<Button
variant="success-reverse"
className="ml-auto"
className="ml-auto min-w-[100px]"
onClick={() => {
setReIndexPopupVisible(true);
}}

View File

@ -25,6 +25,7 @@ import { ReIndexButton } from "./ReIndexButton";
import { buildCCPairInfoUrl } from "./lib";
import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types";
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
import { Button } from "@/components/ui/button";
// since the uploaded files are cleaned up after some period of time
// re-indexing will not work for the file connector. Also, it would not