Bugfix/connector deletion lockout (#2901)

* first cut at deletion hardening

* clean up logging

* remove commented code
This commit is contained in:
rkuo-danswer
2024-10-24 19:43:57 -07:00
committed by GitHub
parent b9781c43fb
commit eaa8ae7399
11 changed files with 281 additions and 93 deletions

View File

@ -17,6 +17,7 @@ from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_redis import RedisDocumentSet from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary from danswer.background.celery.celery_utils import celery_is_worker_primary
@ -161,6 +162,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"): for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key) r.delete(key)
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)
# @worker_process_init.connect # @worker_process_init.connect
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None: # def on_worker_process_init(sender: Any, **kwargs: Any) -> None:

View File

@ -313,6 +313,8 @@ class RedisConnectorDeletion(RedisObjectHelper):
lock: redis.lock.Lock, lock: redis.lock.Lock,
tenant_id: str | None, tenant_id: str | None,
) -> int | None: ) -> int | None:
"""Returns None if the cc_pair doesn't exist.
Otherwise, returns an int with the number of generated tasks."""
last_lock_time = time.monotonic() last_lock_time = time.monotonic()
async_results = [] async_results = []
@ -540,6 +542,29 @@ class RedisConnectorIndexing(RedisObjectHelper):
return False return False
class RedisConnectorStop(RedisObjectHelper):
"""Used to signal any running tasks for a connector to stop. We should refactor
connector related redis helpers into a single class.
"""
PREFIX = "connectorstop"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def celery_get_queue_length(queue: str, r: Redis) -> int: def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue. """This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists It is priority aware and knows how to count across the multiple redis lists

View File

@ -1,4 +1,3 @@
from collections.abc import Callable
from datetime import datetime from datetime import datetime
from datetime import timezone from datetime import timezone
from typing import Any from typing import Any
@ -6,6 +5,7 @@ from typing import Any
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder, rate_limit_builder,
@ -79,7 +79,7 @@ def document_batch_to_ids(
def extract_ids_from_runnable_connector( def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector, runnable_connector: BaseConnector,
progress_callback: Callable[[int], None] | None = None, callback: RunIndexingCallbackInterface | None = None,
) -> set[str]: ) -> set[str]:
""" """
If the PruneConnector hasnt been implemented for the given connector, just pull If the PruneConnector hasnt been implemented for the given connector, just pull
@ -110,8 +110,10 @@ def extract_ids_from_runnable_connector(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60 max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids) )(document_batch_to_ids)
for doc_batch in doc_batch_generator: for doc_batch in doc_batch_generator:
if progress_callback: if callback:
progress_callback(len(doc_batch)) if callback.should_stop():
raise RuntimeError("Stop signal received")
callback.progress(len(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch)) all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
return all_connector_doc_ids return all_connector_doc_ids

View File

@ -1,3 +1,6 @@
from datetime import datetime
from datetime import timezone
import redis import redis
from celery import Celery from celery import Celery
from celery import shared_task from celery import shared_task
@ -8,6 +11,12 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks from danswer.configs.constants import DanswerRedisLocks
@ -15,9 +24,15 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_session_with_tenant from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_pool import get_redis_client from danswer.redis.redis_pool import get_redis_client
class TaskDependencyError(RuntimeError):
"""Raised to the caller to indicate dependent tasks are running that would interfere
with connector deletion."""
@shared_task( @shared_task(
name="check_for_connector_deletion_task", name="check_for_connector_deletion_task",
soft_time_limit=JOB_TIMEOUT, soft_time_limit=JOB_TIMEOUT,
@ -37,17 +52,30 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
if not lock_beat.acquire(blocking=False): if not lock_beat.acquire(blocking=False):
return return
# collect cc_pair_ids
cc_pair_ids: list[int] = [] cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session: with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session) cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs: for cc_pair in cc_pairs:
cc_pair_ids.append(cc_pair.id) cc_pair_ids.append(cc_pair.id)
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids: for cc_pair_id in cc_pair_ids:
with get_session_with_tenant(tenant_id) as db_session: with get_session_with_tenant(tenant_id) as db_session:
try_generate_document_cc_pair_cleanup_tasks( rcs = RedisConnectorStop(cc_pair_id)
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id try:
) try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
# Leave a stop signal to clear indexing and pruning tasks more quickly
task_logger.info(str(e))
r.set(rcs.fence_key, cc_pair_id)
else:
# clear the stop signal if it exists ... no longer needed
r.delete(rcs.fence_key)
except SoftTimeLimitExceeded: except SoftTimeLimitExceeded:
task_logger.info( task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully." "Soft time limit exceeded, task is being terminated gracefully."
@ -70,6 +98,10 @@ def try_generate_document_cc_pair_cleanup_tasks(
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated. """Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero. Note that syncing can still be required even if the number of sync tasks generated is zero.
Returns None if no syncing is required. Returns None if no syncing is required.
Will raise TaskDependencyError if dependent tasks such as indexing and pruning are
still running. In our case, the caller reacts by setting a stop signal in Redis to
exit those tasks as quickly as possible.
""" """
lock_beat.reacquire() lock_beat.reacquire()
@ -90,28 +122,63 @@ def try_generate_document_cc_pair_cleanup_tasks(
if cc_pair.status != ConnectorCredentialPairStatus.DELETING: if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return None return None
# add tasks to celery and build up the task set to monitor in redis # set a basic fence to start
r.delete(rcd.taskset_key) fence_value = RedisConnectorDeletionFenceData(
num_tasks=None,
# Add all documents that need to be updated into the queue submitted=datetime.now(timezone.utc),
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
) )
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id) r.set(rcd.fence_key, fence_value.model_dump_json())
if tasks_generated is None:
try:
# do not proceed if connector indexing or connector pruning are running
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
rci = RedisConnectorIndexing(cc_pair_id, search_settings.id)
if r.get(rci.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (indexing in progress): "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings.id}"
)
rcp = RedisConnectorPruning(cc_pair_id)
if r.get(rcp.fence_key):
raise TaskDependencyError(
f"Connector deletion - Delayed (pruning in progress): "
f"cc_pair={cc_pair_id}"
)
# add tasks to celery and build up the task set to monitor in redis
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
except TaskDependencyError:
r.delete(rcd.fence_key)
raise
except Exception:
task_logger.exception("Unexpected exception")
r.delete(rcd.fence_key)
return None return None
else:
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
# Currently we are allowing the sync to proceed with 0 tasks. task_logger.info(
# It's possible for sets/groups to be generated initially with no entries f"RedisConnectorDeletion.generate_tasks finished. "
# and they still need to be marked as up to date. f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
# if tasks_generated == 0: )
# return 0
task_logger.info( # set this only after all tasks have been added
f"RedisConnectorDeletion.generate_tasks finished. " fence_value.num_tasks = tasks_generated
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" r.set(rcd.fence_key, fence_value.model_dump_json())
)
# set this only after all tasks have been added
r.set(rcd.fence_key, tasks_generated)
return tasks_generated return tasks_generated

View File

@ -5,6 +5,7 @@ from time import sleep
from typing import cast from typing import cast
from uuid import uuid4 from uuid import uuid4
import redis
from celery import Celery from celery import Celery
from celery import shared_task from celery import shared_task
from celery import Task from celery import Task
@ -13,12 +14,15 @@ from redis import Redis
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import ( from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData, RedisConnectorIndexingFenceData,
) )
from danswer.background.indexing.job_client import SimpleJobClient from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@ -50,6 +54,30 @@ from danswer.utils.variable_functionality import global_version
logger = setup_logger() logger = setup_logger()
class RunIndexingCallback(RunIndexingCallbackInterface):
def __init__(
self,
stop_key: str,
generator_progress_key: str,
redis_lock: redis.lock.Lock,
redis_client: Redis,
):
super().__init__()
self.redis_lock: redis.lock.Lock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
return True
return False
def progress(self, amount: int) -> None:
self.redis_lock.reacquire()
self.redis_client.incrby(self.generator_progress_key, amount)
@shared_task( @shared_task(
name="check_for_indexing", name="check_for_indexing",
soft_time_limit=300, soft_time_limit=300,
@ -262,6 +290,10 @@ def try_creating_indexing_task(
return None return None
# skip indexing if the cc_pair is deleting # skip indexing if the cc_pair is deleting
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair) db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING: if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None return None
@ -308,13 +340,8 @@ def try_creating_indexing_task(
raise RuntimeError("send_task for connector_indexing_proxy_task failed.") raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data # now fill out the fence with the rest of the data
fence_value = RedisConnectorIndexingFenceData( fence_value.index_attempt_id = index_attempt_id
index_attempt_id=index_attempt_id, fence_value.celery_task_id = result.id
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=result.id,
)
r.set(rci.fence_key, fence_value.model_dump_json()) r.set(rci.fence_key, fence_value.model_dump_json())
except Exception: except Exception:
r.delete(rci.fence_key) r.delete(rci.fence_key)
@ -400,6 +427,22 @@ def connector_indexing_task(
r = get_redis_client(tenant_id=tenant_id) r = get_redis_client(tenant_id=tenant_id)
rcd = RedisConnectorDeletion(cc_pair_id)
if r.exists(rcd.fence_key):
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"cc_pair={cc_pair_id} "
f"fence={rcd.fence_key}"
)
rcs = RedisConnectorStop(cc_pair_id)
if r.exists(rcs.fence_key):
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"cc_pair={cc_pair_id} "
f"fence={rcs.fence_key}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id) rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
while True: while True:
@ -409,7 +452,7 @@ def connector_indexing_task(
task_logger.info( task_logger.info(
f"connector_indexing_task: fence_value not found: fence={rci.fence_key}" f"connector_indexing_task: fence_value not found: fence={rci.fence_key}"
) )
raise raise RuntimeError(f"Fence not found: fence={rci.fence_key}")
try: try:
fence_json = fence_value.decode("utf-8") fence_json = fence_value.decode("utf-8")
@ -443,17 +486,20 @@ def connector_indexing_task(
if not acquired: if not acquired:
task_logger.warning( task_logger.warning(
f"Indexing task already running, exiting...: " f"Indexing task already running, exiting...: "
f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}" f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
) )
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value) # r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
return None return None
fence_data.started = datetime.now(timezone.utc)
r.set(rci.fence_key, fence_data.model_dump_json())
try: try:
with get_session_with_tenant(tenant_id) as db_session: with get_session_with_tenant(tenant_id) as db_session:
attempt = get_index_attempt(db_session, index_attempt_id) attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt: if not attempt:
raise ValueError( raise ValueError(
f"Index attempt not found: index_attempt_id={index_attempt_id}" f"Index attempt not found: index_attempt={index_attempt_id}"
) )
cc_pair = get_connector_credential_pair_from_id( cc_pair = get_connector_credential_pair_from_id(
@ -462,31 +508,31 @@ def connector_indexing_task(
) )
if not cc_pair: if not cc_pair:
raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}") raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
if not cc_pair.connector: if not cc_pair.connector:
raise ValueError( raise ValueError(
f"Connector not found: connector_id={cc_pair.connector_id}" f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
) )
if not cc_pair.credential: if not cc_pair.credential:
raise ValueError( raise ValueError(
f"Credential not found: credential_id={cc_pair.credential_id}" f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
) )
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id) rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# Define the callback function # define a callback class
def redis_increment_callback(amount: int) -> None: callback = RunIndexingCallback(
lock.reacquire() rcs.fence_key, rci.generator_progress_key, lock, r
r.incrby(rci.generator_progress_key, amount) )
run_indexing_entrypoint( run_indexing_entrypoint(
index_attempt_id, index_attempt_id,
tenant_id, tenant_id,
cc_pair_id, cc_pair_id,
is_ee, is_ee,
progress_callback=redis_increment_callback, callback=callback,
) )
# get back the total number of indexed docs and return it # get back the total number of indexed docs and return it
@ -499,9 +545,10 @@ def connector_indexing_task(
r.set(rci.generator_complete_key, HTTPStatus.OK.value) r.set(rci.generator_complete_key, HTTPStatus.OK.value)
except Exception as e: except Exception as e:
task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.") task_logger.exception(f"Indexing failed: cc_pair={cc_pair_id}")
if attempt: if attempt:
mark_attempt_failed(attempt, db_session, failure_reason=str(e)) with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
r.delete(rci.generator_lock_key) r.delete(rci.generator_lock_key)
r.delete(rci.generator_progress_key) r.delete(rci.generator_progress_key)

View File

@ -11,8 +11,11 @@ from redis import Redis
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorPruning from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisConnectorStop
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
@ -168,6 +171,10 @@ def try_creating_prune_generator_task(
return None return None
# skip pruning if the cc_pair is deleting # skip pruning if the cc_pair is deleting
rcd = RedisConnectorDeletion(cc_pair.id)
if r.exists(rcd.fence_key):
return None
db_session.refresh(cc_pair) db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING: if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None return None
@ -234,7 +241,7 @@ def connector_pruning_generator_task(
acquired = lock.acquire(blocking=False) acquired = lock.acquire(blocking=False)
if not acquired: if not acquired:
task_logger.warning( task_logger.warning(
f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}" f"Pruning task already running, exiting...: cc_pair={cc_pair_id}"
) )
return None return None
@ -252,11 +259,6 @@ def connector_pruning_generator_task(
) )
return return
# Define the callback function
def redis_increment_callback(amount: int) -> None:
lock.reacquire()
r.incrby(rcp.generator_progress_key, amount)
runnable_connector = instantiate_connector( runnable_connector = instantiate_connector(
db_session, db_session,
cc_pair.connector.source, cc_pair.connector.source,
@ -265,9 +267,14 @@ def connector_pruning_generator_task(
cc_pair.credential, cc_pair.credential,
) )
rcs = RedisConnectorStop(cc_pair_id)
callback = RunIndexingCallback(
rcs.fence_key, rcp.generator_progress_key, lock, r
)
# a list of docs in the source # a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector( all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector, redis_increment_callback runnable_connector, callback
) )
# a list of docs in our local index # a list of docs in our local index
@ -285,7 +292,7 @@ def connector_pruning_generator_task(
task_logger.info( task_logger.info(
f"Pruning set collected: " f"Pruning set collected: "
f"cc_pair_id={cc_pair.id} " f"cc_pair={cc_pair_id} "
f"docs_to_remove={len(doc_ids_to_remove)} " f"docs_to_remove={len(doc_ids_to_remove)} "
f"doc_source={cc_pair.connector.source}" f"doc_source={cc_pair.connector.source}"
) )
@ -293,7 +300,7 @@ def connector_pruning_generator_task(
rcp.documents_to_prune = set(doc_ids_to_remove) rcp.documents_to_prune = set(doc_ids_to_remove)
task_logger.info( task_logger.info(
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}" f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair_id}"
) )
tasks_generated = rcp.generate_tasks( tasks_generated = rcp.generate_tasks(
self.app, db_session, r, None, tenant_id self.app, db_session, r, None, tenant_id
@ -303,12 +310,14 @@ def connector_pruning_generator_task(
task_logger.info( task_logger.info(
f"RedisConnectorPruning.generate_tasks finished. " f"RedisConnectorPruning.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}" f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
) )
r.set(rcp.generator_complete_key, tasks_generated) r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e: except Exception as e:
task_logger.exception(f"Failed to run pruning for connector id {connector_id}.") task_logger.exception(
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
)
r.delete(rcp.generator_progress_key) r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key) r.delete(rcp.taskset_key)

View File

@ -0,0 +1,8 @@
from datetime import datetime
from pydantic import BaseModel
class RedisConnectorDeletionFenceData(BaseModel):
num_tasks: int | None
submitted: datetime

View File

@ -1,9 +1,6 @@
from datetime import datetime
from celery import shared_task from celery import shared_task
from celery import Task from celery import Task
from celery.exceptions import SoftTimeLimitExceeded from celery.exceptions import SoftTimeLimitExceeded
from pydantic import BaseModel
from danswer.access.access import get_access_for_document from danswer.access.access import get_access_for_document
from danswer.background.celery.apps.app_base import task_logger from danswer.background.celery.apps.app_base import task_logger
@ -23,13 +20,6 @@ from danswer.server.documents.models import ConnectorCredentialPairIdentifier
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3 DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
celery_task_id: str | None
@shared_task( @shared_task(
name="document_by_cc_pair_cleanup_task", name="document_by_cc_pair_cleanup_task",
bind=True, bind=True,

View File

@ -23,6 +23,9 @@ from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.tasks.shared.RedisConnectorDeletionFenceData import (
RedisConnectorDeletionFenceData,
)
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import ( from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
RedisConnectorIndexingFenceData, RedisConnectorIndexingFenceData,
) )
@ -368,7 +371,7 @@ def monitor_document_set_taskset(
count = cast(int, r.scard(rds.taskset_key)) count = cast(int, r.scard(rds.taskset_key))
task_logger.info( task_logger.info(
f"Document set sync progress: document_set_id={document_set_id} " f"Document set sync progress: document_set={document_set_id} "
f"remaining={count} initial={initial_count}" f"remaining={count} initial={initial_count}"
) )
if count > 0: if count > 0:
@ -383,12 +386,12 @@ def monitor_document_set_taskset(
# if there are no connectors, then delete the document set. # if there are no connectors, then delete the document set.
delete_document_set(document_set_row=document_set, db_session=db_session) delete_document_set(document_set_row=document_set, db_session=db_session)
task_logger.info( task_logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!" f"Successfully deleted document set: document_set={document_set_id}"
) )
else: else:
mark_document_set_as_synced(document_set_id, db_session) mark_document_set_as_synced(document_set_id, db_session)
task_logger.info( task_logger.info(
f"Successfully synced document set with ID: '{document_set_id}'!" f"Successfully synced document set: document_set={document_set_id}"
) )
r.delete(rds.taskset_key) r.delete(rds.taskset_key)
@ -408,19 +411,29 @@ def monitor_connector_deletion_taskset(
rcd = RedisConnectorDeletion(cc_pair_id) rcd = RedisConnectorDeletion(cc_pair_id)
fence_value = r.get(rcd.fence_key) # read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rcd.fence_key))
if fence_value is None: if fence_value is None:
return return
try: try:
initial_count = int(cast(int, fence_value)) fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorDeletionFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError: except ValueError:
task_logger.error("The value is not an integer.") task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
# the fence is setting up but isn't ready yet
if fence_data.num_tasks is None:
return return
count = cast(int, r.scard(rcd.taskset_key)) count = cast(int, r.scard(rcd.taskset_key))
task_logger.info( task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}" f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={fence_data.num_tasks}"
) )
if count > 0: if count > 0:
return return
@ -483,7 +496,7 @@ def monitor_connector_deletion_taskset(
) )
if not connector or not len(connector.credentials): if not connector or not len(connector.credentials):
task_logger.info( task_logger.info(
"Found no credentials left for connector, deleting connector" "Connector deletion - Found no credentials left for connector, deleting connector"
) )
db_session.delete(connector) db_session.delete(connector)
db_session.commit() db_session.commit()
@ -493,17 +506,17 @@ def monitor_connector_deletion_taskset(
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}" error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair_id, error_message) add_deletion_failure_message(db_session, cc_pair_id, error_message)
task_logger.exception( task_logger.exception(
f"Failed to run connector_deletion. " f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}" f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
) )
raise e raise e
task_logger.info( task_logger.info(
f"Successfully deleted cc_pair: " f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} " f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} " f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} " f"credential={cc_pair.credential_id} "
f"docs_deleted={initial_count}" f"docs_deleted={fence_data.num_tasks}"
) )
r.delete(rcd.taskset_key) r.delete(rcd.taskset_key)
@ -618,6 +631,7 @@ def monitor_ccpair_indexing_taskset(
return return
# Read result state BEFORE generator_complete_key to avoid a race condition # Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(fence_data.celery_task_id) result: AsyncResult = AsyncResult(fence_data.celery_task_id)
result_state = result.state result_state = result.state

View File

@ -1,6 +1,7 @@
import time import time
import traceback import traceback
from collections.abc import Callable from abc import ABC
from abc import abstractmethod
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta
from datetime import timezone from datetime import timezone
@ -41,6 +42,19 @@ logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5 INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
class RunIndexingCallbackInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abstractmethod
def progress(self, amount: int) -> None:
"""Send progress updates to the caller."""
def _get_connector_runner( def _get_connector_runner(
db_session: Session, db_session: Session,
attempt: IndexAttempt, attempt: IndexAttempt,
@ -92,7 +106,7 @@ def _run_indexing(
db_session: Session, db_session: Session,
index_attempt: IndexAttempt, index_attempt: IndexAttempt,
tenant_id: str | None, tenant_id: str | None,
progress_callback: Callable[[int], None] | None = None, callback: RunIndexingCallbackInterface | None = None,
) -> None: ) -> None:
""" """
1. Get documents which are either new or updated from specified application 1. Get documents which are either new or updated from specified application
@ -206,6 +220,11 @@ def _run_indexing(
# index being built. We want to populate it even for paused connectors # index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the # Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled. # contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise RuntimeError("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair) db_session.refresh(db_cc_pair)
if ( if (
( (
@ -263,8 +282,8 @@ def _run_indexing(
# be inaccurate # be inaccurate
db_session.commit() db_session.commit()
if progress_callback: if callback:
progress_callback(len(doc_batch)) callback.progress(len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update # This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed( update_docs_indexed(
@ -394,7 +413,7 @@ def run_indexing_entrypoint(
tenant_id: str | None, tenant_id: str | None,
connector_credential_pair_id: int, connector_credential_pair_id: int,
is_ee: bool = False, is_ee: bool = False,
progress_callback: Callable[[int], None] | None = None, callback: RunIndexingCallbackInterface | None = None,
) -> None: ) -> None:
try: try:
if is_ee: if is_ee:
@ -417,7 +436,7 @@ def run_indexing_entrypoint(
f"credentials='{attempt.connector_credential_pair.connector_id}'" f"credentials='{attempt.connector_credential_pair.connector_id}'"
) )
_run_indexing(db_session, attempt, tenant_id, progress_callback) _run_indexing(db_session, attempt, tenant_id, callback)
logger.info( logger.info(
f"Indexing finished for tenant {tenant_id}: " f"Indexing finished for tenant {tenant_id}: "

View File

@ -19,7 +19,6 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import ( from danswer.db.connector_credential_pair import (
update_connector_credential_pair_from_id, update_connector_credential_pair_from_id,
) )
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.engine import get_current_tenant_id from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session from danswer.db.engine import get_session
from danswer.db.enums import ConnectorCredentialPairStatus from danswer.db.enums import ConnectorCredentialPairStatus
@ -175,15 +174,19 @@ def create_deletion_attempt_for_connector_id(
cc_pair_id=cc_pair.id, db_session=db_session, include_secondary_index=True cc_pair_id=cc_pair.id, db_session=db_session, include_secondary_index=True
) )
# TODO(rkuo): 2024-10-24 - check_deletion_attempt_is_allowed shouldn't be necessary
# any more due to background locking improvements.
# Remove the below permanently if everything is behaving for 30 days.
# Check if the deletion attempt should be allowed # Check if the deletion attempt should be allowed
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed( # deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair, db_session=db_session # connector_credential_pair=cc_pair, db_session=db_session
) # )
if deletion_attempt_disallowed_reason: # if deletion_attempt_disallowed_reason:
raise HTTPException( # raise HTTPException(
status_code=400, # status_code=400,
detail=deletion_attempt_disallowed_reason, # detail=deletion_attempt_disallowed_reason,
) # )
# mark as deleting # mark as deleting
update_connector_credential_pair_from_id( update_connector_credential_pair_from_id(