From fc3ed76d129ed8e0c826c97d45012aab74d8d283 Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 14 May 2024 14:53:51 -0700 Subject: [PATCH] Add pagination to user group syncing --- backend/ee/danswer/db/user_group.py | 16 ++++++++++++---- backend/ee/danswer/user_groups/sync.py | 24 ++++++++++++++++-------- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py index beff6c58c..7a263d33c 100644 --- a/backend/ee/danswer/db/user_group.py +++ b/backend/ee/danswer/db/user_group.py @@ -45,9 +45,12 @@ def fetch_user_groups_for_user( return db_session.scalars(stmt).all() -def fetch_documents_for_user_group( - db_session: Session, user_group_id: int -) -> Sequence[Document]: +def fetch_documents_for_user_group_paginated( + db_session: Session, + user_group_id: int, + last_document_id: str | None = None, + limit: int = 100, +) -> tuple[Sequence[Document], str | None]: stmt = ( select(Document) .join( @@ -72,8 +75,13 @@ def fetch_documents_for_user_group( UserGroup__ConnectorCredentialPair.user_group_id == UserGroup.id, ) .where(UserGroup.id == user_group_id) + .order_by(Document.id) + .limit(limit) ) - return db_session.scalars(stmt).all() + if last_document_id is not None: + stmt = stmt.where(Document.id > last_document_id) + documents = db_session.scalars(stmt).all() + return documents, documents[-1].id if documents else None def fetch_user_groups_for_documents( diff --git a/backend/ee/danswer/user_groups/sync.py b/backend/ee/danswer/user_groups/sync.py index 4e6ce31d9..e33655ba2 100644 --- a/backend/ee/danswer/user_groups/sync.py +++ b/backend/ee/danswer/user_groups/sync.py @@ -7,16 +7,15 @@ from danswer.db.embedding_model import get_secondary_db_embedding_model from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import DocumentIndex from danswer.document_index.interfaces import UpdateRequest -from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger from ee.danswer.db.user_group import delete_user_group -from ee.danswer.db.user_group import fetch_documents_for_user_group +from ee.danswer.db.user_group import fetch_documents_for_user_group_paginated from ee.danswer.db.user_group import fetch_user_group from ee.danswer.db.user_group import mark_user_group_as_synced logger = setup_logger() -_SYNC_BATCH_SIZE = 512 +_SYNC_BATCH_SIZE = 100 def _sync_user_group_batch( @@ -62,17 +61,26 @@ def sync_user_groups(user_group_id: int, db_session: Session) -> None: if user_group is None: raise ValueError(f"User group '{user_group_id}' does not exist") - documents_to_update = fetch_documents_for_user_group( - db_session=db_session, - user_group_id=user_group_id, - ) - for document_batch in batch_generator(documents_to_update, _SYNC_BATCH_SIZE): + cursor = None + while True: + # NOTE: this may miss some documents, but that is okay. Any new documents added + # will be added with the correct group membership + document_batch, cursor = fetch_documents_for_user_group_paginated( + db_session=db_session, + user_group_id=user_group_id, + last_document_id=cursor, + limit=_SYNC_BATCH_SIZE, + ) + _sync_user_group_batch( document_ids=[document.id for document in document_batch], document_index=document_index, db_session=db_session, ) + if cursor is None: + break + if user_group.is_up_for_deletion: delete_user_group(db_session=db_session, user_group=user_group) else: