diff --git a/backend/danswer/background/celery/apps/primary.py b/backend/danswer/background/celery/apps/primary.py index 23e25fa92b..983b76773e 100644 --- a/backend/danswer/background/celery/apps/primary.py +++ b/backend/danswer/background/celery/apps/primary.py @@ -17,6 +17,7 @@ from danswer.background.celery.celery_redis import RedisConnectorCredentialPair from danswer.background.celery.celery_redis import RedisConnectorDeletion from danswer.background.celery.celery_redis import RedisConnectorIndexing from danswer.background.celery.celery_redis import RedisConnectorPruning +from danswer.background.celery.celery_redis import RedisConnectorStop from danswer.background.celery.celery_redis import RedisDocumentSet from danswer.background.celery.celery_redis import RedisUserGroup from danswer.background.celery.celery_utils import celery_is_worker_primary @@ -161,6 +162,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): r.delete(key) + for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"): + r.delete(key) + # @worker_process_init.connect # def on_worker_process_init(sender: Any, **kwargs: Any) -> None: diff --git a/backend/danswer/background/celery/celery_redis.py b/backend/danswer/background/celery/celery_redis.py index 1ea5e3b176..e412b0bf73 100644 --- a/backend/danswer/background/celery/celery_redis.py +++ b/backend/danswer/background/celery/celery_redis.py @@ -313,6 +313,8 @@ class RedisConnectorDeletion(RedisObjectHelper): lock: redis.lock.Lock, tenant_id: str | None, ) -> int | None: + """Returns None if the cc_pair doesn't exist. + Otherwise, returns an int with the number of generated tasks.""" last_lock_time = time.monotonic() async_results = [] @@ -540,6 +542,29 @@ class RedisConnectorIndexing(RedisObjectHelper): return False +class RedisConnectorStop(RedisObjectHelper): + """Used to signal any running tasks for a connector to stop. We should refactor + connector related redis helpers into a single class. + """ + + PREFIX = "connectorstop" + FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process + TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's + + def __init__(self, id: int) -> None: + super().__init__(str(id)) + + def generate_tasks( + self, + celery_app: Celery, + db_session: Session, + redis_client: Redis, + lock: redis.lock.Lock | None, + tenant_id: str | None, + ) -> int | None: + return None + + def celery_get_queue_length(queue: str, r: Redis) -> int: """This is a redis specific way to get the length of a celery queue. It is priority aware and knows how to count across the multiple redis lists diff --git a/backend/danswer/background/celery/celery_utils.py b/backend/danswer/background/celery/celery_utils.py index b1e9c2113e..18038a349d 100644 --- a/backend/danswer/background/celery/celery_utils.py +++ b/backend/danswer/background/celery/celery_utils.py @@ -1,4 +1,3 @@ -from collections.abc import Callable from datetime import datetime from datetime import timezone from typing import Any @@ -6,6 +5,7 @@ from typing import Any from sqlalchemy.orm import Session from danswer.background.celery.celery_redis import RedisConnectorDeletion +from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, @@ -79,7 +79,7 @@ def document_batch_to_ids( def extract_ids_from_runnable_connector( runnable_connector: BaseConnector, - progress_callback: Callable[[int], None] | None = None, + callback: RunIndexingCallbackInterface | None = None, ) -> set[str]: """ If the PruneConnector hasnt been implemented for the given connector, just pull @@ -110,8 +110,10 @@ def extract_ids_from_runnable_connector( max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60 )(document_batch_to_ids) for doc_batch in doc_batch_generator: - if progress_callback: - progress_callback(len(doc_batch)) + if callback: + if callback.should_stop(): + raise RuntimeError("Stop signal received") + callback.progress(len(doc_batch)) all_connector_doc_ids.update(doc_batch_processing_func(doc_batch)) return all_connector_doc_ids diff --git a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py index f6a59d03e3..59d236cde3 100644 --- a/backend/danswer/background/celery/tasks/connector_deletion/tasks.py +++ b/backend/danswer/background/celery/tasks/connector_deletion/tasks.py @@ -1,3 +1,6 @@ +from datetime import datetime +from datetime import timezone + import redis from celery import Celery from celery import shared_task @@ -8,6 +11,12 @@ from sqlalchemy.orm import Session from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.celery_redis import RedisConnectorDeletion +from danswer.background.celery.celery_redis import RedisConnectorIndexing +from danswer.background.celery.celery_redis import RedisConnectorPruning +from danswer.background.celery.celery_redis import RedisConnectorStop +from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import ( + RedisConnectorDeletionFenceData, +) from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import DanswerRedisLocks @@ -15,9 +24,15 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.engine import get_session_with_tenant from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.db.search_settings import get_all_search_settings from danswer.redis.redis_pool import get_redis_client +class TaskDependencyError(RuntimeError): + """Raised to the caller to indicate dependent tasks are running that would interfere + with connector deletion.""" + + @shared_task( name="check_for_connector_deletion_task", soft_time_limit=JOB_TIMEOUT, @@ -37,17 +52,30 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N if not lock_beat.acquire(blocking=False): return + # collect cc_pair_ids cc_pair_ids: list[int] = [] with get_session_with_tenant(tenant_id) as db_session: cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in cc_pairs: cc_pair_ids.append(cc_pair.id) + # try running cleanup on the cc_pair_ids for cc_pair_id in cc_pair_ids: with get_session_with_tenant(tenant_id) as db_session: - try_generate_document_cc_pair_cleanup_tasks( - self.app, cc_pair_id, db_session, r, lock_beat, tenant_id - ) + rcs = RedisConnectorStop(cc_pair_id) + try: + try_generate_document_cc_pair_cleanup_tasks( + self.app, cc_pair_id, db_session, r, lock_beat, tenant_id + ) + except TaskDependencyError as e: + # this means we wanted to start deleting but dependent tasks were running + # Leave a stop signal to clear indexing and pruning tasks more quickly + task_logger.info(str(e)) + r.set(rcs.fence_key, cc_pair_id) + else: + # clear the stop signal if it exists ... no longer needed + r.delete(rcs.fence_key) + except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." @@ -70,6 +98,10 @@ def try_generate_document_cc_pair_cleanup_tasks( """Returns an int if syncing is needed. The int represents the number of sync tasks generated. Note that syncing can still be required even if the number of sync tasks generated is zero. Returns None if no syncing is required. + + Will raise TaskDependencyError if dependent tasks such as indexing and pruning are + still running. In our case, the caller reacts by setting a stop signal in Redis to + exit those tasks as quickly as possible. """ lock_beat.reacquire() @@ -90,28 +122,63 @@ def try_generate_document_cc_pair_cleanup_tasks( if cc_pair.status != ConnectorCredentialPairStatus.DELETING: return None - # add tasks to celery and build up the task set to monitor in redis - r.delete(rcd.taskset_key) - - # Add all documents that need to be updated into the queue - task_logger.info( - f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}" + # set a basic fence to start + fence_value = RedisConnectorDeletionFenceData( + num_tasks=None, + submitted=datetime.now(timezone.utc), ) - tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id) - if tasks_generated is None: + r.set(rcd.fence_key, fence_value.model_dump_json()) + + try: + # do not proceed if connector indexing or connector pruning are running + search_settings_list = get_all_search_settings(db_session) + for search_settings in search_settings_list: + rci = RedisConnectorIndexing(cc_pair_id, search_settings.id) + if r.get(rci.fence_key): + raise TaskDependencyError( + f"Connector deletion - Delayed (indexing in progress): " + f"cc_pair={cc_pair_id} " + f"search_settings={search_settings.id}" + ) + + rcp = RedisConnectorPruning(cc_pair_id) + if r.get(rcp.fence_key): + raise TaskDependencyError( + f"Connector deletion - Delayed (pruning in progress): " + f"cc_pair={cc_pair_id}" + ) + + # add tasks to celery and build up the task set to monitor in redis + r.delete(rcd.taskset_key) + + # Add all documents that need to be updated into the queue + task_logger.info( + f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}" + ) + tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id) + if tasks_generated is None: + raise ValueError("RedisConnectorDeletion.generate_tasks returned None") + except TaskDependencyError: + r.delete(rcd.fence_key) + raise + except Exception: + task_logger.exception("Unexpected exception") + r.delete(rcd.fence_key) return None + else: + # Currently we are allowing the sync to proceed with 0 tasks. + # It's possible for sets/groups to be generated initially with no entries + # and they still need to be marked as up to date. + # if tasks_generated == 0: + # return 0 - # Currently we are allowing the sync to proceed with 0 tasks. - # It's possible for sets/groups to be generated initially with no entries - # and they still need to be marked as up to date. - # if tasks_generated == 0: - # return 0 + task_logger.info( + f"RedisConnectorDeletion.generate_tasks finished. " + f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}" + ) - task_logger.info( - f"RedisConnectorDeletion.generate_tasks finished. " - f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" - ) + # set this only after all tasks have been added + fence_value.num_tasks = tasks_generated + r.set(rcd.fence_key, fence_value.model_dump_json()) - # set this only after all tasks have been added - r.set(rcd.fence_key, tasks_generated) return tasks_generated diff --git a/backend/danswer/background/celery/tasks/indexing/tasks.py b/backend/danswer/background/celery/tasks/indexing/tasks.py index bdd55f77f3..980266ec87 100644 --- a/backend/danswer/background/celery/tasks/indexing/tasks.py +++ b/backend/danswer/background/celery/tasks/indexing/tasks.py @@ -5,6 +5,7 @@ from time import sleep from typing import cast from uuid import uuid4 +import redis from celery import Celery from celery import shared_task from celery import Task @@ -13,12 +14,15 @@ from redis import Redis from sqlalchemy.orm import Session from danswer.background.celery.apps.app_base import task_logger +from danswer.background.celery.celery_redis import RedisConnectorDeletion from danswer.background.celery.celery_redis import RedisConnectorIndexing +from danswer.background.celery.celery_redis import RedisConnectorStop from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import ( RedisConnectorIndexingFenceData, ) from danswer.background.indexing.job_client import SimpleJobClient from danswer.background.indexing.run_indexing import run_indexing_entrypoint +from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT @@ -50,6 +54,30 @@ from danswer.utils.variable_functionality import global_version logger = setup_logger() +class RunIndexingCallback(RunIndexingCallbackInterface): + def __init__( + self, + stop_key: str, + generator_progress_key: str, + redis_lock: redis.lock.Lock, + redis_client: Redis, + ): + super().__init__() + self.redis_lock: redis.lock.Lock = redis_lock + self.stop_key: str = stop_key + self.generator_progress_key: str = generator_progress_key + self.redis_client = redis_client + + def should_stop(self) -> bool: + if self.redis_client.exists(self.stop_key): + return True + return False + + def progress(self, amount: int) -> None: + self.redis_lock.reacquire() + self.redis_client.incrby(self.generator_progress_key, amount) + + @shared_task( name="check_for_indexing", soft_time_limit=300, @@ -262,6 +290,10 @@ def try_creating_indexing_task( return None # skip indexing if the cc_pair is deleting + rcd = RedisConnectorDeletion(cc_pair.id) + if r.exists(rcd.fence_key): + return None + db_session.refresh(cc_pair) if cc_pair.status == ConnectorCredentialPairStatus.DELETING: return None @@ -308,13 +340,8 @@ def try_creating_indexing_task( raise RuntimeError("send_task for connector_indexing_proxy_task failed.") # now fill out the fence with the rest of the data - fence_value = RedisConnectorIndexingFenceData( - index_attempt_id=index_attempt_id, - started=None, - submitted=datetime.now(timezone.utc), - celery_task_id=result.id, - ) - + fence_value.index_attempt_id = index_attempt_id + fence_value.celery_task_id = result.id r.set(rci.fence_key, fence_value.model_dump_json()) except Exception: r.delete(rci.fence_key) @@ -400,6 +427,22 @@ def connector_indexing_task( r = get_redis_client(tenant_id=tenant_id) + rcd = RedisConnectorDeletion(cc_pair_id) + if r.exists(rcd.fence_key): + raise RuntimeError( + f"Indexing will not start because connector deletion is in progress: " + f"cc_pair={cc_pair_id} " + f"fence={rcd.fence_key}" + ) + + rcs = RedisConnectorStop(cc_pair_id) + if r.exists(rcs.fence_key): + raise RuntimeError( + f"Indexing will not start because a connector stop signal was detected: " + f"cc_pair={cc_pair_id} " + f"fence={rcs.fence_key}" + ) + rci = RedisConnectorIndexing(cc_pair_id, search_settings_id) while True: @@ -409,7 +452,7 @@ def connector_indexing_task( task_logger.info( f"connector_indexing_task: fence_value not found: fence={rci.fence_key}" ) - raise + raise RuntimeError(f"Fence not found: fence={rci.fence_key}") try: fence_json = fence_value.decode("utf-8") @@ -443,17 +486,20 @@ def connector_indexing_task( if not acquired: task_logger.warning( f"Indexing task already running, exiting...: " - f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}" + f"cc_pair={cc_pair_id} search_settings={search_settings_id}" ) # r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value) return None + fence_data.started = datetime.now(timezone.utc) + r.set(rci.fence_key, fence_data.model_dump_json()) + try: with get_session_with_tenant(tenant_id) as db_session: attempt = get_index_attempt(db_session, index_attempt_id) if not attempt: raise ValueError( - f"Index attempt not found: index_attempt_id={index_attempt_id}" + f"Index attempt not found: index_attempt={index_attempt_id}" ) cc_pair = get_connector_credential_pair_from_id( @@ -462,31 +508,31 @@ def connector_indexing_task( ) if not cc_pair: - raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}") + raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}") if not cc_pair.connector: raise ValueError( - f"Connector not found: connector_id={cc_pair.connector_id}" + f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}" ) if not cc_pair.credential: raise ValueError( - f"Credential not found: credential_id={cc_pair.credential_id}" + f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}" ) rci = RedisConnectorIndexing(cc_pair_id, search_settings_id) - # Define the callback function - def redis_increment_callback(amount: int) -> None: - lock.reacquire() - r.incrby(rci.generator_progress_key, amount) + # define a callback class + callback = RunIndexingCallback( + rcs.fence_key, rci.generator_progress_key, lock, r + ) run_indexing_entrypoint( index_attempt_id, tenant_id, cc_pair_id, is_ee, - progress_callback=redis_increment_callback, + callback=callback, ) # get back the total number of indexed docs and return it @@ -499,9 +545,10 @@ def connector_indexing_task( r.set(rci.generator_complete_key, HTTPStatus.OK.value) except Exception as e: - task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.") + task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}") if attempt: - mark_attempt_failed(attempt, db_session, failure_reason=str(e)) + with get_session_with_tenant(tenant_id) as db_session: + mark_attempt_failed(attempt, db_session, failure_reason=str(e)) r.delete(rci.generator_lock_key) r.delete(rci.generator_progress_key) diff --git a/backend/danswer/background/celery/tasks/pruning/tasks.py b/backend/danswer/background/celery/tasks/pruning/tasks.py index 9f290d6f23..2e68986e83 100644 --- a/backend/danswer/background/celery/tasks/pruning/tasks.py +++ b/backend/danswer/background/celery/tasks/pruning/tasks.py @@ -11,8 +11,11 @@ from redis import Redis from sqlalchemy.orm import Session from danswer.background.celery.apps.app_base import task_logger +from danswer.background.celery.celery_redis import RedisConnectorDeletion from danswer.background.celery.celery_redis import RedisConnectorPruning +from danswer.background.celery.celery_redis import RedisConnectorStop from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector +from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT @@ -168,6 +171,10 @@ def try_creating_prune_generator_task( return None # skip pruning if the cc_pair is deleting + rcd = RedisConnectorDeletion(cc_pair.id) + if r.exists(rcd.fence_key): + return None + db_session.refresh(cc_pair) if cc_pair.status == ConnectorCredentialPairStatus.DELETING: return None @@ -234,7 +241,7 @@ def connector_pruning_generator_task( acquired = lock.acquire(blocking=False) if not acquired: task_logger.warning( - f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}" + f"Pruning task already running, exiting...: cc_pair={cc_pair_id}" ) return None @@ -252,11 +259,6 @@ def connector_pruning_generator_task( ) return - # Define the callback function - def redis_increment_callback(amount: int) -> None: - lock.reacquire() - r.incrby(rcp.generator_progress_key, amount) - runnable_connector = instantiate_connector( db_session, cc_pair.connector.source, @@ -265,9 +267,14 @@ def connector_pruning_generator_task( cc_pair.credential, ) + rcs = RedisConnectorStop(cc_pair_id) + + callback = RunIndexingCallback( + rcs.fence_key, rcp.generator_progress_key, lock, r + ) # a list of docs in the source all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector( - runnable_connector, redis_increment_callback + runnable_connector, callback ) # a list of docs in our local index @@ -285,7 +292,7 @@ def connector_pruning_generator_task( task_logger.info( f"Pruning set collected: " - f"cc_pair_id={cc_pair.id} " + f"cc_pair={cc_pair_id} " f"docs_to_remove={len(doc_ids_to_remove)} " f"doc_source={cc_pair.connector.source}" ) @@ -293,7 +300,7 @@ def connector_pruning_generator_task( rcp.documents_to_prune = set(doc_ids_to_remove) task_logger.info( - f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}" + f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair_id}" ) tasks_generated = rcp.generate_tasks( self.app, db_session, r, None, tenant_id @@ -303,12 +310,14 @@ def connector_pruning_generator_task( task_logger.info( f"RedisConnectorPruning.generate_tasks finished. " - f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" + f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}" ) r.set(rcp.generator_complete_key, tasks_generated) except Exception as e: - task_logger.exception(f"Failed to run pruning for connector id {connector_id}.") + task_logger.exception( + f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}" + ) r.delete(rcp.generator_progress_key) r.delete(rcp.taskset_key) diff --git a/backend/danswer/background/celery/tasks/shared/RedisConnectorDeletionFenceData.py b/backend/danswer/background/celery/tasks/shared/RedisConnectorDeletionFenceData.py new file mode 100644 index 0000000000..1c664d14b4 --- /dev/null +++ b/backend/danswer/background/celery/tasks/shared/RedisConnectorDeletionFenceData.py @@ -0,0 +1,8 @@ +from datetime import datetime + +from pydantic import BaseModel + + +class RedisConnectorDeletionFenceData(BaseModel): + num_tasks: int | None + submitted: datetime diff --git a/backend/danswer/background/celery/tasks/shared/tasks.py b/backend/danswer/background/celery/tasks/shared/tasks.py index 7ce43454aa..52a49d467e 100644 --- a/backend/danswer/background/celery/tasks/shared/tasks.py +++ b/backend/danswer/background/celery/tasks/shared/tasks.py @@ -1,9 +1,6 @@ -from datetime import datetime - from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded -from pydantic import BaseModel from danswer.access.access import get_access_for_document from danswer.background.celery.apps.app_base import task_logger @@ -23,13 +20,6 @@ from danswer.server.documents.models import ConnectorCredentialPairIdentifier DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3 -class RedisConnectorIndexingFenceData(BaseModel): - index_attempt_id: int | None - started: datetime | None - submitted: datetime - celery_task_id: str | None - - @shared_task( name="document_by_cc_pair_cleanup_task", bind=True, diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index 812074b91e..fcc4d2aa5b 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -23,6 +23,9 @@ from danswer.background.celery.celery_redis import RedisConnectorIndexing from danswer.background.celery.celery_redis import RedisConnectorPruning from danswer.background.celery.celery_redis import RedisDocumentSet from danswer.background.celery.celery_redis import RedisUserGroup +from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import ( + RedisConnectorDeletionFenceData, +) from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import ( RedisConnectorIndexingFenceData, ) @@ -368,7 +371,7 @@ def monitor_document_set_taskset( count = cast(int, r.scard(rds.taskset_key)) task_logger.info( - f"Document set sync progress: document_set_id={document_set_id} " + f"Document set sync progress: document_set={document_set_id} " f"remaining={count} initial={initial_count}" ) if count > 0: @@ -383,12 +386,12 @@ def monitor_document_set_taskset( # if there are no connectors, then delete the document set. delete_document_set(document_set_row=document_set, db_session=db_session) task_logger.info( - f"Successfully deleted document set with ID: '{document_set_id}'!" + f"Successfully deleted document set: document_set={document_set_id}" ) else: mark_document_set_as_synced(document_set_id, db_session) task_logger.info( - f"Successfully synced document set with ID: '{document_set_id}'!" + f"Successfully synced document set: document_set={document_set_id}" ) r.delete(rds.taskset_key) @@ -408,19 +411,29 @@ def monitor_connector_deletion_taskset( rcd = RedisConnectorDeletion(cc_pair_id) - fence_value = r.get(rcd.fence_key) + # read related data and evaluate/print task progress + fence_value = cast(bytes, r.get(rcd.fence_key)) if fence_value is None: return try: - initial_count = int(cast(int, fence_value)) + fence_json = fence_value.decode("utf-8") + fence_data = RedisConnectorDeletionFenceData.model_validate_json( + cast(str, fence_json) + ) except ValueError: - task_logger.error("The value is not an integer.") + task_logger.exception( + "monitor_ccpair_indexing_taskset: fence_data not decodeable." + ) + raise + + # the fence is setting up but isn't ready yet + if fence_data.num_tasks is None: return count = cast(int, r.scard(rcd.taskset_key)) task_logger.info( - f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}" + f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}" ) if count > 0: return @@ -483,7 +496,7 @@ def monitor_connector_deletion_taskset( ) if not connector or not len(connector.credentials): task_logger.info( - "Found no credentials left for connector, deleting connector" + "Connector deletion - Found no credentials left for connector, deleting connector" ) db_session.delete(connector) db_session.commit() @@ -493,17 +506,17 @@ def monitor_connector_deletion_taskset( error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}" add_deletion_failure_message(db_session, cc_pair_id, error_message) task_logger.exception( - f"Failed to run connector_deletion. " + f"Connector deletion exceptioned: " f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}" ) raise e task_logger.info( - f"Successfully deleted cc_pair: " + f"Connector deletion succeeded: " f"cc_pair={cc_pair_id} " f"connector={cc_pair.connector_id} " f"credential={cc_pair.credential_id} " - f"docs_deleted={initial_count}" + f"docs_deleted={fence_data.num_tasks}" ) r.delete(rcd.taskset_key) @@ -618,6 +631,7 @@ def monitor_ccpair_indexing_taskset( return # Read result state BEFORE generator_complete_key to avoid a race condition + # never use any blocking methods on the result from inside a task! result: AsyncResult = AsyncResult(fence_data.celery_task_id) result_state = result.state diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index cb50739045..d95a6a70d5 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -1,6 +1,7 @@ import time import traceback -from collections.abc import Callable +from abc import ABC +from abc import abstractmethod from datetime import datetime from datetime import timedelta from datetime import timezone @@ -41,6 +42,19 @@ logger = setup_logger() INDEXING_TRACER_NUM_PRINT_ENTRIES = 5 +class RunIndexingCallbackInterface(ABC): + """Defines a callback interface to be passed to + to run_indexing_entrypoint.""" + + @abstractmethod + def should_stop(self) -> bool: + """Signal to stop the looping function in flight.""" + + @abstractmethod + def progress(self, amount: int) -> None: + """Send progress updates to the caller.""" + + def _get_connector_runner( db_session: Session, attempt: IndexAttempt, @@ -92,7 +106,7 @@ def _run_indexing( db_session: Session, index_attempt: IndexAttempt, tenant_id: str | None, - progress_callback: Callable[[int], None] | None = None, + callback: RunIndexingCallbackInterface | None = None, ) -> None: """ 1. Get documents which are either new or updated from specified application @@ -206,6 +220,11 @@ def _run_indexing( # index being built. We want to populate it even for paused connectors # Often paused connectors are sources that aren't updated frequently but the # contents still need to be initially pulled. + if callback: + if callback.should_stop(): + raise RuntimeError("Connector stop signal detected") + + # TODO: should we move this into the above callback instead? db_session.refresh(db_cc_pair) if ( ( @@ -263,8 +282,8 @@ def _run_indexing( # be inaccurate db_session.commit() - if progress_callback: - progress_callback(len(doc_batch)) + if callback: + callback.progress(len(doc_batch)) # This new value is updated every batch, so UI can refresh per batch update update_docs_indexed( @@ -394,7 +413,7 @@ def run_indexing_entrypoint( tenant_id: str | None, connector_credential_pair_id: int, is_ee: bool = False, - progress_callback: Callable[[int], None] | None = None, + callback: RunIndexingCallbackInterface | None = None, ) -> None: try: if is_ee: @@ -417,7 +436,7 @@ def run_indexing_entrypoint( f"credentials='{attempt.connector_credential_pair.connector_id}'" ) - _run_indexing(db_session, attempt, tenant_id, progress_callback) + _run_indexing(db_session, attempt, tenant_id, callback) logger.info( f"Indexing finished for tenant {tenant_id}: " diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index d16aa59c4c..1ceeb776ab 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -19,7 +19,6 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair from danswer.db.connector_credential_pair import ( update_connector_credential_pair_from_id, ) -from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed from danswer.db.engine import get_current_tenant_id from danswer.db.engine import get_session from danswer.db.enums import ConnectorCredentialPairStatus @@ -175,15 +174,19 @@ def create_deletion_attempt_for_connector_id( cc_pair_id=cc_pair.id, db_session=db_session, include_secondary_index=True ) + # TODO(rkuo): 2024-10-24 - check_deletion_attempt_is_allowed shouldn't be necessary + # any more due to background locking improvements. + # Remove the below permanently if everything is behaving for 30 days. + # Check if the deletion attempt should be allowed - deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed( - connector_credential_pair=cc_pair, db_session=db_session - ) - if deletion_attempt_disallowed_reason: - raise HTTPException( - status_code=400, - detail=deletion_attempt_disallowed_reason, - ) + # deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed( + # connector_credential_pair=cc_pair, db_session=db_session + # ) + # if deletion_attempt_disallowed_reason: + # raise HTTPException( + # status_code=400, + # detail=deletion_attempt_disallowed_reason, + # ) # mark as deleting update_connector_credential_pair_from_id(