select only doc_id (#3920)

* select only doc_id

* select more doc ids

* fix user group

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer 2025-02-05 23:00:40 -08:00 committed by GitHub
parent a0a1b431be
commit 6ccb3f085a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 71 additions and 44 deletions

View File

@ -218,14 +218,14 @@ def fetch_user_groups_for_user(
return db_session.scalars(stmt).all()
def construct_document_select_by_usergroup(
def construct_document_id_select_by_usergroup(
user_group_id: int,
) -> Select:
"""This returns a statement that should be executed using
.yield_per() to minimize overhead. The primary consumers of this function
are background processing task generators."""
stmt = (
select(Document)
select(Document.id)
.join(
DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id,

View File

@ -179,11 +179,14 @@ def try_generate_document_cc_pair_cleanup_tasks(
if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
try:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
except Exception:
pass
except TaskDependencyError:
redis_connector.delete.set_fence(None)

View File

@ -105,6 +105,32 @@ def construct_document_select_for_connector_credential_pair_by_needs_sync(
return stmt
def construct_document_id_select_for_connector_credential_pair_by_needs_sync(
connector_id: int, credential_id: int
) -> Select:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = (
select(DbDocument.id)
.where(
DbDocument.id.in_(initial_doc_ids_stmt),
or_(
DbDocument.last_modified
> DbDocument.last_synced, # last_modified is newer than last_synced
DbDocument.last_synced.is_(None), # never synced
),
)
.distinct()
)
return stmt
def get_all_documents_needing_vespa_sync_for_cc_pair(
db_session: Session, cc_pair_id: int
) -> list[DbDocument]:

View File

@ -545,7 +545,7 @@ def fetch_documents_for_document_set_paginated(
return documents, documents[-1].id if documents else None
def construct_document_select_by_docset(
def construct_document_id_select_by_docset(
document_set_id: int,
current_only: bool = True,
) -> Select:
@ -554,7 +554,7 @@ def construct_document_select_by_docset(
are background processing task generators."""
stmt = (
select(Document)
select(Document.id)
.join(
DocumentByConnectorCredentialPair,
DocumentByConnectorCredentialPair.id == Document.id,

View File

@ -16,9 +16,8 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
construct_document_id_select_for_connector_credential_pair_by_needs_sync,
)
from onyx.db.models import Document
from onyx.redis.redis_object_helper import RedisObjectHelper
@ -72,7 +71,8 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
last_lock_time = time.monotonic()
async_results = []
num_tasks_sent = 0
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=int(self._id),
@ -80,14 +80,14 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
stmt = construct_document_id_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
num_docs = 0
for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc = cast(Document, doc)
for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc_id = cast(str, doc_id)
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
@ -98,7 +98,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
num_docs += 1
# check if we should skip the document (typically because it's already syncing)
if doc.id in self.skip_docs:
if doc_id in self.skip_docs:
continue
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
@ -114,21 +114,21 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
kwargs=dict(document_id=doc_id, tenant_id=tenant_id),
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
)
async_results.append(result)
self.skip_docs.add(doc.id)
num_tasks_sent += 1
self.skip_docs.add(doc_id)
if len(async_results) >= max_tasks:
if num_tasks_sent >= max_tasks:
break
return len(async_results), num_docs
return num_tasks_sent, num_docs
class RedisGlobalConnectorCredentialPair:

View File

@ -14,8 +14,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.document_set import construct_document_select_by_docset
from onyx.db.models import Document
from onyx.db.document_set import construct_document_id_select_by_docset
from onyx.redis.redis_object_helper import RedisObjectHelper
@ -66,10 +65,11 @@ class RedisDocumentSet(RedisObjectHelper):
"""
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc = cast(Document, doc)
num_tasks_sent = 0
stmt = construct_document_id_select_by_docset(int(self._id), current_only=False)
for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc_id = cast(str, doc_id)
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
@ -86,17 +86,17 @@ class RedisDocumentSet(RedisObjectHelper):
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
kwargs=dict(document_id=doc_id, tenant_id=tenant_id),
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.LOW,
)
async_results.append(result)
num_tasks_sent += 1
return len(async_results), len(async_results)
return num_tasks_sent, num_tasks_sent
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)

View File

@ -14,7 +14,6 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.models import Document
from onyx.redis.redis_object_helper import RedisObjectHelper
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
@ -66,23 +65,22 @@ class RedisUserGroup(RedisObjectHelper):
user group up to date over multiple batches.
"""
last_lock_time = time.monotonic()
async_results = []
num_tasks_sent = 0
if not global_version.is_ee_version():
return 0, 0
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
construct_document_id_select_by_usergroup = fetch_versioned_implementation(
"onyx.db.user_group",
"construct_document_select_by_usergroup",
"construct_document_id_select_by_usergroup",
)
except ModuleNotFoundError:
return 0, 0
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc = cast(Document, doc)
stmt = construct_document_id_select_by_usergroup(int(self._id))
for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc_id = cast(str, doc_id)
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
@ -99,17 +97,17 @@ class RedisUserGroup(RedisObjectHelper):
# add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task(
celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
kwargs=dict(document_id=doc_id, tenant_id=tenant_id),
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.LOW,
)
async_results.append(result)
num_tasks_sent += 1
return len(async_results), len(async_results)
return num_tasks_sent, num_tasks_sent
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)