no serializable, use with_for_update to lock the row.

This commit is contained in:
Richard Kuo (Danswer) 2024-10-18 11:07:54 -07:00
parent 7906d9edc8
commit e12785d277
2 changed files with 47 additions and 41 deletions

View File

@ -20,11 +20,10 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import transition_attempt_to_in_progress
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
@ -382,46 +381,15 @@ def _run_indexing(
def _prepare_index_attempt(
db_session: Session, index_attempt_id: int, tenant_id: str | None
) -> IndexAttempt:
# make sure that the index attempt can't change in between checking the
# status and marking it as in_progress. This setting will be discarded
# after the next commit:
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions
db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore
try:
if tenant_id is not None:
# Explicitly set the search path for the given tenant
db_session.execute(text(f'SET search_path TO "{tenant_id}"'))
# Verify the search path was set correctly
result = db_session.execute(text("SHOW search_path"))
current_search_path = result.scalar()
logger.info(f"Current search path set to: {current_search_path}")
if tenant_id is not None:
# Explicitly set the search path for the given tenant
db_session.execute(text(f'SET search_path TO "{tenant_id}"'))
# Verify the search path was set correctly
result = db_session.execute(text("SHOW search_path"))
current_search_path = result.scalar()
logger.info(f"Current search path set to: {current_search_path}")
attempt = get_index_attempt(
db_session=db_session,
index_attempt_id=index_attempt_id,
)
if attempt is None:
raise RuntimeError(
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
)
if attempt.status != IndexingStatus.NOT_STARTED:
raise RuntimeError(
f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. "
f"Current status is '{attempt.status}'."
)
mark_attempt_in_progress(attempt, db_session)
# only commit once, to make sure this all happens in a single transaction
db_session.commit()
except Exception:
db_session.rollback()
logger.exception("_prepare_index_attempt exceptioned.")
raise
return attempt
return transition_attempt_to_in_progress(index_attempt_id, db_session)
def run_indexing_entrypoint(

View File

@ -101,6 +101,40 @@ def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
return list(new_attempts.all())
def transition_attempt_to_in_progress(
index_attempt_id: int,
db_session: Session,
) -> IndexAttempt:
"""Locks the row when we try to update"""
with db_session.begin_nested():
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update()
).scalar_one()
if attempt is None:
raise RuntimeError(
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
)
if attempt.status != IndexingStatus.NOT_STARTED:
raise RuntimeError(
f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. "
f"Current status is '{attempt.status}'."
)
attempt.status = IndexingStatus.IN_PROGRESS
attempt.time_started = attempt.time_started or func.now() # type: ignore
db_session.commit()
return attempt
except Exception:
db_session.rollback()
logger.exception("transition_attempt_to_in_progress exceptioned.")
raise
def mark_attempt_in_progress(
index_attempt: IndexAttempt,
db_session: Session,
@ -118,6 +152,7 @@ def mark_attempt_in_progress(
db_session.commit()
except Exception:
db_session.rollback()
raise
def mark_attempt_succeeded(
@ -136,6 +171,7 @@ def mark_attempt_succeeded(
db_session.commit()
except Exception:
db_session.rollback()
raise
def mark_attempt_partially_succeeded(
@ -154,6 +190,7 @@ def mark_attempt_partially_succeeded(
db_session.commit()
except Exception:
db_session.rollback()
raise
def mark_attempt_failed(
@ -176,6 +213,7 @@ def mark_attempt_failed(
db_session.commit()
except Exception:
db_session.rollback()
raise
source = index_attempt.connector_credential_pair.connector.source
optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})