Lock improvement

This commit is contained in:
Weves 2024-05-08 15:48:09 -07:00 committed by Chris Weaver
parent 8cbf7c8097
commit 7ed176b7cc
6 changed files with 166 additions and 143 deletions

View File

@ -103,32 +103,30 @@ def sync_document_set_task(document_set_id: int) -> None:
logger.debug(f"Syncing document sets for: {document_ids}")
# Acquires a lock on the documents so that no other process can modify them
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
# get current state of document sets for these documents
document_set_map = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=document_ids, db_session=db_session
)
}
# get current state of document sets for these documents
document_set_map = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=document_ids, db_session=db_session
# update Vespa
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
}
# update Vespa
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
update_requests = [
UpdateRequest(
document_ids=[document_id],
document_sets=set(document_set_map.get(document_id, [])),
)
for document_id in document_ids
]
document_index.update(update_requests=update_requests)
# Commit to release the locks
db_session.commit()
update_requests = [
UpdateRequest(
document_ids=[document_id],
document_sets=set(document_set_map.get(document_id, [])),
)
for document_id in document_ids
]
document_index.update(update_requests=update_requests)
with Session(get_sqlalchemy_engine()) as db_session:
try:

View File

@ -19,8 +19,8 @@ from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from danswer.db.document import delete_document_by_connector_credential_pair
from danswer.db.document import delete_documents_complete
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document_connector_cnts
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
@ -54,57 +54,58 @@ def _delete_connector_credential_pair_batch(
with Session(get_sqlalchemy_engine()) as db_session:
# acquire lock for all documents in this batch so that indexing can't
# override the deletion
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
document_connector_cnts = get_document_connector_cnts(
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
)
# figure out which docs need to be completely deleted
document_ids_to_delete = [
document_id for document_id, cnt in document_connector_cnts if cnt == 1
]
logger.debug(f"Deleting documents: {document_ids_to_delete}")
document_index.delete(doc_ids=document_ids_to_delete)
delete_documents_complete(
db_session=db_session,
document_ids=document_ids_to_delete,
)
# figure out which docs need to be updated
document_ids_to_update = [
document_id for document_id, cnt in document_connector_cnts if cnt > 1
]
access_for_documents = get_access_for_documents(
document_ids=document_ids_to_update,
db_session=db_session,
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
update_requests = [
UpdateRequest(
document_ids=[document_id],
access=access,
):
document_connector_cnts = get_document_connector_cnts(
db_session=db_session, document_ids=document_ids
)
for document_id, access in access_for_documents.items()
]
logger.debug(f"Updating documents: {document_ids_to_update}")
document_index.update(update_requests=update_requests)
# figure out which docs need to be completely deleted
document_ids_to_delete = [
document_id for document_id, cnt in document_connector_cnts if cnt == 1
]
logger.debug(f"Deleting documents: {document_ids_to_delete}")
delete_document_by_connector_credential_pair(
db_session=db_session,
document_ids=document_ids_to_update,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
db_session.commit()
document_index.delete(doc_ids=document_ids_to_delete)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=document_ids_to_delete,
)
# figure out which docs need to be updated
document_ids_to_update = [
document_id for document_id, cnt in document_connector_cnts if cnt > 1
]
access_for_documents = get_access_for_documents(
document_ids=document_ids_to_update,
db_session=db_session,
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
update_requests = [
UpdateRequest(
document_ids=[document_id],
access=access,
)
for document_id, access in access_for_documents.items()
]
logger.debug(f"Updating documents: {document_ids_to_update}")
document_index.update(update_requests=update_requests)
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids_to_update,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
db_session.commit()
def cleanup_synced_entities(

View File

@ -1,4 +1,6 @@
import contextlib
import time
from collections.abc import Generator
from collections.abc import Sequence
from datetime import datetime
from uuid import UUID
@ -9,15 +11,17 @@ from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine.util import TransactionalContext
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from danswer.configs.constants import DEFAULT_BOOST
from danswer.db.feedback import delete_document_feedback_for_documents
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import Document as DbDocument
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.tag import delete_document_tags_for_documents
from danswer.db.tag import delete_document_tags_for_documents__no_commit
from danswer.db.utils import model_to_dict
from danswer.document_index.interfaces import DocumentMetadata
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@ -242,7 +246,7 @@ def upsert_documents_complete(
)
def delete_document_by_connector_credential_pair(
def delete_document_by_connector_credential_pair__no_commit(
db_session: Session,
document_ids: list[str],
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier
@ -263,19 +267,22 @@ def delete_document_by_connector_credential_pair(
db_session.execute(stmt)
def delete_documents(db_session: Session, document_ids: list[str]) -> None:
def delete_documents__no_commit(db_session: Session, document_ids: list[str]) -> None:
db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids)))
def delete_documents_complete(db_session: Session, document_ids: list[str]) -> None:
def delete_documents_complete__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
logger.info(f"Deleting {len(document_ids)} documents from the DB")
delete_document_by_connector_credential_pair(db_session, document_ids)
delete_document_feedback_for_documents(
delete_document_by_connector_credential_pair__no_commit(db_session, document_ids)
delete_document_feedback_for_documents__no_commit(
document_ids=document_ids, db_session=db_session
)
delete_document_tags_for_documents(document_ids=document_ids, db_session=db_session)
delete_documents(db_session, document_ids)
db_session.commit()
delete_document_tags_for_documents__no_commit(
document_ids=document_ids, db_session=db_session
)
delete_documents__no_commit(db_session, document_ids)
def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool:
@ -288,12 +295,18 @@ def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool
document IDs in a single call).
"""
stmt = (
select(DbDocument)
select(DbDocument.id)
.where(DbDocument.id.in_(document_ids))
.with_for_update(nowait=True)
)
# will raise exception if any of the documents are already locked
db_session.execute(stmt)
documents = db_session.scalars(stmt).all()
# make sure we found every document
if len(documents) != len(document_ids):
logger.warning("Didn't find row for all specified document IDs. Aborting.")
return False
return True
@ -301,20 +314,34 @@ _NUM_LOCK_ATTEMPTS = 10
_LOCK_RETRY_DELAY = 30
def prepare_to_modify_documents(db_session: Session, document_ids: list[str]) -> None:
@contextlib.contextmanager
def prepare_to_modify_documents(
db_session: Session, document_ids: list[str], retry_delay: int = _LOCK_RETRY_DELAY
) -> Generator[TransactionalContext, None, None]:
"""Try and acquire locks for the documents to prevent other jobs from
modifying them at the same time (e.g. avoid race conditions). This should be
called ahead of any modification to Vespa. Locks should be released by the
caller as soon as updates are complete by finishing the transaction."""
caller as soon as updates are complete by finishing the transaction.
NOTE: only one commit is allowed within the context manager returned by this funtion.
Multiple commits will result in a sqlalchemy.exc.InvalidRequestError.
NOTE: this function will commit any existing transaction.
"""
db_session.commit() # ensure that we're not in a transaction
lock_acquired = False
for _ in range(_NUM_LOCK_ATTEMPTS):
try:
lock_acquired = acquire_document_locks(
db_session=db_session, document_ids=document_ids
)
except Exception as e:
with db_session.begin() as transaction:
lock_acquired = acquire_document_locks(
db_session=db_session, document_ids=document_ids
)
if lock_acquired:
yield transaction
break
except OperationalError as e:
logger.info(f"Failed to acquire locks for documents, retrying. Error: {e}")
time.sleep(_LOCK_RETRY_DELAY)
time.sleep(retry_delay)
if not lock_acquired:
raise RuntimeError(

View File

@ -129,7 +129,7 @@ def create_doc_retrieval_feedback(
db_session.commit()
def delete_document_feedback_for_documents(
def delete_document_feedback_for_documents__no_commit(
document_ids: list[str], db_session: Session
) -> None:
"""NOTE: does not commit transaction so that this can be used as part of a

View File

@ -117,12 +117,11 @@ def get_tags_by_value_prefix_for_source_types(
return list(tags)
def delete_document_tags_for_documents(
def delete_document_tags_for_documents__no_commit(
document_ids: list[str], db_session: Session
) -> None:
stmt = delete(Document__Tag).where(Document__Tag.document_id.in_(document_ids))
db_session.execute(stmt)
db_session.commit()
orphan_tags_query = (
select(Tag.id)
@ -136,4 +135,3 @@ def delete_document_tags_for_documents(
if orphan_tags:
delete_orphan_tags_stmt = delete(Tag).where(Tag.id.in_(orphan_tags))
db_session.execute(delete_orphan_tags_stmt)
db_session.commit()

View File

@ -161,59 +161,58 @@ def index_doc_batch(
# Acquires a lock on the documents so that no other process can modify them
# NOTE: don't need to acquire till here, since this is when the actual race condition
# with Vespa can occur.
prepare_to_modify_documents(db_session=db_session, document_ids=updatable_ids)
# Attach the latest status from Postgres (source of truth for access) to each
# chunk. This access status will be attached to each chunk in the document index
# TODO: attach document sets to the chunk based on the status of Postgres as well
document_id_to_access_info = get_access_for_documents(
document_ids=updatable_ids, db_session=db_session
)
document_id_to_document_set = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
with prepare_to_modify_documents(db_session=db_session, document_ids=updatable_ids):
# Attach the latest status from Postgres (source of truth for access) to each
# chunk. This access status will be attached to each chunk in the document index
# TODO: attach document sets to the chunk based on the status of Postgres as well
document_id_to_access_info = get_access_for_documents(
document_ids=updatable_ids, db_session=db_session
)
}
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=document_id_to_access_info[chunk.source_document.id],
document_sets=set(
document_id_to_document_set.get(chunk.source_document.id, [])
),
boost=(
id_to_db_doc_map[chunk.source_document.id].boost
if chunk.source_document.id in id_to_db_doc_map
else DEFAULT_BOOST
),
document_id_to_document_set = {
document_id: document_sets
for document_id, document_sets in fetch_document_sets_for_documents(
document_ids=updatable_ids, db_session=db_session
)
}
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=document_id_to_access_info[chunk.source_document.id],
document_sets=set(
document_id_to_document_set.get(chunk.source_document.id, [])
),
boost=(
id_to_db_doc_map[chunk.source_document.id].boost
if chunk.source_document.id in id_to_db_doc_map
else DEFAULT_BOOST
),
)
for chunk in chunks_with_embeddings
]
logger.debug(
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}"
)
for chunk in chunks_with_embeddings
]
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
# in this set
insertion_records = document_index.index(chunks=access_aware_chunks)
logger.debug(
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}"
)
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
# in this set
insertion_records = document_index.index(chunks=access_aware_chunks)
successful_doc_ids = [record.document_id for record in insertion_records]
successful_docs = [
doc for doc in updatable_docs if doc.id in successful_doc_ids
]
successful_doc_ids = [record.document_id for record in insertion_records]
successful_docs = [doc for doc in updatable_docs if doc.id in successful_doc_ids]
# Update the time of latest version of the doc successfully indexed
ids_to_new_updated_at = {}
for doc in successful_docs:
if doc.doc_updated_at is None:
continue
ids_to_new_updated_at[doc.id] = doc.doc_updated_at
# Update the time of latest version of the doc successfully indexed
ids_to_new_updated_at = {}
for doc in successful_docs:
if doc.doc_updated_at is None:
continue
ids_to_new_updated_at[doc.id] = doc.doc_updated_at
update_docs_updated_at(
ids_to_new_updated_at=ids_to_new_updated_at, db_session=db_session
)
db_session.commit()
update_docs_updated_at(
ids_to_new_updated_at=ids_to_new_updated_at, db_session=db_session
)
return len([r for r in insertion_records if r.already_existed is False]), len(
chunks