select only doc_id (#3920)

* select only doc_id

* select more doc ids

* fix user group

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer
2025-02-05 23:00:40 -08:00
committed by GitHub
parent a0a1b431be
commit 6ccb3f085a
7 changed files with 71 additions and 44 deletions

View File

@ -218,14 +218,14 @@ def fetch_user_groups_for_user(
return db_session.scalars(stmt).all() return db_session.scalars(stmt).all()
def construct_document_select_by_usergroup( def construct_document_id_select_by_usergroup(
user_group_id: int, user_group_id: int,
) -> Select: ) -> Select:
"""This returns a statement that should be executed using """This returns a statement that should be executed using
.yield_per() to minimize overhead. The primary consumers of this function .yield_per() to minimize overhead. The primary consumers of this function
are background processing task generators.""" are background processing task generators."""
stmt = ( stmt = (
select(Document) select(Document.id)
.join( .join(
DocumentByConnectorCredentialPair, DocumentByConnectorCredentialPair,
Document.id == DocumentByConnectorCredentialPair.id, Document.id == DocumentByConnectorCredentialPair.id,

View File

@ -179,11 +179,14 @@ def try_generate_document_cc_pair_cleanup_tasks(
if tasks_generated is None: if tasks_generated is None:
raise ValueError("RedisConnectorDeletion.generate_tasks returned None") raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
try:
insert_sync_record( insert_sync_record(
db_session=db_session, db_session=db_session,
entity_id=cc_pair_id, entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION, sync_type=SyncType.CONNECTOR_DELETION,
) )
except Exception:
pass
except TaskDependencyError: except TaskDependencyError:
redis_connector.delete.set_fence(None) redis_connector.delete.set_fence(None)

View File

@ -105,6 +105,32 @@ def construct_document_select_for_connector_credential_pair_by_needs_sync(
return stmt return stmt
def construct_document_id_select_for_connector_credential_pair_by_needs_sync(
connector_id: int, credential_id: int
) -> Select:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = (
select(DbDocument.id)
.where(
DbDocument.id.in_(initial_doc_ids_stmt),
or_(
DbDocument.last_modified
> DbDocument.last_synced, # last_modified is newer than last_synced
DbDocument.last_synced.is_(None), # never synced
),
)
.distinct()
)
return stmt
def get_all_documents_needing_vespa_sync_for_cc_pair( def get_all_documents_needing_vespa_sync_for_cc_pair(
db_session: Session, cc_pair_id: int db_session: Session, cc_pair_id: int
) -> list[DbDocument]: ) -> list[DbDocument]:

View File

@ -545,7 +545,7 @@ def fetch_documents_for_document_set_paginated(
return documents, documents[-1].id if documents else None return documents, documents[-1].id if documents else None
def construct_document_select_by_docset( def construct_document_id_select_by_docset(
document_set_id: int, document_set_id: int,
current_only: bool = True, current_only: bool = True,
) -> Select: ) -> Select:
@ -554,7 +554,7 @@ def construct_document_select_by_docset(
are background processing task generators.""" are background processing task generators."""
stmt = ( stmt = (
select(Document) select(Document.id)
.join( .join(
DocumentByConnectorCredentialPair, DocumentByConnectorCredentialPair,
DocumentByConnectorCredentialPair.id == Document.id, DocumentByConnectorCredentialPair.id == Document.id,

View File

@ -16,9 +16,8 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisConstants
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import ( from onyx.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync, construct_document_id_select_for_connector_credential_pair_by_needs_sync,
) )
from onyx.db.models import Document
from onyx.redis.redis_object_helper import RedisObjectHelper from onyx.redis.redis_object_helper import RedisObjectHelper
@ -72,7 +71,8 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
last_lock_time = time.monotonic() last_lock_time = time.monotonic()
async_results = [] num_tasks_sent = 0
cc_pair = get_connector_credential_pair_from_id( cc_pair = get_connector_credential_pair_from_id(
db_session=db_session, db_session=db_session,
cc_pair_id=int(self._id), cc_pair_id=int(self._id),
@ -80,14 +80,14 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
if not cc_pair: if not cc_pair:
return None return None
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync( stmt = construct_document_id_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id cc_pair.connector_id, cc_pair.credential_id
) )
num_docs = 0 num_docs = 0
for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc = cast(Document, doc) doc_id = cast(str, doc_id)
current_time = time.monotonic() current_time = time.monotonic()
if current_time - last_lock_time >= ( if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
@ -98,7 +98,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
num_docs += 1 num_docs += 1
# check if we should skip the document (typically because it's already syncing) # check if we should skip the document (typically because it's already syncing)
if doc.id in self.skip_docs: if doc_id in self.skip_docs:
continue continue
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
@ -114,21 +114,21 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
) )
# Priority on sync's triggered by new indexing should be medium # Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task( celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK, OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
kwargs=dict(document_id=doc.id, tenant_id=tenant_id), kwargs=dict(document_id=doc_id, tenant_id=tenant_id),
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC, queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id, task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM, priority=OnyxCeleryPriority.MEDIUM,
) )
async_results.append(result) num_tasks_sent += 1
self.skip_docs.add(doc.id) self.skip_docs.add(doc_id)
if len(async_results) >= max_tasks: if num_tasks_sent >= max_tasks:
break break
return len(async_results), num_docs return num_tasks_sent, num_docs
class RedisGlobalConnectorCredentialPair: class RedisGlobalConnectorCredentialPair:

View File

@ -14,8 +14,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisConstants
from onyx.db.document_set import construct_document_select_by_docset from onyx.db.document_set import construct_document_id_select_by_docset
from onyx.db.models import Document
from onyx.redis.redis_object_helper import RedisObjectHelper from onyx.redis.redis_object_helper import RedisObjectHelper
@ -66,10 +65,11 @@ class RedisDocumentSet(RedisObjectHelper):
""" """
last_lock_time = time.monotonic() last_lock_time = time.monotonic()
async_results = [] num_tasks_sent = 0
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): stmt = construct_document_id_select_by_docset(int(self._id), current_only=False)
doc = cast(Document, doc) for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc_id = cast(str, doc_id)
current_time = time.monotonic() current_time = time.monotonic()
if current_time - last_lock_time >= ( if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
@ -86,17 +86,17 @@ class RedisDocumentSet(RedisObjectHelper):
# add to the set BEFORE creating the task. # add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id) redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task( celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK, OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
kwargs=dict(document_id=doc.id, tenant_id=tenant_id), kwargs=dict(document_id=doc_id, tenant_id=tenant_id),
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC, queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id, task_id=custom_task_id,
priority=OnyxCeleryPriority.LOW, priority=OnyxCeleryPriority.LOW,
) )
async_results.append(result) num_tasks_sent += 1
return len(async_results), len(async_results) return num_tasks_sent, num_tasks_sent
def reset(self) -> None: def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)

View File

@ -14,7 +14,6 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisConstants
from onyx.db.models import Document
from onyx.redis.redis_object_helper import RedisObjectHelper from onyx.redis.redis_object_helper import RedisObjectHelper
from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version from onyx.utils.variable_functionality import global_version
@ -66,23 +65,22 @@ class RedisUserGroup(RedisObjectHelper):
user group up to date over multiple batches. user group up to date over multiple batches.
""" """
last_lock_time = time.monotonic() last_lock_time = time.monotonic()
num_tasks_sent = 0
async_results = []
if not global_version.is_ee_version(): if not global_version.is_ee_version():
return 0, 0 return 0, 0
try: try:
construct_document_select_by_usergroup = fetch_versioned_implementation( construct_document_id_select_by_usergroup = fetch_versioned_implementation(
"onyx.db.user_group", "onyx.db.user_group",
"construct_document_select_by_usergroup", "construct_document_id_select_by_usergroup",
) )
except ModuleNotFoundError: except ModuleNotFoundError:
return 0, 0 return 0, 0
stmt = construct_document_select_by_usergroup(int(self._id)) stmt = construct_document_id_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT):
doc = cast(Document, doc) doc_id = cast(str, doc_id)
current_time = time.monotonic() current_time = time.monotonic()
if current_time - last_lock_time >= ( if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
@ -99,17 +97,17 @@ class RedisUserGroup(RedisObjectHelper):
# add to the set BEFORE creating the task. # add to the set BEFORE creating the task.
redis_client.sadd(self.taskset_key, custom_task_id) redis_client.sadd(self.taskset_key, custom_task_id)
result = celery_app.send_task( celery_app.send_task(
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK, OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
kwargs=dict(document_id=doc.id, tenant_id=tenant_id), kwargs=dict(document_id=doc_id, tenant_id=tenant_id),
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC, queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
task_id=custom_task_id, task_id=custom_task_id,
priority=OnyxCeleryPriority.LOW, priority=OnyxCeleryPriority.LOW,
) )
async_results.append(result) num_tasks_sent += 1
return len(async_results), len(async_results) return num_tasks_sent, num_tasks_sent
def reset(self) -> None: def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)