use with for update instead of serializable (#2848)

* use with for update instead of serializable

* remove tenant logic handled now by get_session_with_tenant

* remove usage of begin_nested ... it's not necessary
This commit is contained in:
rkuo-danswer 2024-10-18 13:35:23 -07:00 committed by GitHub
parent 55de519364
commit 12cbbe6cee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 96 deletions

View File

@ -4,7 +4,6 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
@ -20,11 +19,10 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import transition_attempt_to_in_progress
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
@ -379,51 +377,6 @@ def _run_indexing(
)
def _prepare_index_attempt(
db_session: Session, index_attempt_id: int, tenant_id: str | None
) -> IndexAttempt:
# make sure that the index attempt can't change in between checking the
# status and marking it as in_progress. This setting will be discarded
# after the next commit:
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions
db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore
try:
if tenant_id is not None:
# Explicitly set the search path for the given tenant
db_session.execute(text(f'SET search_path TO "{tenant_id}"'))
# Verify the search path was set correctly
result = db_session.execute(text("SHOW search_path"))
current_search_path = result.scalar()
logger.info(f"Current search path set to: {current_search_path}")
attempt = get_index_attempt(
db_session=db_session,
index_attempt_id=index_attempt_id,
)
if attempt is None:
raise RuntimeError(
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
)
if attempt.status != IndexingStatus.NOT_STARTED:
raise RuntimeError(
f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. "
f"Current status is '{attempt.status}'."
)
mark_attempt_in_progress(attempt, db_session)
# only commit once, to make sure this all happens in a single transaction
db_session.commit()
except Exception:
db_session.rollback()
logger.exception("_prepare_index_attempt exceptioned.")
raise
return attempt
def run_indexing_entrypoint(
index_attempt_id: int,
tenant_id: str | None,
@ -440,7 +393,7 @@ def run_indexing_entrypoint(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:
attempt = _prepare_index_attempt(db_session, index_attempt_id, tenant_id)
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
logger.info(
f"Indexing starting for tenant {tenant_id}: "

View File

@ -101,59 +101,92 @@ def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
return list(new_attempts.all())
def transition_attempt_to_in_progress(
index_attempt_id: int,
db_session: Session,
) -> IndexAttempt:
"""Locks the row when we try to update"""
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update()
).scalar_one()
if attempt is None:
raise RuntimeError(
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
)
if attempt.status != IndexingStatus.NOT_STARTED:
raise RuntimeError(
f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. "
f"Current status is '{attempt.status}'."
)
attempt.status = IndexingStatus.IN_PROGRESS
attempt.time_started = attempt.time_started or func.now() # type: ignore
db_session.commit()
return attempt
except Exception:
db_session.rollback()
logger.exception("transition_attempt_to_in_progress exceptioned.")
raise
def mark_attempt_in_progress(
index_attempt: IndexAttempt,
db_session: Session,
) -> None:
with db_session.begin_nested():
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
attempt.status = IndexingStatus.IN_PROGRESS
attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
except Exception:
db_session.rollback()
attempt.status = IndexingStatus.IN_PROGRESS
attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
except Exception:
db_session.rollback()
raise
def mark_attempt_succeeded(
index_attempt: IndexAttempt,
db_session: Session,
) -> None:
with db_session.begin_nested():
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
attempt.status = IndexingStatus.SUCCESS
db_session.commit()
except Exception:
db_session.rollback()
attempt.status = IndexingStatus.SUCCESS
db_session.commit()
except Exception:
db_session.rollback()
raise
def mark_attempt_partially_succeeded(
index_attempt: IndexAttempt,
db_session: Session,
) -> None:
with db_session.begin_nested():
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
db_session.commit()
except Exception:
db_session.rollback()
attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
db_session.commit()
except Exception:
db_session.rollback()
raise
def mark_attempt_failed(
@ -162,20 +195,20 @@ def mark_attempt_failed(
failure_reason: str = "Unknown",
full_exception_trace: str | None = None,
) -> None:
with db_session.begin_nested():
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
attempt.status = IndexingStatus.FAILED
attempt.error_msg = failure_reason
attempt.full_exception_trace = full_exception_trace
db_session.commit()
except Exception:
db_session.rollback()
attempt.status = IndexingStatus.FAILED
attempt.error_msg = failure_reason
attempt.full_exception_trace = full_exception_trace
db_session.commit()
except Exception:
db_session.rollback()
raise
source = index_attempt.connector_credential_pair.connector.source
optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})