Add pagination to document set syncing + improve speed

This commit is contained in:
Weves 2024-05-14 15:24:11 -07:00 committed by Chris Weaver
parent 6f90308278
commit 05bc6b1c65
3 changed files with 75 additions and 33 deletions

View File

@ -15,7 +15,7 @@ from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.document_set import fetch_documents_for_document_set
from danswer.db.document_set import fetch_documents_for_document_set_paginated
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import build_connection_string
@ -27,7 +27,6 @@ from danswer.db.tasks import get_latest_task
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -38,7 +37,7 @@ celery_backend_url = f"db+{connection_string}"
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
_SYNC_BATCH_SIZE = 1000
_SYNC_BATCH_SIZE = 100
#####
@ -126,18 +125,21 @@ def sync_document_set_task(document_set_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
try:
documents_to_update = fetch_documents_for_document_set(
document_set_id=document_set_id,
db_session=db_session,
current_only=False,
)
for document_batch in batch_generator(
documents_to_update, _SYNC_BATCH_SIZE
):
cursor = None
while True:
document_batch, cursor = fetch_documents_for_document_set_paginated(
document_set_id=document_set_id,
db_session=db_session,
current_only=False,
last_document_id=cursor,
limit=_SYNC_BATCH_SIZE,
)
_sync_document_batch(
document_ids=[document.id for document in document_batch],
db_session=db_session,
)
if cursor is None:
break
# if there are no connectors, then delete the document set. Otherwise, just
# mark it as successfully synced.

View File

@ -382,9 +382,13 @@ def fetch_user_document_sets(
)
def fetch_documents_for_document_set(
document_set_id: int, db_session: Session, current_only: bool = True
) -> Sequence[Document]:
def fetch_documents_for_document_set_paginated(
document_set_id: int,
db_session: Session,
current_only: bool = True,
last_document_id: str | None = None,
limit: int = 100,
) -> tuple[Sequence[Document], str | None]:
stmt = (
select(Document)
.join(
@ -411,14 +415,19 @@ def fetch_documents_for_document_set(
== DocumentSet__ConnectorCredentialPair.document_set_id,
)
.where(DocumentSetDBModel.id == document_set_id)
.order_by(Document.id)
.limit(limit)
)
if last_document_id is not None:
stmt = stmt.where(Document.id > last_document_id)
if current_only:
stmt = stmt.where(
DocumentSet__ConnectorCredentialPair.is_current == True # noqa: E712
)
stmt = stmt.distinct()
return db_session.scalars(stmt).all()
documents = db_session.scalars(stmt).all()
return documents, documents[-1].id if documents else None
def fetch_document_sets_for_documents(

View File

@ -90,7 +90,7 @@ SEARCH_ENDPOINT = f"{VESPA_APP_CONTAINER_URL}/search/"
_BATCH_SIZE = 128 # Specific to Vespa
_NUM_THREADS = (
16 # since Vespa doesn't allow batching of inserts / updates, we use threads
32 # since Vespa doesn't allow batching of inserts / updates, we use threads
)
# up from 500ms for now, since we've seen quite a few timeouts
# in the long term, we are looking to improve the performance of Vespa
@ -847,9 +847,46 @@ class VespaIndex(DocumentIndex):
def update(self, update_requests: list[UpdateRequest]) -> None:
logger.info(f"Updating {len(update_requests)} documents in Vespa")
start = time.time()
update_start = time.monotonic()
processed_updates_requests: list[_VespaUpdateRequest] = []
all_doc_chunk_ids: dict[str, list[str]] = {}
# Fetch all chunks for each document ahead of time
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
chunk_id_start_time = time.monotonic()
with concurrent.futures.ThreadPoolExecutor(
max_workers=_NUM_THREADS
) as executor:
future_to_doc_chunk_ids = {
executor.submit(
_get_vespa_chunk_ids_by_document_id,
document_id=document_id,
index_name=index_name,
): (document_id, index_name)
for index_name in index_names
for update_request in update_requests
for document_id in update_request.document_ids
}
for future in concurrent.futures.as_completed(future_to_doc_chunk_ids):
document_id, index_name = future_to_doc_chunk_ids[future]
try:
doc_chunk_ids = future.result()
if document_id not in all_doc_chunk_ids:
all_doc_chunk_ids[document_id] = []
all_doc_chunk_ids[document_id].extend(doc_chunk_ids)
except Exception as e:
logger.error(
f"Error retrieving chunk IDs for document {document_id} in index {index_name}: {e}"
)
logger.debug(
f"Took {time.monotonic() - chunk_id_start_time:.2f} seconds to fetch all Vespa chunk IDs"
)
# Build the _VespaUpdateRequest objects
for update_request in update_requests:
update_dict: dict[str, dict] = {"fields": {}}
if update_request.boost is not None:
@ -873,26 +910,20 @@ class VespaIndex(DocumentIndex):
logger.error("Update request received but nothing to update")
continue
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
for index_name in index_names:
for document_id in update_request.document_ids:
for doc_chunk_id in _get_vespa_chunk_ids_by_document_id(
document_id=document_id, index_name=index_name
):
processed_updates_requests.append(
_VespaUpdateRequest(
document_id=document_id,
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}",
update_request=update_dict,
)
for document_id in update_request.document_ids:
for doc_chunk_id in all_doc_chunk_ids[document_id]:
processed_updates_requests.append(
_VespaUpdateRequest(
document_id=document_id,
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
update_request=update_dict,
)
)
self._apply_updates_batched(processed_updates_requests)
logger.info(
"Finished updating Vespa documents in %s seconds", time.time() - start
"Finished updating Vespa documents in %.2f seconds",
time.monotonic() - update_start,
)
def delete(self, doc_ids: list[str]) -> None: