From 12cbbe6ceeca812cb0ebf523b5aff29c6d264db2 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Fri, 18 Oct 2024 13:35:23 -0700 Subject: [PATCH] use with for update instead of serializable (#2848) * use with for update instead of serializable * remove tenant logic handled now by get_session_with_tenant * remove usage of begin_nested ... it's not necessary --- .../background/indexing/run_indexing.py | 51 +------ backend/danswer/db/index_attempt.py | 127 +++++++++++------- 2 files changed, 82 insertions(+), 96 deletions(-) diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index b4cfea97a..32878bfa4 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -4,7 +4,6 @@ from datetime import datetime from datetime import timedelta from datetime import timezone -from sqlalchemy import text from sqlalchemy.orm import Session from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt @@ -20,11 +19,10 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim from danswer.db.connector_credential_pair import update_connector_credential_pair from danswer.db.engine import get_session_with_tenant from danswer.db.enums import ConnectorCredentialPairStatus -from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import mark_attempt_failed -from danswer.db.index_attempt import mark_attempt_in_progress from danswer.db.index_attempt import mark_attempt_partially_succeeded from danswer.db.index_attempt import mark_attempt_succeeded +from danswer.db.index_attempt import transition_attempt_to_in_progress from danswer.db.index_attempt import update_docs_indexed from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus @@ -379,51 +377,6 @@ def _run_indexing( ) -def _prepare_index_attempt( - db_session: Session, index_attempt_id: int, tenant_id: str | None -) -> IndexAttempt: - # make sure that the index attempt can't change in between checking the - # status and marking it as in_progress. This setting will be discarded - # after the next commit: - # https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions - db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore - try: - if tenant_id is not None: - # Explicitly set the search path for the given tenant - db_session.execute(text(f'SET search_path TO "{tenant_id}"')) - # Verify the search path was set correctly - result = db_session.execute(text("SHOW search_path")) - current_search_path = result.scalar() - logger.info(f"Current search path set to: {current_search_path}") - - attempt = get_index_attempt( - db_session=db_session, - index_attempt_id=index_attempt_id, - ) - - if attempt is None: - raise RuntimeError( - f"Unable to find IndexAttempt for ID '{index_attempt_id}'" - ) - - if attempt.status != IndexingStatus.NOT_STARTED: - raise RuntimeError( - f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. " - f"Current status is '{attempt.status}'." - ) - - mark_attempt_in_progress(attempt, db_session) - - # only commit once, to make sure this all happens in a single transaction - db_session.commit() - except Exception: - db_session.rollback() - logger.exception("_prepare_index_attempt exceptioned.") - raise - - return attempt - - def run_indexing_entrypoint( index_attempt_id: int, tenant_id: str | None, @@ -440,7 +393,7 @@ def run_indexing_entrypoint( index_attempt_id, connector_credential_pair_id ) with get_session_with_tenant(tenant_id) as db_session: - attempt = _prepare_index_attempt(db_session, index_attempt_id, tenant_id) + attempt = transition_attempt_to_in_progress(index_attempt_id, db_session) logger.info( f"Indexing starting for tenant {tenant_id}: " diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index 21a1bbd23..5d214d778 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -101,59 +101,92 @@ def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]: return list(new_attempts.all()) +def transition_attempt_to_in_progress( + index_attempt_id: int, + db_session: Session, +) -> IndexAttempt: + """Locks the row when we try to update""" + try: + attempt = db_session.execute( + select(IndexAttempt) + .where(IndexAttempt.id == index_attempt_id) + .with_for_update() + ).scalar_one() + + if attempt is None: + raise RuntimeError( + f"Unable to find IndexAttempt for ID '{index_attempt_id}'" + ) + + if attempt.status != IndexingStatus.NOT_STARTED: + raise RuntimeError( + f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. " + f"Current status is '{attempt.status}'." + ) + + attempt.status = IndexingStatus.IN_PROGRESS + attempt.time_started = attempt.time_started or func.now() # type: ignore + db_session.commit() + return attempt + except Exception: + db_session.rollback() + logger.exception("transition_attempt_to_in_progress exceptioned.") + raise + + def mark_attempt_in_progress( index_attempt: IndexAttempt, db_session: Session, ) -> None: - with db_session.begin_nested(): - try: - attempt = db_session.execute( - select(IndexAttempt) - .where(IndexAttempt.id == index_attempt.id) - .with_for_update() - ).scalar_one() + try: + attempt = db_session.execute( + select(IndexAttempt) + .where(IndexAttempt.id == index_attempt.id) + .with_for_update() + ).scalar_one() - attempt.status = IndexingStatus.IN_PROGRESS - attempt.time_started = index_attempt.time_started or func.now() # type: ignore - db_session.commit() - except Exception: - db_session.rollback() + attempt.status = IndexingStatus.IN_PROGRESS + attempt.time_started = index_attempt.time_started or func.now() # type: ignore + db_session.commit() + except Exception: + db_session.rollback() + raise def mark_attempt_succeeded( index_attempt: IndexAttempt, db_session: Session, ) -> None: - with db_session.begin_nested(): - try: - attempt = db_session.execute( - select(IndexAttempt) - .where(IndexAttempt.id == index_attempt.id) - .with_for_update() - ).scalar_one() + try: + attempt = db_session.execute( + select(IndexAttempt) + .where(IndexAttempt.id == index_attempt.id) + .with_for_update() + ).scalar_one() - attempt.status = IndexingStatus.SUCCESS - db_session.commit() - except Exception: - db_session.rollback() + attempt.status = IndexingStatus.SUCCESS + db_session.commit() + except Exception: + db_session.rollback() + raise def mark_attempt_partially_succeeded( index_attempt: IndexAttempt, db_session: Session, ) -> None: - with db_session.begin_nested(): - try: - attempt = db_session.execute( - select(IndexAttempt) - .where(IndexAttempt.id == index_attempt.id) - .with_for_update() - ).scalar_one() + try: + attempt = db_session.execute( + select(IndexAttempt) + .where(IndexAttempt.id == index_attempt.id) + .with_for_update() + ).scalar_one() - attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS - db_session.commit() - except Exception: - db_session.rollback() + attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS + db_session.commit() + except Exception: + db_session.rollback() + raise def mark_attempt_failed( @@ -162,20 +195,20 @@ def mark_attempt_failed( failure_reason: str = "Unknown", full_exception_trace: str | None = None, ) -> None: - with db_session.begin_nested(): - try: - attempt = db_session.execute( - select(IndexAttempt) - .where(IndexAttempt.id == index_attempt.id) - .with_for_update() - ).scalar_one() + try: + attempt = db_session.execute( + select(IndexAttempt) + .where(IndexAttempt.id == index_attempt.id) + .with_for_update() + ).scalar_one() - attempt.status = IndexingStatus.FAILED - attempt.error_msg = failure_reason - attempt.full_exception_trace = full_exception_trace - db_session.commit() - except Exception: - db_session.rollback() + attempt.status = IndexingStatus.FAILED + attempt.error_msg = failure_reason + attempt.full_exception_trace = full_exception_trace + db_session.commit() + except Exception: + db_session.rollback() + raise source = index_attempt.connector_credential_pair.connector.source optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})