mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-07 19:38:19 +02:00
Feature/kill indexing (#3213)
* checkpoint * add celery termination of the task * rename to RedisConnectorPermissionSyncPayload, add RedisLock to more places, add get_active_search_settings * rename payload * pretty sure these weren't named correctly * testing in progress * cleanup * remove space * merge fix * three dots animation on Pausing * improve messaging when connector is stopped or killed and animate buttons --------- Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
parent
5be7d27285
commit
7f1e4a02bf
@ -5,7 +5,6 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -37,7 +36,7 @@ class TaskDependencyError(RuntimeError):
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@ -60,7 +59,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
|
||||
self.app, cc_pair_id, db_session, lock_beat, tenant_id
|
||||
)
|
||||
except TaskDependencyError as e:
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
@ -86,7 +85,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
|
@ -8,6 +8,7 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@ -27,7 +28,7 @@ from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncData,
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import doc_permission_sync_ctx
|
||||
@ -138,7 +139,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
@ -162,7 +163,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
app.send_task(
|
||||
result = app.send_task(
|
||||
"connector_permission_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@ -174,8 +175,8 @@ def try_creating_permissions_sync_task(
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=None,
|
||||
payload = RedisConnectorPermissionSyncPayload(
|
||||
started=None, celery_task_id=result.id
|
||||
)
|
||||
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
@ -247,9 +248,11 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
|
||||
|
||||
payload = RedisConnectorPermissionSyncData(
|
||||
started=datetime.now(timezone.utc),
|
||||
)
|
||||
payload = redis_connector.permissions.payload
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
@ -8,6 +8,7 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
@ -24,6 +25,9 @@ from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_ext_group_sync import (
|
||||
RedisConnectorExternalGroupSyncPayload,
|
||||
)
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
@ -107,7 +111,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
@ -125,7 +129,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
@ -156,7 +160,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
_ = app.send_task(
|
||||
result = app.send_task(
|
||||
"connector_external_group_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@ -166,8 +170,13 @@ def try_creating_permissions_sync_task(
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
# set a basic fence to start
|
||||
redis_connector.external_group_sync.set_fence(True)
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
@ -203,7 +212,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
@ -253,7 +262,6 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
|
||||
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
@ -264,6 +272,6 @@ def connector_external_group_sync_generator_task(
|
||||
raise e
|
||||
finally:
|
||||
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
|
||||
redis_connector.external_group_sync.set_fence(False)
|
||||
redis_connector.external_group_sync.set_fence(None)
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
@ -39,12 +39,13 @@ from danswer.db.index_attempt import delete_index_attempt
|
||||
from danswer.db.index_attempt import get_all_index_attempts_by_status
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import mark_attempt_canceled
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_active_search_settings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
@ -209,17 +210,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
for search_settings_instance in search_settings:
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings(
|
||||
db_session
|
||||
)
|
||||
for search_settings_instance in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
@ -237,7 +231,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
)
|
||||
|
||||
search_settings_primary = False
|
||||
if search_settings_instance.id == primary_search_settings.id:
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
search_settings_primary = True
|
||||
|
||||
if not _should_index(
|
||||
@ -245,13 +239,13 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
search_settings_primary=search_settings_primary,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
secondary_index_building=len(search_settings_list) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
reindex = False
|
||||
if search_settings_instance.id == primary_search_settings.id:
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
# the indexing trigger is only checked and cleared with the primary search settings
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
|
||||
@ -284,7 +278,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
f"Connector indexing queued: "
|
||||
f"index_attempt={attempt_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings_instance.id} "
|
||||
f"search_settings={search_settings_instance.id}"
|
||||
)
|
||||
tasks_created += 1
|
||||
|
||||
@ -529,8 +523,11 @@ def try_creating_indexing_task(
|
||||
return index_attempt_id
|
||||
|
||||
|
||||
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
|
||||
@shared_task(
|
||||
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
|
||||
)
|
||||
def connector_indexing_proxy_task(
|
||||
self: Task,
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
@ -543,6 +540,10 @@ def connector_indexing_proxy_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
if not self.request.id:
|
||||
task_logger.error("self.request.id is None!")
|
||||
|
||||
client = SimpleJobClient()
|
||||
|
||||
job = client.submit(
|
||||
@ -571,8 +572,30 @@ def connector_indexing_proxy_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
while True:
|
||||
sleep(10)
|
||||
sleep(5)
|
||||
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing proxy - termination signal detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
break
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
|
@ -46,6 +46,7 @@ from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
@ -58,7 +59,7 @@ from danswer.redis.redis_connector_credential_pair import RedisConnectorCredenti
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncData,
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from danswer.redis.redis_connector_index import RedisConnectorIndex
|
||||
from danswer.redis.redis_connector_prune import RedisConnectorPrune
|
||||
@ -588,7 +589,7 @@ def monitor_ccpair_permissions_taskset(
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncData | None = (
|
||||
payload: RedisConnectorPermissionSyncPayload | None = (
|
||||
redis_connector.permissions.payload
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
@ -596,9 +597,7 @@ def monitor_ccpair_permissions_taskset(
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.taskset_clear()
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.set_fence(None)
|
||||
redis_connector.permissions.reset()
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
@ -678,11 +677,15 @@ def monitor_ccpair_indexing_taskset(
|
||||
|
||||
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
|
||||
if index_attempt:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
@ -692,6 +695,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
@ -724,7 +728,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
# print current queue lengths
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r)
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
@ -19,6 +19,7 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import mark_attempt_canceled
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
@ -87,6 +88,10 @@ def _get_connector_runner(
|
||||
)
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
@ -208,9 +213,7 @@ def _run_indexing(
|
||||
# contents still need to be initially pulled.
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_run_indexing: Connector stop signal detected"
|
||||
)
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
db_session.refresh(db_cc_pair)
|
||||
@ -304,26 +307,16 @@ def _run_indexing(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
|
||||
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
|
||||
)
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or not db_cc_pair.status.is_active()
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
mark_attempt_canceled(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
@ -335,6 +328,37 @@ def _run_indexing(
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
# Only mark the attempt as a complete failure if this is the first indexing window.
|
||||
# Otherwise, some progress was made - the next run will not start from the beginning.
|
||||
# In this case, it is not accurate to mark it as a failure. When the next run begins,
|
||||
# if that fails immediately, it will be marked as a failure.
|
||||
#
|
||||
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or not db_cc_pair.status.is_active()
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt.id,
|
||||
db_session,
|
||||
failure_reason=str(e),
|
||||
full_exception_trace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
|
||||
if INDEXING_TRACER_INTERVAL > 0:
|
||||
tracer.stop()
|
||||
raise e
|
||||
|
||||
# break => similar to success case. As mentioned above, if the next run fails for the same
|
||||
# reason it will then be marked as a failure
|
||||
|
@ -143,6 +143,25 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
|
||||
return latest_settings
|
||||
|
||||
|
||||
def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
"""Returns active search settings. The first entry will always be the current search
|
||||
settings. If there are new search settings that are being migrated to, those will be
|
||||
the second entry."""
|
||||
search_settings_list: list[SearchSettings] = []
|
||||
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings_list.append(primary_search_settings)
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings_list.append(secondary_search_settings)
|
||||
|
||||
return search_settings_list
|
||||
|
||||
|
||||
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
query = select(SearchSettings).order_by(SearchSettings.id.desc())
|
||||
result = db_session.execute(query)
|
||||
|
@ -1,5 +1,8 @@
|
||||
import time
|
||||
|
||||
import redis
|
||||
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
@ -31,6 +34,44 @@ class RedisConnector:
|
||||
self.tenant_id, self.id, search_settings_id, self.redis
|
||||
)
|
||||
|
||||
def wait_for_indexing_termination(
|
||||
self,
|
||||
search_settings_list: list[SearchSettings],
|
||||
timeout: float = 15.0,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if all indexing for the given redis connector is finished within the given timeout.
|
||||
Returns False if the timeout is exceeded
|
||||
|
||||
This check does not guarantee that current indexings being terminated
|
||||
won't get restarted midflight
|
||||
"""
|
||||
|
||||
finished = False
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
while True:
|
||||
still_indexing = False
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = self.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
still_indexing = True
|
||||
break
|
||||
|
||||
if not still_indexing:
|
||||
finished = True
|
||||
break
|
||||
|
||||
now = time.monotonic()
|
||||
if now - start > timeout:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
return finished
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> str | None:
|
||||
"""
|
||||
|
@ -14,8 +14,9 @@ from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
|
||||
|
||||
class RedisConnectorPermissionSyncData(BaseModel):
|
||||
class RedisConnectorPermissionSyncPayload(BaseModel):
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
|
||||
class RedisConnectorPermissionSync:
|
||||
@ -78,14 +79,14 @@ class RedisConnectorPermissionSync:
|
||||
return False
|
||||
|
||||
@property
|
||||
def payload(self) -> RedisConnectorPermissionSyncData | None:
|
||||
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorPermissionSyncData.model_validate_json(
|
||||
payload = RedisConnectorPermissionSyncPayload.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
)
|
||||
|
||||
@ -93,7 +94,7 @@ class RedisConnectorPermissionSync:
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorPermissionSyncData | None,
|
||||
payload: RedisConnectorPermissionSyncPayload | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.fence_key)
|
||||
@ -162,6 +163,12 @@ class RedisConnectorPermissionSync:
|
||||
|
||||
return len(async_results)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@staticmethod
|
||||
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
||||
taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}"
|
||||
|
@ -1,11 +1,18 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class RedisConnectorExternalGroupSyncPayload(BaseModel):
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
|
||||
class RedisConnectorExternalGroupSync:
|
||||
"""Manages interactions with redis for external group syncing tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
@ -68,12 +75,29 @@ class RedisConnectorExternalGroupSync:
|
||||
|
||||
return False
|
||||
|
||||
def set_fence(self, value: bool) -> None:
|
||||
if not value:
|
||||
@property
|
||||
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorExternalGroupSyncPayload.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorExternalGroupSyncPayload | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, 0)
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
|
@ -29,6 +29,8 @@ class RedisConnectorIndex:
|
||||
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
|
||||
|
||||
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
@ -51,6 +53,7 @@ class RedisConnectorIndex:
|
||||
self.generator_lock_key = (
|
||||
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
|
||||
)
|
||||
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
|
||||
|
||||
@classmethod
|
||||
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
|
||||
@ -92,6 +95,18 @@ class RedisConnectorIndex:
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
|
||||
def terminating(self, celery_task_id: str) -> bool:
|
||||
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_terminate(self, celery_task_id: str) -> None:
|
||||
"""This sets a signal. It does not block!"""
|
||||
# We shouldn't need very long to terminate the spawned task.
|
||||
# 10 minute TTL is good.
|
||||
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
|
||||
|
||||
def set_generator_complete(self, payload: int | None) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
@ -6,6 +6,7 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -37,7 +38,9 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import count_index_attempts_for_connector
|
||||
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_active_search_settings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
@ -158,7 +161,19 @@ def update_cc_pair_status(
|
||||
status_update_request: CCStatusUpdateRequest,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""This method may wait up to 30 seconds if pausing the connector due to the need to
|
||||
terminate tasks in progress. Tasks are not guaranteed to terminate within the
|
||||
timeout.
|
||||
|
||||
Returns HTTPStatus.OK if everything finished.
|
||||
Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks
|
||||
did not finish within the timeout.
|
||||
"""
|
||||
WAIT_TIMEOUT = 15.0
|
||||
still_terminating = False
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
@ -173,10 +188,76 @@ def update_cc_pair_status(
|
||||
)
|
||||
|
||||
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
|
||||
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings(
|
||||
db_session
|
||||
)
|
||||
|
||||
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
try:
|
||||
redis_connector.stop.set_fence(True)
|
||||
while True:
|
||||
logger.debug(
|
||||
f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}"
|
||||
)
|
||||
wait_succeeded = redis_connector.wait_for_indexing_termination(
|
||||
search_settings_list, WAIT_TIMEOUT
|
||||
)
|
||||
if wait_succeeded:
|
||||
logger.debug(
|
||||
f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}"
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
"Wait for indexing soft termination timed out. "
|
||||
f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}"
|
||||
)
|
||||
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings.id
|
||||
)
|
||||
if not redis_connector_index.fenced:
|
||||
continue
|
||||
|
||||
index_payload = redis_connector_index.payload
|
||||
if not index_payload:
|
||||
continue
|
||||
|
||||
if not index_payload.celery_task_id:
|
||||
continue
|
||||
|
||||
# Revoke the task to prevent it from running
|
||||
primary_app.control.revoke(index_payload.celery_task_id)
|
||||
|
||||
# If it is running, then signaling for termination will get the
|
||||
# watchdog thread to kill the spawned task
|
||||
redis_connector_index.set_terminate(index_payload.celery_task_id)
|
||||
|
||||
logger.debug(
|
||||
f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}"
|
||||
)
|
||||
wait_succeeded = redis_connector.wait_for_indexing_termination(
|
||||
search_settings_list, WAIT_TIMEOUT
|
||||
)
|
||||
if wait_succeeded:
|
||||
logger.debug(
|
||||
f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}"
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}"
|
||||
)
|
||||
still_terminating = True
|
||||
break
|
||||
finally:
|
||||
redis_connector.stop.set_fence(False)
|
||||
|
||||
update_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@ -185,6 +266,18 @@ def update_cc_pair_status(
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if still_terminating:
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.ACCEPTED,
|
||||
content={
|
||||
"message": "Request accepted, background task termination still in progress"
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
|
||||
)
|
||||
|
||||
|
||||
@router.put("/admin/cc-pair/{cc_pair_id}/name")
|
||||
def update_cc_pair_name(
|
||||
@ -267,9 +360,9 @@ def prune_cc_pair(
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Pruning cc_pair: cc_pair_id={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"Pruning cc_pair: cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
|
@ -240,7 +240,85 @@ class CCPairManager:
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def wait_for_indexing(
|
||||
def wait_for_indexing_inactive(
|
||||
cc_pair: DATestCCPair,
|
||||
timeout: float = MAX_DELAY,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""wait for the number of docs to be indexed on the connector.
|
||||
This is used to test pausing a connector in the middle of indexing and
|
||||
terminating that indexing."""
|
||||
print(f"Indexing wait for inactive starting: cc_pair={cc_pair.id}")
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
fetched_cc_pairs = CCPairManager.get_indexing_statuses(
|
||||
user_performing_action
|
||||
)
|
||||
for fetched_cc_pair in fetched_cc_pairs:
|
||||
if fetched_cc_pair.cc_pair_id != cc_pair.id:
|
||||
continue
|
||||
|
||||
if fetched_cc_pair.in_progress:
|
||||
continue
|
||||
|
||||
print(f"Indexing is inactive: cc_pair={cc_pair.id}")
|
||||
return
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"Indexing wait for inactive timed out: cc_pair={cc_pair.id} timeout={timeout}s"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Indexing wait for inactive still waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s"
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_indexing_in_progress(
|
||||
cc_pair: DATestCCPair,
|
||||
timeout: float = MAX_DELAY,
|
||||
num_docs: int = 16,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""wait for the number of docs to be indexed on the connector.
|
||||
This is used to test pausing a connector in the middle of indexing and
|
||||
terminating that indexing."""
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
fetched_cc_pairs = CCPairManager.get_indexing_statuses(
|
||||
user_performing_action
|
||||
)
|
||||
for fetched_cc_pair in fetched_cc_pairs:
|
||||
if fetched_cc_pair.cc_pair_id != cc_pair.id:
|
||||
continue
|
||||
|
||||
if not fetched_cc_pair.in_progress:
|
||||
continue
|
||||
|
||||
if fetched_cc_pair.docs_indexed >= num_docs:
|
||||
print(
|
||||
"Indexed at least the requested number of docs: "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"docs_indexed={fetched_cc_pair.docs_indexed} "
|
||||
f"num_docs={num_docs}"
|
||||
)
|
||||
return
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"Indexing in progress wait timed out: cc_pair={cc_pair.id} timeout={timeout}s"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Indexing in progress waiting: cc_pair={cc_pair.id} elapsed={elapsed:.2f} timeout={timeout}s"
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_indexing_completion(
|
||||
cc_pair: DATestCCPair,
|
||||
after: datetime,
|
||||
timeout: float = MAX_DELAY,
|
||||
|
@ -78,7 +78,7 @@ def test_slack_permission_sync(
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_indexing(
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
@ -113,7 +113,7 @@ def test_slack_permission_sync(
|
||||
# Run indexing
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing(
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
@ -305,7 +305,7 @@ def test_slack_group_permission_sync(
|
||||
|
||||
# Run indexing
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing(
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
|
@ -74,7 +74,7 @@ def test_slack_prune(
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_indexing(
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
@ -113,7 +113,7 @@ def test_slack_prune(
|
||||
# Run indexing
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing(
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
|
@ -58,7 +58,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_indexing(
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair_1, now, timeout=120, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
@ -71,7 +71,7 @@ def test_overlapping_connector_creation(reset: None) -> None:
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_indexing(
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair_2, now, timeout=120, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
@ -82,3 +82,48 @@ def test_overlapping_connector_creation(reset: None) -> None:
|
||||
assert info_2
|
||||
|
||||
assert info_1.num_docs_indexed == info_2.num_docs_indexed
|
||||
|
||||
|
||||
def test_connector_pause_while_indexing(reset: None) -> None:
|
||||
"""Tests that we can pause a connector while indexing is in progress and that
|
||||
tasks end early or abort as a result.
|
||||
|
||||
TODO: This does not specifically test for soft or hard termination code paths.
|
||||
Design specific tests for those use cases.
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
config = {
|
||||
"wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
|
||||
"space": "",
|
||||
"is_cloud": True,
|
||||
"page_id": "",
|
||||
}
|
||||
|
||||
credential = {
|
||||
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
}
|
||||
|
||||
# store the time before we create the connector so that we know after
|
||||
# when the indexing should have started
|
||||
datetime.now(timezone.utc)
|
||||
|
||||
# create connector
|
||||
cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
connector_specific_config=config,
|
||||
credential_json=credential,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_indexing_in_progress(
|
||||
cc_pair_1, timeout=60, num_docs=16, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
CCPairManager.pause_cc_pair(cc_pair_1, user_performing_action=admin_user)
|
||||
|
||||
CCPairManager.wait_for_indexing_inactive(
|
||||
cc_pair_1, timeout=60, user_performing_action=admin_user
|
||||
)
|
||||
return
|
||||
|
@ -135,7 +135,7 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_indexing(
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair_1, now, timeout=60, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
|
@ -161,7 +161,7 @@ export default function UpgradingPage({
|
||||
reindexingProgress={sortedReindexingProgress}
|
||||
/>
|
||||
) : (
|
||||
<ErrorCallout errorTitle="Failed to fetch re-indexing progress" />
|
||||
<ErrorCallout errorTitle="Failed to fetch reindexing progress" />
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
@ -171,7 +171,7 @@ export default function UpgradingPage({
|
||||
</h3>
|
||||
<p className="mb-4 text-text-800">
|
||||
You're currently switching embedding models, but there
|
||||
are no connectors to re-index. This means the transition will
|
||||
are no connectors to reindex. This means the transition will
|
||||
be quick and seamless!
|
||||
</p>
|
||||
<p className="text-text-600">
|
||||
|
@ -6,6 +6,8 @@ import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { mutate } from "swr";
|
||||
import { buildCCPairInfoUrl } from "./lib";
|
||||
import { setCCPairStatus } from "@/lib/ccPair";
|
||||
import { useState } from "react";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
|
||||
export function ModifyStatusButtonCluster({
|
||||
ccPair,
|
||||
@ -13,44 +15,72 @@ export function ModifyStatusButtonCluster({
|
||||
ccPair: CCPairFullInfo;
|
||||
}) {
|
||||
const { popup, setPopup } = usePopup();
|
||||
const [isUpdating, setIsUpdating] = useState(false);
|
||||
|
||||
const handleStatusChange = async (
|
||||
newStatus: ConnectorCredentialPairStatus
|
||||
) => {
|
||||
if (isUpdating) return; // Prevent double-clicks or multiple requests
|
||||
setIsUpdating(true);
|
||||
|
||||
try {
|
||||
// Call the backend to update the status
|
||||
await setCCPairStatus(ccPair.id, newStatus, setPopup);
|
||||
|
||||
// Use mutate to revalidate the status on the backend
|
||||
await mutate(buildCCPairInfoUrl(ccPair.id));
|
||||
} catch (error) {
|
||||
console.error("Failed to update status", error);
|
||||
} finally {
|
||||
// Reset local updating state and button text after mutation
|
||||
setIsUpdating(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Compute the button text based on current state and backend status
|
||||
const buttonText =
|
||||
ccPair.status === ConnectorCredentialPairStatus.PAUSED
|
||||
? "Re-Enable"
|
||||
: "Pause";
|
||||
|
||||
const tooltip =
|
||||
ccPair.status === ConnectorCredentialPairStatus.PAUSED
|
||||
? "Click to start indexing again!"
|
||||
: "When paused, the connector's documents will still be visible. However, no new documents will be indexed.";
|
||||
|
||||
return (
|
||||
<>
|
||||
{popup}
|
||||
{ccPair.status === ConnectorCredentialPairStatus.PAUSED ? (
|
||||
<Button
|
||||
variant="success-reverse"
|
||||
onClick={() =>
|
||||
setCCPairStatus(
|
||||
ccPair.id,
|
||||
ConnectorCredentialPairStatus.ACTIVE,
|
||||
setPopup,
|
||||
() => mutate(buildCCPairInfoUrl(ccPair.id))
|
||||
)
|
||||
}
|
||||
tooltip="Click to start indexing again!"
|
||||
>
|
||||
Re-Enable
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
variant="default"
|
||||
onClick={() =>
|
||||
setCCPairStatus(
|
||||
ccPair.id,
|
||||
ConnectorCredentialPairStatus.PAUSED,
|
||||
setPopup,
|
||||
() => mutate(buildCCPairInfoUrl(ccPair.id))
|
||||
)
|
||||
}
|
||||
tooltip={
|
||||
"When paused, the connectors documents will still" +
|
||||
" be visible. However, no new documents will be indexed."
|
||||
}
|
||||
>
|
||||
Pause
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
className="flex items-center justify-center w-auto min-w-[100px] px-4 py-2"
|
||||
variant={
|
||||
ccPair.status === ConnectorCredentialPairStatus.PAUSED
|
||||
? "success-reverse"
|
||||
: "default"
|
||||
}
|
||||
disabled={isUpdating}
|
||||
onClick={() =>
|
||||
handleStatusChange(
|
||||
ccPair.status === ConnectorCredentialPairStatus.PAUSED
|
||||
? ConnectorCredentialPairStatus.ACTIVE
|
||||
: ConnectorCredentialPairStatus.PAUSED
|
||||
)
|
||||
}
|
||||
tooltip={tooltip}
|
||||
>
|
||||
{isUpdating ? (
|
||||
<LoadingAnimation
|
||||
text={
|
||||
ccPair.status === ConnectorCredentialPairStatus.PAUSED
|
||||
? "Resuming"
|
||||
: "Pausing"
|
||||
}
|
||||
size="text-md"
|
||||
/>
|
||||
) : (
|
||||
buttonText
|
||||
)}
|
||||
</Button>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ export function ReIndexButton({
|
||||
{popup}
|
||||
<Button
|
||||
variant="success-reverse"
|
||||
className="ml-auto"
|
||||
className="ml-auto min-w-[100px]"
|
||||
onClick={() => {
|
||||
setReIndexPopupVisible(true);
|
||||
}}
|
||||
|
@ -25,6 +25,7 @@ import { ReIndexButton } from "./ReIndexButton";
|
||||
import { buildCCPairInfoUrl } from "./lib";
|
||||
import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types";
|
||||
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
// since the uploaded files are cleaned up after some period of time
|
||||
// re-indexing will not work for the file connector. Also, it would not
|
||||
|
Loading…
x
Reference in New Issue
Block a user