Lock improvement

This commit is contained in:
Weves
2024-05-08 15:48:09 -07:00
committed by Chris Weaver
parent 8cbf7c8097
commit 7ed176b7cc
6 changed files with 166 additions and 143 deletions

View File

@ -103,8 +103,9 @@ def sync_document_set_task(document_set_id: int) -> None:
logger.debug(f"Syncing document sets for: {document_ids}") logger.debug(f"Syncing document sets for: {document_ids}")
# Acquires a lock on the documents so that no other process can modify them # Acquires a lock on the documents so that no other process can modify them
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids) with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
# get current state of document sets for these documents # get current state of document sets for these documents
document_set_map = { document_set_map = {
document_id: document_sets document_id: document_sets
@ -127,9 +128,6 @@ def sync_document_set_task(document_set_id: int) -> None:
] ]
document_index.update(update_requests=update_requests) document_index.update(update_requests=update_requests)
# Commit to release the locks
db_session.commit()
with Session(get_sqlalchemy_engine()) as db_session: with Session(get_sqlalchemy_engine()) as db_session:
try: try:
documents_to_update = fetch_documents_for_document_set( documents_to_update = fetch_documents_for_document_set(

View File

@ -19,8 +19,8 @@ from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import ( from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit, delete_connector_credential_pair__no_commit,
) )
from danswer.db.document import delete_document_by_connector_credential_pair from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document_connector_cnts from danswer.db.document import get_document_connector_cnts
from danswer.db.document import get_documents_for_connector_credential_pair from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents from danswer.db.document import prepare_to_modify_documents
@ -54,8 +54,9 @@ def _delete_connector_credential_pair_batch(
with Session(get_sqlalchemy_engine()) as db_session: with Session(get_sqlalchemy_engine()) as db_session:
# acquire lock for all documents in this batch so that indexing can't # acquire lock for all documents in this batch so that indexing can't
# override the deletion # override the deletion
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids) with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
document_connector_cnts = get_document_connector_cnts( document_connector_cnts = get_document_connector_cnts(
db_session=db_session, document_ids=document_ids db_session=db_session, document_ids=document_ids
) )
@ -68,7 +69,7 @@ def _delete_connector_credential_pair_batch(
document_index.delete(doc_ids=document_ids_to_delete) document_index.delete(doc_ids=document_ids_to_delete)
delete_documents_complete( delete_documents_complete__no_commit(
db_session=db_session, db_session=db_session,
document_ids=document_ids_to_delete, document_ids=document_ids_to_delete,
) )
@ -96,7 +97,7 @@ def _delete_connector_credential_pair_batch(
document_index.update(update_requests=update_requests) document_index.update(update_requests=update_requests)
delete_document_by_connector_credential_pair( delete_document_by_connector_credential_pair__no_commit(
db_session=db_session, db_session=db_session,
document_ids=document_ids_to_update, document_ids=document_ids_to_update,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier( connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(

View File

@ -1,4 +1,6 @@
import contextlib
import time import time
from collections.abc import Generator
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
@ -9,15 +11,17 @@ from sqlalchemy import func
from sqlalchemy import or_ from sqlalchemy import or_
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine.util import TransactionalContext
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.configs.constants import DEFAULT_BOOST from danswer.configs.constants import DEFAULT_BOOST
from danswer.db.feedback import delete_document_feedback_for_documents from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
from danswer.db.models import ConnectorCredentialPair from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential from danswer.db.models import Credential
from danswer.db.models import Document as DbDocument from danswer.db.models import Document as DbDocument
from danswer.db.models import DocumentByConnectorCredentialPair from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.tag import delete_document_tags_for_documents from danswer.db.tag import delete_document_tags_for_documents__no_commit
from danswer.db.utils import model_to_dict from danswer.db.utils import model_to_dict
from danswer.document_index.interfaces import DocumentMetadata from danswer.document_index.interfaces import DocumentMetadata
from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@ -242,7 +246,7 @@ def upsert_documents_complete(
) )
def delete_document_by_connector_credential_pair( def delete_document_by_connector_credential_pair__no_commit(
db_session: Session, db_session: Session,
document_ids: list[str], document_ids: list[str],
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier connector_credential_pair_identifier: ConnectorCredentialPairIdentifier
@ -263,19 +267,22 @@ def delete_document_by_connector_credential_pair(
db_session.execute(stmt) db_session.execute(stmt)
def delete_documents(db_session: Session, document_ids: list[str]) -> None: def delete_documents__no_commit(db_session: Session, document_ids: list[str]) -> None:
db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids))) db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids)))
def delete_documents_complete(db_session: Session, document_ids: list[str]) -> None: def delete_documents_complete__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
logger.info(f"Deleting {len(document_ids)} documents from the DB") logger.info(f"Deleting {len(document_ids)} documents from the DB")
delete_document_by_connector_credential_pair(db_session, document_ids) delete_document_by_connector_credential_pair__no_commit(db_session, document_ids)
delete_document_feedback_for_documents( delete_document_feedback_for_documents__no_commit(
document_ids=document_ids, db_session=db_session document_ids=document_ids, db_session=db_session
) )
delete_document_tags_for_documents(document_ids=document_ids, db_session=db_session) delete_document_tags_for_documents__no_commit(
delete_documents(db_session, document_ids) document_ids=document_ids, db_session=db_session
db_session.commit() )
delete_documents__no_commit(db_session, document_ids)
def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool: def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool:
@ -288,12 +295,18 @@ def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool
document IDs in a single call). document IDs in a single call).
""" """
stmt = ( stmt = (
select(DbDocument) select(DbDocument.id)
.where(DbDocument.id.in_(document_ids)) .where(DbDocument.id.in_(document_ids))
.with_for_update(nowait=True) .with_for_update(nowait=True)
) )
# will raise exception if any of the documents are already locked # will raise exception if any of the documents are already locked
db_session.execute(stmt) documents = db_session.scalars(stmt).all()
# make sure we found every document
if len(documents) != len(document_ids):
logger.warning("Didn't find row for all specified document IDs. Aborting.")
return False
return True return True
@ -301,20 +314,34 @@ _NUM_LOCK_ATTEMPTS = 10
_LOCK_RETRY_DELAY = 30 _LOCK_RETRY_DELAY = 30
def prepare_to_modify_documents(db_session: Session, document_ids: list[str]) -> None: @contextlib.contextmanager
def prepare_to_modify_documents(
db_session: Session, document_ids: list[str], retry_delay: int = _LOCK_RETRY_DELAY
) -> Generator[TransactionalContext, None, None]:
"""Try and acquire locks for the documents to prevent other jobs from """Try and acquire locks for the documents to prevent other jobs from
modifying them at the same time (e.g. avoid race conditions). This should be modifying them at the same time (e.g. avoid race conditions). This should be
called ahead of any modification to Vespa. Locks should be released by the called ahead of any modification to Vespa. Locks should be released by the
caller as soon as updates are complete by finishing the transaction.""" caller as soon as updates are complete by finishing the transaction.
NOTE: only one commit is allowed within the context manager returned by this funtion.
Multiple commits will result in a sqlalchemy.exc.InvalidRequestError.
NOTE: this function will commit any existing transaction.
"""
db_session.commit() # ensure that we're not in a transaction
lock_acquired = False lock_acquired = False
for _ in range(_NUM_LOCK_ATTEMPTS): for _ in range(_NUM_LOCK_ATTEMPTS):
try: try:
with db_session.begin() as transaction:
lock_acquired = acquire_document_locks( lock_acquired = acquire_document_locks(
db_session=db_session, document_ids=document_ids db_session=db_session, document_ids=document_ids
) )
except Exception as e: if lock_acquired:
yield transaction
break
except OperationalError as e:
logger.info(f"Failed to acquire locks for documents, retrying. Error: {e}") logger.info(f"Failed to acquire locks for documents, retrying. Error: {e}")
time.sleep(_LOCK_RETRY_DELAY) time.sleep(retry_delay)
if not lock_acquired: if not lock_acquired:
raise RuntimeError( raise RuntimeError(

View File

@ -129,7 +129,7 @@ def create_doc_retrieval_feedback(
db_session.commit() db_session.commit()
def delete_document_feedback_for_documents( def delete_document_feedback_for_documents__no_commit(
document_ids: list[str], db_session: Session document_ids: list[str], db_session: Session
) -> None: ) -> None:
"""NOTE: does not commit transaction so that this can be used as part of a """NOTE: does not commit transaction so that this can be used as part of a

View File

@ -117,12 +117,11 @@ def get_tags_by_value_prefix_for_source_types(
return list(tags) return list(tags)
def delete_document_tags_for_documents( def delete_document_tags_for_documents__no_commit(
document_ids: list[str], db_session: Session document_ids: list[str], db_session: Session
) -> None: ) -> None:
stmt = delete(Document__Tag).where(Document__Tag.document_id.in_(document_ids)) stmt = delete(Document__Tag).where(Document__Tag.document_id.in_(document_ids))
db_session.execute(stmt) db_session.execute(stmt)
db_session.commit()
orphan_tags_query = ( orphan_tags_query = (
select(Tag.id) select(Tag.id)
@ -136,4 +135,3 @@ def delete_document_tags_for_documents(
if orphan_tags: if orphan_tags:
delete_orphan_tags_stmt = delete(Tag).where(Tag.id.in_(orphan_tags)) delete_orphan_tags_stmt = delete(Tag).where(Tag.id.in_(orphan_tags))
db_session.execute(delete_orphan_tags_stmt) db_session.execute(delete_orphan_tags_stmt)
db_session.commit()

View File

@ -161,8 +161,7 @@ def index_doc_batch(
# Acquires a lock on the documents so that no other process can modify them # Acquires a lock on the documents so that no other process can modify them
# NOTE: don't need to acquire till here, since this is when the actual race condition # NOTE: don't need to acquire till here, since this is when the actual race condition
# with Vespa can occur. # with Vespa can occur.
prepare_to_modify_documents(db_session=db_session, document_ids=updatable_ids) with prepare_to_modify_documents(db_session=db_session, document_ids=updatable_ids):
# Attach the latest status from Postgres (source of truth for access) to each # Attach the latest status from Postgres (source of truth for access) to each
# chunk. This access status will be attached to each chunk in the document index # chunk. This access status will be attached to each chunk in the document index
# TODO: attach document sets to the chunk based on the status of Postgres as well # TODO: attach document sets to the chunk based on the status of Postgres as well
@ -200,7 +199,9 @@ def index_doc_batch(
insertion_records = document_index.index(chunks=access_aware_chunks) insertion_records = document_index.index(chunks=access_aware_chunks)
successful_doc_ids = [record.document_id for record in insertion_records] successful_doc_ids = [record.document_id for record in insertion_records]
successful_docs = [doc for doc in updatable_docs if doc.id in successful_doc_ids] successful_docs = [
doc for doc in updatable_docs if doc.id in successful_doc_ids
]
# Update the time of latest version of the doc successfully indexed # Update the time of latest version of the doc successfully indexed
ids_to_new_updated_at = {} ids_to_new_updated_at = {}
@ -213,8 +214,6 @@ def index_doc_batch(
ids_to_new_updated_at=ids_to_new_updated_at, db_session=db_session ids_to_new_updated_at=ids_to_new_updated_at, db_session=db_session
) )
db_session.commit()
return len([r for r in insertion_records if r.already_existed is False]), len( return len([r for r in insertion_records if r.already_existed is False]), len(
chunks chunks
) )