diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index ca749dd40b17..b4cfea97a2e2 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -20,10 +20,11 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim from danswer.db.connector_credential_pair import update_connector_credential_pair from danswer.db.engine import get_session_with_tenant from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import mark_attempt_failed +from danswer.db.index_attempt import mark_attempt_in_progress from danswer.db.index_attempt import mark_attempt_partially_succeeded from danswer.db.index_attempt import mark_attempt_succeeded -from danswer.db.index_attempt import transition_attempt_to_in_progress from danswer.db.index_attempt import update_docs_indexed from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus @@ -381,15 +382,46 @@ def _run_indexing( def _prepare_index_attempt( db_session: Session, index_attempt_id: int, tenant_id: str | None ) -> IndexAttempt: - if tenant_id is not None: - # Explicitly set the search path for the given tenant - db_session.execute(text(f'SET search_path TO "{tenant_id}"')) - # Verify the search path was set correctly - result = db_session.execute(text("SHOW search_path")) - current_search_path = result.scalar() - logger.info(f"Current search path set to: {current_search_path}") + # make sure that the index attempt can't change in between checking the + # status and marking it as in_progress. This setting will be discarded + # after the next commit: + # https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions + db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore + try: + if tenant_id is not None: + # Explicitly set the search path for the given tenant + db_session.execute(text(f'SET search_path TO "{tenant_id}"')) + # Verify the search path was set correctly + result = db_session.execute(text("SHOW search_path")) + current_search_path = result.scalar() + logger.info(f"Current search path set to: {current_search_path}") - return transition_attempt_to_in_progress(index_attempt_id, db_session) + attempt = get_index_attempt( + db_session=db_session, + index_attempt_id=index_attempt_id, + ) + + if attempt is None: + raise RuntimeError( + f"Unable to find IndexAttempt for ID '{index_attempt_id}'" + ) + + if attempt.status != IndexingStatus.NOT_STARTED: + raise RuntimeError( + f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. " + f"Current status is '{attempt.status}'." + ) + + mark_attempt_in_progress(attempt, db_session) + + # only commit once, to make sure this all happens in a single transaction + db_session.commit() + except Exception: + db_session.rollback() + logger.exception("_prepare_index_attempt exceptioned.") + raise + + return attempt def run_indexing_entrypoint( diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index f1b58878dc23..21a1bbd236f1 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -101,40 +101,6 @@ def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]: return list(new_attempts.all()) -def transition_attempt_to_in_progress( - index_attempt_id: int, - db_session: Session, -) -> IndexAttempt: - """Locks the row when we try to update""" - with db_session.begin_nested(): - try: - attempt = db_session.execute( - select(IndexAttempt) - .where(IndexAttempt.id == index_attempt_id) - .with_for_update() - ).scalar_one() - - if attempt is None: - raise RuntimeError( - f"Unable to find IndexAttempt for ID '{index_attempt_id}'" - ) - - if attempt.status != IndexingStatus.NOT_STARTED: - raise RuntimeError( - f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. " - f"Current status is '{attempt.status}'." - ) - - attempt.status = IndexingStatus.IN_PROGRESS - attempt.time_started = attempt.time_started or func.now() # type: ignore - db_session.commit() - return attempt - except Exception: - db_session.rollback() - logger.exception("transition_attempt_to_in_progress exceptioned.") - raise - - def mark_attempt_in_progress( index_attempt: IndexAttempt, db_session: Session, @@ -152,7 +118,6 @@ def mark_attempt_in_progress( db_session.commit() except Exception: db_session.rollback() - raise def mark_attempt_succeeded( @@ -171,7 +136,6 @@ def mark_attempt_succeeded( db_session.commit() except Exception: db_session.rollback() - raise def mark_attempt_partially_succeeded( @@ -190,7 +154,6 @@ def mark_attempt_partially_succeeded( db_session.commit() except Exception: db_session.rollback() - raise def mark_attempt_failed( @@ -213,7 +176,6 @@ def mark_attempt_failed( db_session.commit() except Exception: db_session.rollback() - raise source = index_attempt.connector_credential_pair.connector.source optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})