diff --git a/backend/onyx/background/celery/apps/app_base.py b/backend/onyx/background/celery/apps/app_base.py index 22529a66c2..40a98f38ab 100644 --- a/backend/onyx/background/celery/apps/app_base.py +++ b/backend/onyx/background/celery/apps/app_base.py @@ -161,9 +161,34 @@ def on_task_postrun( return -def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: +def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: """The first signal sent on celery worker startup""" - multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn + + # NOTE(rkuo): start method "fork" is unsafe and we really need it to be "spawn" + # But something is blocking set_start_method from working in the cloud unless + # force=True. so we use force=True as a fallback. + + all_start_methods: list[str] = multiprocessing.get_all_start_methods() + logger.info(f"Multiprocessing all start methods: {all_start_methods}") + + try: + multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn + except Exception: + logger.info( + "Multiprocessing set_start_method exceptioned. Trying force=True..." + ) + try: + multiprocessing.set_start_method( + "spawn", force=True + ) # fork is unsafe, set to spawn + except Exception: + logger.info( + "Multiprocessing set_start_method force=True exceptioned even with force=True." + ) + + logger.info( + f"Multiprocessing selected start method: {multiprocessing.get_start_method()}" + ) def wait_for_redis(sender: Any, **kwargs: Any) -> None: diff --git a/backend/onyx/background/celery/apps/heavy.py b/backend/onyx/background/celery/apps/heavy.py index f45e6df9aa..4854940fd9 100644 --- a/backend/onyx/background/celery/apps/heavy.py +++ b/backend/onyx/background/celery/apps/heavy.py @@ -1,9 +1,9 @@ -import multiprocessing from typing import Any from celery import Celery from celery import signals from celery import Task +from celery.apps.worker import Worker from celery.signals import celeryd_init from celery.signals import worker_init from celery.signals import worker_ready @@ -49,17 +49,16 @@ def on_task_postrun( @celeryd_init.connect -def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: +def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect -def on_worker_init(sender: Any, **kwargs: Any) -> None: +def on_worker_init(sender: Worker, **kwargs: Any) -> None: logger.info("worker_init signal received.") - logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}") SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME) - SqlEngine.init_engine(pool_size=4, max_overflow=12) + SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore app_base.wait_for_redis(sender, **kwargs) app_base.wait_for_db(sender, **kwargs) diff --git a/backend/onyx/background/celery/apps/indexing.py b/backend/onyx/background/celery/apps/indexing.py index 9262b632dc..e222da5e3b 100644 --- a/backend/onyx/background/celery/apps/indexing.py +++ b/backend/onyx/background/celery/apps/indexing.py @@ -1,9 +1,9 @@ -import multiprocessing from typing import Any from celery import Celery from celery import signals from celery import Task +from celery.apps.worker import Worker from celery.signals import celeryd_init from celery.signals import worker_init from celery.signals import worker_process_init @@ -50,22 +50,21 @@ def on_task_postrun( @celeryd_init.connect -def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: +def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect -def on_worker_init(sender: Any, **kwargs: Any) -> None: +def on_worker_init(sender: Worker, **kwargs: Any) -> None: logger.info("worker_init signal received.") - logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}") SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME) - # rkuo: been seeing transient connection exceptions here, so upping the connection count - # from just concurrency/concurrency to concurrency/concurrency*2 - SqlEngine.init_engine( - pool_size=sender.concurrency, max_overflow=sender.concurrency * 2 - ) + # rkuo: Transient errors keep happening in the indexing watchdog threads. + # "SSL connection has been closed unexpectedly" + # actually setting the spawn method in the cloud fixes 95% of these. + # setting pre ping might help even more, but not worrying about that yet + SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore app_base.wait_for_redis(sender, **kwargs) app_base.wait_for_db(sender, **kwargs) diff --git a/backend/onyx/background/celery/apps/light.py b/backend/onyx/background/celery/apps/light.py index e6567b1477..abc2cfab12 100644 --- a/backend/onyx/background/celery/apps/light.py +++ b/backend/onyx/background/celery/apps/light.py @@ -1,9 +1,9 @@ -import multiprocessing from typing import Any from celery import Celery from celery import signals from celery import Task +from celery.apps.worker import Worker from celery.signals import celeryd_init from celery.signals import worker_init from celery.signals import worker_ready @@ -15,7 +15,6 @@ from onyx.db.engine import SqlEngine from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT - logger = setup_logger() celery_app = Celery(__name__) @@ -49,17 +48,18 @@ def on_task_postrun( @celeryd_init.connect -def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: +def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect -def on_worker_init(sender: Any, **kwargs: Any) -> None: +def on_worker_init(sender: Worker, **kwargs: Any) -> None: logger.info("worker_init signal received.") - logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}") + + logger.info(f"Concurrency: {sender.concurrency}") # type: ignore SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME) - SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) + SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore app_base.wait_for_redis(sender, **kwargs) app_base.wait_for_db(sender, **kwargs) diff --git a/backend/onyx/background/celery/apps/primary.py b/backend/onyx/background/celery/apps/primary.py index caa697f883..8056e3d5e1 100644 --- a/backend/onyx/background/celery/apps/primary.py +++ b/backend/onyx/background/celery/apps/primary.py @@ -1,5 +1,4 @@ import logging -import multiprocessing from typing import Any from typing import cast @@ -7,6 +6,7 @@ from celery import bootsteps # type: ignore from celery import Celery from celery import signals from celery import Task +from celery.apps.worker import Worker from celery.exceptions import WorkerShutdown from celery.signals import celeryd_init from celery.signals import worker_init @@ -73,14 +73,13 @@ def on_task_postrun( @celeryd_init.connect -def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: +def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect -def on_worker_init(sender: Any, **kwargs: Any) -> None: +def on_worker_init(sender: Worker, **kwargs: Any) -> None: logger.info("worker_init signal received.") - logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}") SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME) SqlEngine.init_engine(pool_size=8, max_overflow=0) @@ -135,7 +134,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: raise WorkerShutdown("Primary worker lock could not be acquired!") # tacking on our own user data to the sender - sender.primary_worker_lock = lock + sender.primary_worker_lock = lock # type: ignore # As currently designed, when this worker starts as "primary", we reinitialize redis # to a clean state (for our purposes, anyway) diff --git a/backend/onyx/background/celery/tasks/indexing/tasks.py b/backend/onyx/background/celery/tasks/indexing/tasks.py index 9fd73972d0..771ee8e709 100644 --- a/backend/onyx/background/celery/tasks/indexing/tasks.py +++ b/backend/onyx/background/celery/tasks/indexing/tasks.py @@ -1,3 +1,4 @@ +import multiprocessing import os import sys import time @@ -853,11 +854,14 @@ def connector_indexing_proxy_task( search_settings_id: int, tenant_id: str | None, ) -> None: - """celery tasks are forked, but forking is unstable. This proxies work to a spawned task.""" + """celery tasks are forked, but forking is unstable. + This is a thread that proxies work to a spawned task.""" + task_logger.info( f"Indexing watchdog - starting: attempt={index_attempt_id} " f"cc_pair={cc_pair_id} " - f"search_settings={search_settings_id}" + f"search_settings={search_settings_id} " + f"mp_start_method={multiprocessing.get_start_method()}" ) if not self.request.id: diff --git a/backend/onyx/background/indexing/job_client.py b/backend/onyx/background/indexing/job_client.py index 444894f8d6..a679eebe7f 100644 --- a/backend/onyx/background/indexing/job_client.py +++ b/backend/onyx/background/indexing/job_client.py @@ -4,9 +4,10 @@ not follow the expected behavior, etc. NOTE: cannot use Celery directly due to https://github.com/celery/celery/issues/7007#issuecomment-1740139367""" +import multiprocessing as mp from collections.abc import Callable from dataclasses import dataclass -from multiprocessing import Process +from multiprocessing.context import SpawnProcess from typing import Any from typing import Literal from typing import Optional @@ -63,7 +64,7 @@ class SimpleJob: """Drop in replacement for `dask.distributed.Future`""" id: int - process: Optional["Process"] = None + process: Optional["SpawnProcess"] = None def cancel(self) -> bool: return self.release() @@ -131,7 +132,10 @@ class SimpleJobClient: job_id = self.job_id_counter self.job_id_counter += 1 - process = Process(target=_run_in_process, args=(func, args), daemon=True) + # this approach allows us to always "spawn" a new process regardless of + # get_start_method's current setting + ctx = mp.get_context("spawn") + process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True) job = SimpleJob(id=job_id, process=process) process.start()