mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-09 06:02:00 +02:00
Merge pull request #3648 from onyx-dot-app/bugfix/light_cpu
figuring out why multiprocessing set_start_method isn't working.
This commit is contained in:
@ -161,9 +161,34 @@ def on_task_postrun(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||||
"""The first signal sent on celery worker startup"""
|
"""The first signal sent on celery worker startup"""
|
||||||
|
|
||||||
|
# NOTE(rkuo): start method "fork" is unsafe and we really need it to be "spawn"
|
||||||
|
# But something is blocking set_start_method from working in the cloud unless
|
||||||
|
# force=True. so we use force=True as a fallback.
|
||||||
|
|
||||||
|
all_start_methods: list[str] = multiprocessing.get_all_start_methods()
|
||||||
|
logger.info(f"Multiprocessing all start methods: {all_start_methods}")
|
||||||
|
|
||||||
|
try:
|
||||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||||
|
except Exception:
|
||||||
|
logger.info(
|
||||||
|
"Multiprocessing set_start_method exceptioned. Trying force=True..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
multiprocessing.set_start_method(
|
||||||
|
"spawn", force=True
|
||||||
|
) # fork is unsafe, set to spawn
|
||||||
|
except Exception:
|
||||||
|
logger.info(
|
||||||
|
"Multiprocessing set_start_method force=True exceptioned even with force=True."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Multiprocessing selected start method: {multiprocessing.get_start_method()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import multiprocessing
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
from celery import signals
|
from celery import signals
|
||||||
from celery import Task
|
from celery import Task
|
||||||
|
from celery.apps.worker import Worker
|
||||||
from celery.signals import celeryd_init
|
from celery.signals import celeryd_init
|
||||||
from celery.signals import worker_init
|
from celery.signals import worker_init
|
||||||
from celery.signals import worker_ready
|
from celery.signals import worker_ready
|
||||||
@ -49,17 +49,16 @@ def on_task_postrun(
|
|||||||
|
|
||||||
|
|
||||||
@celeryd_init.connect
|
@celeryd_init.connect
|
||||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@worker_init.connect
|
@worker_init.connect
|
||||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||||
logger.info("worker_init signal received.")
|
logger.info("worker_init signal received.")
|
||||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
|
||||||
|
|
||||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||||
SqlEngine.init_engine(pool_size=4, max_overflow=12)
|
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||||
|
|
||||||
app_base.wait_for_redis(sender, **kwargs)
|
app_base.wait_for_redis(sender, **kwargs)
|
||||||
app_base.wait_for_db(sender, **kwargs)
|
app_base.wait_for_db(sender, **kwargs)
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import multiprocessing
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
from celery import signals
|
from celery import signals
|
||||||
from celery import Task
|
from celery import Task
|
||||||
|
from celery.apps.worker import Worker
|
||||||
from celery.signals import celeryd_init
|
from celery.signals import celeryd_init
|
||||||
from celery.signals import worker_init
|
from celery.signals import worker_init
|
||||||
from celery.signals import worker_process_init
|
from celery.signals import worker_process_init
|
||||||
@ -50,22 +50,21 @@ def on_task_postrun(
|
|||||||
|
|
||||||
|
|
||||||
@celeryd_init.connect
|
@celeryd_init.connect
|
||||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@worker_init.connect
|
@worker_init.connect
|
||||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||||
logger.info("worker_init signal received.")
|
logger.info("worker_init signal received.")
|
||||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
|
||||||
|
|
||||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||||
|
|
||||||
# rkuo: been seeing transient connection exceptions here, so upping the connection count
|
# rkuo: Transient errors keep happening in the indexing watchdog threads.
|
||||||
# from just concurrency/concurrency to concurrency/concurrency*2
|
# "SSL connection has been closed unexpectedly"
|
||||||
SqlEngine.init_engine(
|
# actually setting the spawn method in the cloud fixes 95% of these.
|
||||||
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
|
# setting pre ping might help even more, but not worrying about that yet
|
||||||
)
|
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||||
|
|
||||||
app_base.wait_for_redis(sender, **kwargs)
|
app_base.wait_for_redis(sender, **kwargs)
|
||||||
app_base.wait_for_db(sender, **kwargs)
|
app_base.wait_for_db(sender, **kwargs)
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import multiprocessing
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
from celery import signals
|
from celery import signals
|
||||||
from celery import Task
|
from celery import Task
|
||||||
|
from celery.apps.worker import Worker
|
||||||
from celery.signals import celeryd_init
|
from celery.signals import celeryd_init
|
||||||
from celery.signals import worker_init
|
from celery.signals import worker_init
|
||||||
from celery.signals import worker_ready
|
from celery.signals import worker_ready
|
||||||
@ -15,7 +15,6 @@ from onyx.db.engine import SqlEngine
|
|||||||
from onyx.utils.logger import setup_logger
|
from onyx.utils.logger import setup_logger
|
||||||
from shared_configs.configs import MULTI_TENANT
|
from shared_configs.configs import MULTI_TENANT
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
celery_app = Celery(__name__)
|
celery_app = Celery(__name__)
|
||||||
@ -49,17 +48,18 @@ def on_task_postrun(
|
|||||||
|
|
||||||
|
|
||||||
@celeryd_init.connect
|
@celeryd_init.connect
|
||||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@worker_init.connect
|
@worker_init.connect
|
||||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||||
logger.info("worker_init signal received.")
|
logger.info("worker_init signal received.")
|
||||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
|
||||||
|
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
|
||||||
|
|
||||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||||
|
|
||||||
app_base.wait_for_redis(sender, **kwargs)
|
app_base.wait_for_redis(sender, **kwargs)
|
||||||
app_base.wait_for_db(sender, **kwargs)
|
app_base.wait_for_db(sender, **kwargs)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@ -7,6 +6,7 @@ from celery import bootsteps # type: ignore
|
|||||||
from celery import Celery
|
from celery import Celery
|
||||||
from celery import signals
|
from celery import signals
|
||||||
from celery import Task
|
from celery import Task
|
||||||
|
from celery.apps.worker import Worker
|
||||||
from celery.exceptions import WorkerShutdown
|
from celery.exceptions import WorkerShutdown
|
||||||
from celery.signals import celeryd_init
|
from celery.signals import celeryd_init
|
||||||
from celery.signals import worker_init
|
from celery.signals import worker_init
|
||||||
@ -73,14 +73,13 @@ def on_task_postrun(
|
|||||||
|
|
||||||
|
|
||||||
@celeryd_init.connect
|
@celeryd_init.connect
|
||||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@worker_init.connect
|
@worker_init.connect
|
||||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||||
logger.info("worker_init signal received.")
|
logger.info("worker_init signal received.")
|
||||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
|
||||||
|
|
||||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||||
@ -135,7 +134,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
|||||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||||
|
|
||||||
# tacking on our own user data to the sender
|
# tacking on our own user data to the sender
|
||||||
sender.primary_worker_lock = lock
|
sender.primary_worker_lock = lock # type: ignore
|
||||||
|
|
||||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||||
# to a clean state (for our purposes, anyway)
|
# to a clean state (for our purposes, anyway)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
@ -853,11 +854,14 @@ def connector_indexing_proxy_task(
|
|||||||
search_settings_id: int,
|
search_settings_id: int,
|
||||||
tenant_id: str | None,
|
tenant_id: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
"""celery tasks are forked, but forking is unstable.
|
||||||
|
This is a thread that proxies work to a spawned task."""
|
||||||
|
|
||||||
task_logger.info(
|
task_logger.info(
|
||||||
f"Indexing watchdog - starting: attempt={index_attempt_id} "
|
f"Indexing watchdog - starting: attempt={index_attempt_id} "
|
||||||
f"cc_pair={cc_pair_id} "
|
f"cc_pair={cc_pair_id} "
|
||||||
f"search_settings={search_settings_id}"
|
f"search_settings={search_settings_id} "
|
||||||
|
f"mp_start_method={multiprocessing.get_start_method()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.request.id:
|
if not self.request.id:
|
||||||
|
@ -4,9 +4,10 @@ not follow the expected behavior, etc.
|
|||||||
|
|
||||||
NOTE: cannot use Celery directly due to
|
NOTE: cannot use Celery directly due to
|
||||||
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||||
|
import multiprocessing as mp
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import Process
|
from multiprocessing.context import SpawnProcess
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -63,7 +64,7 @@ class SimpleJob:
|
|||||||
"""Drop in replacement for `dask.distributed.Future`"""
|
"""Drop in replacement for `dask.distributed.Future`"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
process: Optional["Process"] = None
|
process: Optional["SpawnProcess"] = None
|
||||||
|
|
||||||
def cancel(self) -> bool:
|
def cancel(self) -> bool:
|
||||||
return self.release()
|
return self.release()
|
||||||
@ -131,7 +132,10 @@ class SimpleJobClient:
|
|||||||
job_id = self.job_id_counter
|
job_id = self.job_id_counter
|
||||||
self.job_id_counter += 1
|
self.job_id_counter += 1
|
||||||
|
|
||||||
process = Process(target=_run_in_process, args=(func, args), daemon=True)
|
# this approach allows us to always "spawn" a new process regardless of
|
||||||
|
# get_start_method's current setting
|
||||||
|
ctx = mp.get_context("spawn")
|
||||||
|
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
|
||||||
job = SimpleJob(id=job_id, process=process)
|
job = SimpleJob(id=job_id, process=process)
|
||||||
process.start()
|
process.start()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user