Add pagination to document set syncing + improve speed

This commit is contained in:
Weves
2024-05-14 15:24:11 -07:00
committed by Chris Weaver
parent 6f90308278
commit 05bc6b1c65
3 changed files with 75 additions and 33 deletions

View File

@ -15,7 +15,7 @@ from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_documents from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import fetch_documents_for_document_set from danswer.db.document_set import fetch_documents_for_document_set_paginated
from danswer.db.document_set import get_document_set_by_id from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import build_connection_string from danswer.db.engine import build_connection_string
@ -27,7 +27,6 @@ from danswer.db.tasks import get_latest_task
from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@ -38,7 +37,7 @@ celery_backend_url = f"db+{connection_string}"
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url) celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
_SYNC_BATCH_SIZE = 1000 _SYNC_BATCH_SIZE = 100
##### #####
@ -126,18 +125,21 @@ def sync_document_set_task(document_set_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session: with Session(get_sqlalchemy_engine()) as db_session:
try: try:
documents_to_update = fetch_documents_for_document_set( cursor = None
document_set_id=document_set_id, while True:
db_session=db_session, document_batch, cursor = fetch_documents_for_document_set_paginated(
current_only=False, document_set_id=document_set_id,
) db_session=db_session,
for document_batch in batch_generator( current_only=False,
documents_to_update, _SYNC_BATCH_SIZE last_document_id=cursor,
): limit=_SYNC_BATCH_SIZE,
)
_sync_document_batch( _sync_document_batch(
document_ids=[document.id for document in document_batch], document_ids=[document.id for document in document_batch],
db_session=db_session, db_session=db_session,
) )
if cursor is None:
break
# if there are no connectors, then delete the document set. Otherwise, just # if there are no connectors, then delete the document set. Otherwise, just
# mark it as successfully synced. # mark it as successfully synced.

View File

@ -382,9 +382,13 @@ def fetch_user_document_sets(
) )
def fetch_documents_for_document_set( def fetch_documents_for_document_set_paginated(
document_set_id: int, db_session: Session, current_only: bool = True document_set_id: int,
) -> Sequence[Document]: db_session: Session,
current_only: bool = True,
last_document_id: str | None = None,
limit: int = 100,
) -> tuple[Sequence[Document], str | None]:
stmt = ( stmt = (
select(Document) select(Document)
.join( .join(
@ -411,14 +415,19 @@ def fetch_documents_for_document_set(
== DocumentSet__ConnectorCredentialPair.document_set_id, == DocumentSet__ConnectorCredentialPair.document_set_id,
) )
.where(DocumentSetDBModel.id == document_set_id) .where(DocumentSetDBModel.id == document_set_id)
.order_by(Document.id)
.limit(limit)
) )
if last_document_id is not None:
stmt = stmt.where(Document.id > last_document_id)
if current_only: if current_only:
stmt = stmt.where( stmt = stmt.where(
DocumentSet__ConnectorCredentialPair.is_current == True # noqa: E712 DocumentSet__ConnectorCredentialPair.is_current == True # noqa: E712
) )
stmt = stmt.distinct() stmt = stmt.distinct()
return db_session.scalars(stmt).all() documents = db_session.scalars(stmt).all()
return documents, documents[-1].id if documents else None
def fetch_document_sets_for_documents( def fetch_document_sets_for_documents(

View File

@ -90,7 +90,7 @@ SEARCH_ENDPOINT = f"{VESPA_APP_CONTAINER_URL}/search/"
_BATCH_SIZE = 128 # Specific to Vespa _BATCH_SIZE = 128 # Specific to Vespa
_NUM_THREADS = ( _NUM_THREADS = (
16 # since Vespa doesn't allow batching of inserts / updates, we use threads 32 # since Vespa doesn't allow batching of inserts / updates, we use threads
) )
# up from 500ms for now, since we've seen quite a few timeouts # up from 500ms for now, since we've seen quite a few timeouts
# in the long term, we are looking to improve the performance of Vespa # in the long term, we are looking to improve the performance of Vespa
@ -847,9 +847,46 @@ class VespaIndex(DocumentIndex):
def update(self, update_requests: list[UpdateRequest]) -> None: def update(self, update_requests: list[UpdateRequest]) -> None:
logger.info(f"Updating {len(update_requests)} documents in Vespa") logger.info(f"Updating {len(update_requests)} documents in Vespa")
start = time.time() update_start = time.monotonic()
processed_updates_requests: list[_VespaUpdateRequest] = [] processed_updates_requests: list[_VespaUpdateRequest] = []
all_doc_chunk_ids: dict[str, list[str]] = {}
# Fetch all chunks for each document ahead of time
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
chunk_id_start_time = time.monotonic()
with concurrent.futures.ThreadPoolExecutor(
max_workers=_NUM_THREADS
) as executor:
future_to_doc_chunk_ids = {
executor.submit(
_get_vespa_chunk_ids_by_document_id,
document_id=document_id,
index_name=index_name,
): (document_id, index_name)
for index_name in index_names
for update_request in update_requests
for document_id in update_request.document_ids
}
for future in concurrent.futures.as_completed(future_to_doc_chunk_ids):
document_id, index_name = future_to_doc_chunk_ids[future]
try:
doc_chunk_ids = future.result()
if document_id not in all_doc_chunk_ids:
all_doc_chunk_ids[document_id] = []
all_doc_chunk_ids[document_id].extend(doc_chunk_ids)
except Exception as e:
logger.error(
f"Error retrieving chunk IDs for document {document_id} in index {index_name}: {e}"
)
logger.debug(
f"Took {time.monotonic() - chunk_id_start_time:.2f} seconds to fetch all Vespa chunk IDs"
)
# Build the _VespaUpdateRequest objects
for update_request in update_requests: for update_request in update_requests:
update_dict: dict[str, dict] = {"fields": {}} update_dict: dict[str, dict] = {"fields": {}}
if update_request.boost is not None: if update_request.boost is not None:
@ -873,26 +910,20 @@ class VespaIndex(DocumentIndex):
logger.error("Update request received but nothing to update") logger.error("Update request received but nothing to update")
continue continue
index_names = [self.index_name] for document_id in update_request.document_ids:
if self.secondary_index_name: for doc_chunk_id in all_doc_chunk_ids[document_id]:
index_names.append(self.secondary_index_name) processed_updates_requests.append(
_VespaUpdateRequest(
for index_name in index_names: document_id=document_id,
for document_id in update_request.document_ids: url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
for doc_chunk_id in _get_vespa_chunk_ids_by_document_id( update_request=update_dict,
document_id=document_id, index_name=index_name
):
processed_updates_requests.append(
_VespaUpdateRequest(
document_id=document_id,
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}",
update_request=update_dict,
)
) )
)
self._apply_updates_batched(processed_updates_requests) self._apply_updates_batched(processed_updates_requests)
logger.info( logger.info(
"Finished updating Vespa documents in %s seconds", time.time() - start "Finished updating Vespa documents in %.2f seconds",
time.monotonic() - update_start,
) )
def delete(self, doc_ids: list[str]) -> None: def delete(self, doc_ids: list[str]) -> None: