diff --git a/backend/danswer/background/celery/tasks/vespa/tasks.py b/backend/danswer/background/celery/tasks/vespa/tasks.py index c43de3a85..e958d63ad 100644 --- a/backend/danswer/background/celery/tasks/vespa/tasks.py +++ b/backend/danswer/background/celery/tasks/vespa/tasks.py @@ -31,6 +31,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.document import count_documents_by_needs_sync from danswer.db.document import get_document +from danswer.db.document import get_document_ids_for_connector_credential_pair from danswer.db.document import mark_document_as_synced from danswer.db.document_set import delete_document_set from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit @@ -363,7 +364,7 @@ def monitor_connector_deletion_taskset( count = cast(int, r.scard(rcd.taskset_key)) task_logger.info( - f"Connector deletion progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}" + f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}" ) if count > 0: return @@ -372,16 +373,27 @@ def monitor_connector_deletion_taskset( cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session) if not cc_pair: task_logger.warning( - f"monitor_connector_deletion_taskset - cc_pair_id not found: cc_pair_id={cc_pair_id}" + f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}" ) return try: + doc_ids = get_document_ids_for_connector_credential_pair( + db_session, cc_pair.connector_id, cc_pair.credential_id + ) + if len(doc_ids) > 0: + # if this happens, documents somehow got added while deletion was in progress. Likely a bug + # gating off pruning and indexing work before deletion starts + task_logger.warning( + f"Connector deletion - documents still found after taskset completion: " + f"cc_pair={cc_pair_id} num={len(doc_ids)}" + ) + # clean up the rest of the related Postgres entities # index attempts delete_index_attempts( db_session=db_session, - cc_pair_id=cc_pair.id, + cc_pair_id=cc_pair_id, ) # document sets @@ -398,7 +410,7 @@ def monitor_connector_deletion_taskset( noop_fallback, ) cleanup_user_groups( - cc_pair_id=cc_pair.id, + cc_pair_id=cc_pair_id, db_session=db_session, ) @@ -420,20 +432,21 @@ def monitor_connector_deletion_taskset( db_session.delete(connector) db_session.commit() except Exception as e: + db_session.rollback() stack_trace = traceback.format_exc() error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}" - add_deletion_failure_message(db_session, cc_pair.id, error_message) + add_deletion_failure_message(db_session, cc_pair_id, error_message) task_logger.exception( f"Failed to run connector_deletion. " - f"cc_pair_id={cc_pair_id} connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}" + f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}" ) raise e task_logger.info( f"Successfully deleted cc_pair: " - f"cc_pair_id={cc_pair_id} " - f"connector_id={cc_pair.connector_id} " - f"credential_id={cc_pair.credential_id} " + f"cc_pair={cc_pair_id} " + f"connector={cc_pair.connector_id} " + f"credential={cc_pair.credential_id} " f"docs_deleted={initial_count}" ) diff --git a/backend/danswer/db/document_set.py b/backend/danswer/db/document_set.py index 0ba6c4e9a..b5af99b22 100644 --- a/backend/danswer/db/document_set.py +++ b/backend/danswer/db/document_set.py @@ -398,7 +398,7 @@ def mark_document_set_as_to_be_deleted( def delete_document_set_cc_pair_relationship__no_commit( connector_id: int, credential_id: int, db_session: Session -) -> None: +) -> int: """Deletes all rows from DocumentSet__ConnectorCredentialPair where the connector_credential_pair_id matches the given cc_pair_id.""" delete_stmt = delete(DocumentSet__ConnectorCredentialPair).where( @@ -409,7 +409,8 @@ def delete_document_set_cc_pair_relationship__no_commit( == ConnectorCredentialPair.id, ) ) - db_session.execute(delete_stmt) + result = db_session.execute(delete_stmt) + return result.rowcount # type: ignore def fetch_document_sets(