mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Lock improvement
This commit is contained in:
parent
8cbf7c8097
commit
7ed176b7cc
@ -103,32 +103,30 @@ def sync_document_set_task(document_set_id: int) -> None:
|
||||
logger.debug(f"Syncing document sets for: {document_ids}")
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
|
||||
with prepare_to_modify_documents(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
):
|
||||
# get current state of document sets for these documents
|
||||
document_set_map = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
}
|
||||
|
||||
# get current state of document sets for these documents
|
||||
document_set_map = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
# update Vespa
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
}
|
||||
|
||||
# update Vespa
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=set(document_set_map.get(document_id, [])),
|
||||
)
|
||||
for document_id in document_ids
|
||||
]
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
# Commit to release the locks
|
||||
db_session.commit()
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=set(document_set_map.get(document_id, [])),
|
||||
)
|
||||
for document_id in document_ids
|
||||
]
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
try:
|
||||
|
@ -19,8 +19,8 @@ from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair
|
||||
from danswer.db.document import delete_documents_complete
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document_connector_cnts
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
@ -54,57 +54,58 @@ def _delete_connector_credential_pair_batch(
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# acquire lock for all documents in this batch so that indexing can't
|
||||
# override the deletion
|
||||
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
|
||||
|
||||
document_connector_cnts = get_document_connector_cnts(
|
||||
with prepare_to_modify_documents(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
|
||||
# figure out which docs need to be completely deleted
|
||||
document_ids_to_delete = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt == 1
|
||||
]
|
||||
logger.debug(f"Deleting documents: {document_ids_to_delete}")
|
||||
|
||||
document_index.delete(doc_ids=document_ids_to_delete)
|
||||
|
||||
delete_documents_complete(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_delete,
|
||||
)
|
||||
|
||||
# figure out which docs need to be updated
|
||||
document_ids_to_update = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt > 1
|
||||
]
|
||||
access_for_documents = get_access_for_documents(
|
||||
document_ids=document_ids_to_update,
|
||||
db_session=db_session,
|
||||
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=access,
|
||||
):
|
||||
document_connector_cnts = get_document_connector_cnts(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
for document_id, access in access_for_documents.items()
|
||||
]
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
|
||||
document_index.update(update_requests=update_requests)
|
||||
# figure out which docs need to be completely deleted
|
||||
document_ids_to_delete = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt == 1
|
||||
]
|
||||
logger.debug(f"Deleting documents: {document_ids_to_delete}")
|
||||
|
||||
delete_document_by_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
db_session.commit()
|
||||
document_index.delete(doc_ids=document_ids_to_delete)
|
||||
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_delete,
|
||||
)
|
||||
|
||||
# figure out which docs need to be updated
|
||||
document_ids_to_update = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt > 1
|
||||
]
|
||||
access_for_documents = get_access_for_documents(
|
||||
document_ids=document_ids_to_update,
|
||||
db_session=db_session,
|
||||
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=access,
|
||||
)
|
||||
for document_id, access in access_for_documents.items()
|
||||
]
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def cleanup_synced_entities(
|
||||
|
@ -1,4 +1,6 @@
|
||||
import contextlib
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
@ -9,15 +11,17 @@ from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.engine.util import TransactionalContext
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.db.feedback import delete_document_feedback_for_documents
|
||||
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import Document as DbDocument
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.tag import delete_document_tags_for_documents
|
||||
from danswer.db.tag import delete_document_tags_for_documents__no_commit
|
||||
from danswer.db.utils import model_to_dict
|
||||
from danswer.document_index.interfaces import DocumentMetadata
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
@ -242,7 +246,7 @@ def upsert_documents_complete(
|
||||
)
|
||||
|
||||
|
||||
def delete_document_by_connector_credential_pair(
|
||||
def delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier
|
||||
@ -263,19 +267,22 @@ def delete_document_by_connector_credential_pair(
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def delete_documents(db_session: Session, document_ids: list[str]) -> None:
|
||||
def delete_documents__no_commit(db_session: Session, document_ids: list[str]) -> None:
|
||||
db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids)))
|
||||
|
||||
|
||||
def delete_documents_complete(db_session: Session, document_ids: list[str]) -> None:
|
||||
def delete_documents_complete__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
logger.info(f"Deleting {len(document_ids)} documents from the DB")
|
||||
delete_document_by_connector_credential_pair(db_session, document_ids)
|
||||
delete_document_feedback_for_documents(
|
||||
delete_document_by_connector_credential_pair__no_commit(db_session, document_ids)
|
||||
delete_document_feedback_for_documents__no_commit(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
delete_document_tags_for_documents(document_ids=document_ids, db_session=db_session)
|
||||
delete_documents(db_session, document_ids)
|
||||
db_session.commit()
|
||||
delete_document_tags_for_documents__no_commit(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
delete_documents__no_commit(db_session, document_ids)
|
||||
|
||||
|
||||
def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool:
|
||||
@ -288,12 +295,18 @@ def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool
|
||||
document IDs in a single call).
|
||||
"""
|
||||
stmt = (
|
||||
select(DbDocument)
|
||||
select(DbDocument.id)
|
||||
.where(DbDocument.id.in_(document_ids))
|
||||
.with_for_update(nowait=True)
|
||||
)
|
||||
# will raise exception if any of the documents are already locked
|
||||
db_session.execute(stmt)
|
||||
documents = db_session.scalars(stmt).all()
|
||||
|
||||
# make sure we found every document
|
||||
if len(documents) != len(document_ids):
|
||||
logger.warning("Didn't find row for all specified document IDs. Aborting.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@ -301,20 +314,34 @@ _NUM_LOCK_ATTEMPTS = 10
|
||||
_LOCK_RETRY_DELAY = 30
|
||||
|
||||
|
||||
def prepare_to_modify_documents(db_session: Session, document_ids: list[str]) -> None:
|
||||
@contextlib.contextmanager
|
||||
def prepare_to_modify_documents(
|
||||
db_session: Session, document_ids: list[str], retry_delay: int = _LOCK_RETRY_DELAY
|
||||
) -> Generator[TransactionalContext, None, None]:
|
||||
"""Try and acquire locks for the documents to prevent other jobs from
|
||||
modifying them at the same time (e.g. avoid race conditions). This should be
|
||||
called ahead of any modification to Vespa. Locks should be released by the
|
||||
caller as soon as updates are complete by finishing the transaction."""
|
||||
caller as soon as updates are complete by finishing the transaction.
|
||||
|
||||
NOTE: only one commit is allowed within the context manager returned by this funtion.
|
||||
Multiple commits will result in a sqlalchemy.exc.InvalidRequestError.
|
||||
NOTE: this function will commit any existing transaction.
|
||||
"""
|
||||
db_session.commit() # ensure that we're not in a transaction
|
||||
|
||||
lock_acquired = False
|
||||
for _ in range(_NUM_LOCK_ATTEMPTS):
|
||||
try:
|
||||
lock_acquired = acquire_document_locks(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
except Exception as e:
|
||||
with db_session.begin() as transaction:
|
||||
lock_acquired = acquire_document_locks(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
if lock_acquired:
|
||||
yield transaction
|
||||
break
|
||||
except OperationalError as e:
|
||||
logger.info(f"Failed to acquire locks for documents, retrying. Error: {e}")
|
||||
time.sleep(_LOCK_RETRY_DELAY)
|
||||
time.sleep(retry_delay)
|
||||
|
||||
if not lock_acquired:
|
||||
raise RuntimeError(
|
||||
|
@ -129,7 +129,7 @@ def create_doc_retrieval_feedback(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_document_feedback_for_documents(
|
||||
def delete_document_feedback_for_documents__no_commit(
|
||||
document_ids: list[str], db_session: Session
|
||||
) -> None:
|
||||
"""NOTE: does not commit transaction so that this can be used as part of a
|
||||
|
@ -117,12 +117,11 @@ def get_tags_by_value_prefix_for_source_types(
|
||||
return list(tags)
|
||||
|
||||
|
||||
def delete_document_tags_for_documents(
|
||||
def delete_document_tags_for_documents__no_commit(
|
||||
document_ids: list[str], db_session: Session
|
||||
) -> None:
|
||||
stmt = delete(Document__Tag).where(Document__Tag.document_id.in_(document_ids))
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
orphan_tags_query = (
|
||||
select(Tag.id)
|
||||
@ -136,4 +135,3 @@ def delete_document_tags_for_documents(
|
||||
if orphan_tags:
|
||||
delete_orphan_tags_stmt = delete(Tag).where(Tag.id.in_(orphan_tags))
|
||||
db_session.execute(delete_orphan_tags_stmt)
|
||||
db_session.commit()
|
||||
|
@ -161,59 +161,58 @@ def index_doc_batch(
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
# NOTE: don't need to acquire till here, since this is when the actual race condition
|
||||
# with Vespa can occur.
|
||||
prepare_to_modify_documents(db_session=db_session, document_ids=updatable_ids)
|
||||
|
||||
# Attach the latest status from Postgres (source of truth for access) to each
|
||||
# chunk. This access status will be attached to each chunk in the document index
|
||||
# TODO: attach document sets to the chunk based on the status of Postgres as well
|
||||
document_id_to_access_info = get_access_for_documents(
|
||||
document_ids=updatable_ids, db_session=db_session
|
||||
)
|
||||
document_id_to_document_set = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
with prepare_to_modify_documents(db_session=db_session, document_ids=updatable_ids):
|
||||
# Attach the latest status from Postgres (source of truth for access) to each
|
||||
# chunk. This access status will be attached to each chunk in the document index
|
||||
# TODO: attach document sets to the chunk based on the status of Postgres as well
|
||||
document_id_to_access_info = get_access_for_documents(
|
||||
document_ids=updatable_ids, db_session=db_session
|
||||
)
|
||||
}
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=document_id_to_access_info[chunk.source_document.id],
|
||||
document_sets=set(
|
||||
document_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
boost=(
|
||||
id_to_db_doc_map[chunk.source_document.id].boost
|
||||
if chunk.source_document.id in id_to_db_doc_map
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
document_id_to_document_set = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=updatable_ids, db_session=db_session
|
||||
)
|
||||
}
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=document_id_to_access_info[chunk.source_document.id],
|
||||
document_sets=set(
|
||||
document_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
boost=(
|
||||
id_to_db_doc_map[chunk.source_document.id].boost
|
||||
if chunk.source_document.id in id_to_db_doc_map
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
)
|
||||
for chunk in chunks_with_embeddings
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}"
|
||||
)
|
||||
for chunk in chunks_with_embeddings
|
||||
]
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
insertion_records = document_index.index(chunks=access_aware_chunks)
|
||||
|
||||
logger.debug(
|
||||
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}"
|
||||
)
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
insertion_records = document_index.index(chunks=access_aware_chunks)
|
||||
successful_doc_ids = [record.document_id for record in insertion_records]
|
||||
successful_docs = [
|
||||
doc for doc in updatable_docs if doc.id in successful_doc_ids
|
||||
]
|
||||
|
||||
successful_doc_ids = [record.document_id for record in insertion_records]
|
||||
successful_docs = [doc for doc in updatable_docs if doc.id in successful_doc_ids]
|
||||
# Update the time of latest version of the doc successfully indexed
|
||||
ids_to_new_updated_at = {}
|
||||
for doc in successful_docs:
|
||||
if doc.doc_updated_at is None:
|
||||
continue
|
||||
ids_to_new_updated_at[doc.id] = doc.doc_updated_at
|
||||
|
||||
# Update the time of latest version of the doc successfully indexed
|
||||
ids_to_new_updated_at = {}
|
||||
for doc in successful_docs:
|
||||
if doc.doc_updated_at is None:
|
||||
continue
|
||||
ids_to_new_updated_at[doc.id] = doc.doc_updated_at
|
||||
|
||||
update_docs_updated_at(
|
||||
ids_to_new_updated_at=ids_to_new_updated_at, db_session=db_session
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
update_docs_updated_at(
|
||||
ids_to_new_updated_at=ids_to_new_updated_at, db_session=db_session
|
||||
)
|
||||
|
||||
return len([r for r in insertion_records if r.already_existed is False]), len(
|
||||
chunks
|
||||
|
Loading…
x
Reference in New Issue
Block a user