harden join function (#4424)

* harden join function

* remove log spam

* use time.monotonic

* add pid logging

* client only celery app

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer
2025-04-02 01:04:00 -07:00
committed by GitHub
parent be20586ba1
commit 8a8526dbbb
12 changed files with 85 additions and 32 deletions

View File

@@ -1,5 +1,6 @@
import logging
import multiprocessing
import os
import time
from typing import Any
from typing import cast
@@ -305,7 +306,7 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
logger.info(f"Running as a secondary celery worker: pid={os.getpid()}")
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5

View File

@@ -0,0 +1,7 @@
from celery import Celery
import onyx.background.celery.apps.app_base as app_base
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.client")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]

View File

@@ -1,4 +1,5 @@
import logging
import os
from typing import Any
from typing import cast
@@ -95,7 +96,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
logger.info("Running as the primary celery worker.")
logger.info(f"Running as the primary celery worker: pid={os.getpid()}")
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -0,0 +1,16 @@
import onyx.background.celery.configs.base as shared_config
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late

View File

@@ -0,0 +1,20 @@
"""Factory stub for running celery worker / celery beat.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker.
This is an app stub purely for sending tasks as a client.
"""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.client import celery_app
return celery_app
app = get_app()

View File

@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urljoin
from urllib.parse import urlparse
import requests
@@ -342,9 +343,14 @@ def build_confluence_document_id(
Returns:
str: The document id
"""
if is_cloud and not base_url.endswith("/wiki"):
base_url += "/wiki"
return f"{base_url}{content_url}"
# NOTE: urljoin is tricky and will drop the last segment of the base if it doesn't
# end with "/" because it believes that makes it a file.
final_url = base_url.rstrip("/") + "/"
if is_cloud and not final_url.endswith("/wiki/"):
final_url = urljoin(final_url, "wiki") + "/"
final_url = urljoin(final_url, content_url.lstrip("/"))
return final_url
def datetime_from_string(datetime_string: str) -> datetime:

View File

@@ -21,7 +21,7 @@ from onyx.background.celery.tasks.external_group_syncing.tasks import (
from onyx.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.indexing.models import IndexAttemptErrorPydantic
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
@@ -219,7 +219,7 @@ def update_cc_pair_status(
continue
# Revoke the task to prevent it from running
primary_app.control.revoke(index_payload.celery_task_id)
client_app.control.revoke(index_payload.celery_task_id)
# If it is running, then signaling for termination will get the
# watchdog thread to kill the spawned task
@@ -238,7 +238,7 @@ def update_cc_pair_status(
db_session.commit()
# this speeds up the start of indexing by firing the check immediately
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
kwargs=dict(tenant_id=tenant_id),
priority=OnyxCeleryPriority.HIGH,
@@ -376,7 +376,7 @@ def prune_cc_pair(
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_prune_generator_task(
primary_app, cc_pair, db_session, r, tenant_id
client_app, cc_pair, db_session, r, tenant_id
)
if not payload_id:
raise HTTPException(
@@ -450,7 +450,7 @@ def sync_cc_pair(
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_permissions_sync_task(
primary_app, cc_pair_id, r, tenant_id
client_app, cc_pair_id, r, tenant_id
)
if not payload_id:
raise HTTPException(
@@ -524,7 +524,7 @@ def sync_cc_pair_groups(
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_external_group_sync_task(
primary_app, cc_pair_id, r, tenant_id
client_app, cc_pair_id, r, tenant_id
)
if not payload_id:
raise HTTPException(
@@ -634,7 +634,7 @@ def associate_credential_to_connector(
)
# trigger indexing immediately
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},

View File

@@ -20,7 +20,7 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES
from onyx.configs.app_configs import MOCK_CONNECTOR_FILE_PATH
from onyx.configs.constants import DocumentSource
@@ -928,7 +928,7 @@ def create_connector_with_mock_credential(
)
# trigger indexing immediately
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
@@ -1314,7 +1314,7 @@ def trigger_indexing_for_cc_pair(
# run the beat task to pick up the triggers immediately
priority = OnyxCeleryPriority.HIGHEST if is_user_file else OnyxCeleryPriority.HIGH
logger.info(f"Sending indexing check task with priority {priority}")
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
priority=priority,
kwargs={"tenant_id": tenant_id},

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.document_set import check_document_sets_are_public
@@ -52,7 +52,7 @@ def create_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,
@@ -85,7 +85,7 @@ def patch_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,
@@ -108,7 +108,7 @@ def delete_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,

View File

@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
@@ -192,7 +192,7 @@ def create_deletion_attempt_for_connector_id(
db_session.commit()
# run the beat task to pick up this deletion from the db immediately
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},

View File

@@ -165,17 +165,18 @@ class DocumentManager:
doc["fields"]["document_id"]: doc["fields"] for doc in retrieved_docs_dict
}
# NOTE(rkuo): too much log spam
# Left this here for debugging purposes.
import json
# import json
print("DEBUGGING DOCUMENTS")
print(retrieved_docs)
for doc in retrieved_docs.values():
printable_doc = doc.copy()
print(printable_doc.keys())
printable_doc.pop("embeddings")
printable_doc.pop("title_embedding")
print(json.dumps(printable_doc, indent=2))
# print("DEBUGGING DOCUMENTS")
# print(retrieved_docs)
# for doc in retrieved_docs.values():
# printable_doc = doc.copy()
# print(printable_doc.keys())
# printable_doc.pop("embeddings")
# printable_doc.pop("title_embedding")
# print(json.dumps(printable_doc, indent=2))
for document in cc_pair.documents:
retrieved_doc = retrieved_docs.get(document.id)

View File

@@ -1,3 +1,4 @@
import time
from datetime import datetime
from datetime import timedelta
from urllib.parse import urlencode
@@ -191,7 +192,7 @@ class IndexAttemptManager:
user_performing_action: DATestUser | None = None,
) -> None:
"""Wait for an IndexAttempt to complete"""
start = datetime.now()
start = time.monotonic()
while True:
index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt_id,
@@ -203,7 +204,7 @@ class IndexAttemptManager:
print(f"IndexAttempt {index_attempt_id} completed")
return
elapsed = (datetime.now() - start).total_seconds()
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"IndexAttempt {index_attempt_id} did not complete within {timeout} seconds"