mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-29 11:12:02 +01:00
Add pagination to user group syncing
This commit is contained in:
parent
a2597d5f21
commit
fc3ed76d12
@ -45,9 +45,12 @@ def fetch_user_groups_for_user(
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def fetch_documents_for_user_group(
|
||||
db_session: Session, user_group_id: int
|
||||
) -> Sequence[Document]:
|
||||
def fetch_documents_for_user_group_paginated(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
last_document_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> tuple[Sequence[Document], str | None]:
|
||||
stmt = (
|
||||
select(Document)
|
||||
.join(
|
||||
@ -72,8 +75,13 @@ def fetch_documents_for_user_group(
|
||||
UserGroup__ConnectorCredentialPair.user_group_id == UserGroup.id,
|
||||
)
|
||||
.where(UserGroup.id == user_group_id)
|
||||
.order_by(Document.id)
|
||||
.limit(limit)
|
||||
)
|
||||
return db_session.scalars(stmt).all()
|
||||
if last_document_id is not None:
|
||||
stmt = stmt.where(Document.id > last_document_id)
|
||||
documents = db_session.scalars(stmt).all()
|
||||
return documents, documents[-1].id if documents else None
|
||||
|
||||
|
||||
def fetch_user_groups_for_documents(
|
||||
|
@ -7,16 +7,15 @@ from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.user_group import delete_user_group
|
||||
from ee.danswer.db.user_group import fetch_documents_for_user_group
|
||||
from ee.danswer.db.user_group import fetch_documents_for_user_group_paginated
|
||||
from ee.danswer.db.user_group import fetch_user_group
|
||||
from ee.danswer.db.user_group import mark_user_group_as_synced
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SYNC_BATCH_SIZE = 512
|
||||
_SYNC_BATCH_SIZE = 100
|
||||
|
||||
|
||||
def _sync_user_group_batch(
|
||||
@ -62,17 +61,26 @@ def sync_user_groups(user_group_id: int, db_session: Session) -> None:
|
||||
if user_group is None:
|
||||
raise ValueError(f"User group '{user_group_id}' does not exist")
|
||||
|
||||
documents_to_update = fetch_documents_for_user_group(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
)
|
||||
for document_batch in batch_generator(documents_to_update, _SYNC_BATCH_SIZE):
|
||||
cursor = None
|
||||
while True:
|
||||
# NOTE: this may miss some documents, but that is okay. Any new documents added
|
||||
# will be added with the correct group membership
|
||||
document_batch, cursor = fetch_documents_for_user_group_paginated(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
last_document_id=cursor,
|
||||
limit=_SYNC_BATCH_SIZE,
|
||||
)
|
||||
|
||||
_sync_user_group_batch(
|
||||
document_ids=[document.id for document in document_batch],
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if cursor is None:
|
||||
break
|
||||
|
||||
if user_group.is_up_for_deletion:
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user