diff --git a/backend/onyx/background/celery/apps/app_base.py b/backend/onyx/background/celery/apps/app_base.py index 3f0d50950..9b320aae4 100644 --- a/backend/onyx/background/celery/apps/app_base.py +++ b/backend/onyx/background/celery/apps/app_base.py @@ -161,7 +161,7 @@ def on_task_postrun( return -def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: +def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: """The first signal sent on celery worker startup""" # NOTE(rkuo): start method "fork" is unsafe and we really need it to be "spawn" diff --git a/backend/onyx/background/celery/apps/heavy.py b/backend/onyx/background/celery/apps/heavy.py index 7216e858d..4854940fd 100644 --- a/backend/onyx/background/celery/apps/heavy.py +++ b/backend/onyx/background/celery/apps/heavy.py @@ -3,6 +3,7 @@ from typing import Any from celery import Celery from celery import signals from celery import Task +from celery.apps.worker import Worker from celery.signals import celeryd_init from celery.signals import worker_init from celery.signals import worker_ready @@ -48,16 +49,16 @@ def on_task_postrun( @celeryd_init.connect -def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: +def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect -def on_worker_init(sender: Any, **kwargs: Any) -> None: +def on_worker_init(sender: Worker, **kwargs: Any) -> None: logger.info("worker_init signal received.") SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME) - SqlEngine.init_engine(pool_size=4, max_overflow=12) + SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore app_base.wait_for_redis(sender, **kwargs) app_base.wait_for_db(sender, **kwargs) diff --git a/backend/onyx/background/celery/apps/indexing.py b/backend/onyx/background/celery/apps/indexing.py index 0c116984f..89681ea74 100644 --- a/backend/onyx/background/celery/apps/indexing.py +++ b/backend/onyx/background/celery/apps/indexing.py @@ -3,6 +3,7 @@ from typing import Any from celery import Celery from celery import signals from celery import Task +from celery.apps.worker import Worker from celery.signals import celeryd_init from celery.signals import worker_init from celery.signals import worker_process_init @@ -49,19 +50,19 @@ def on_task_postrun( @celeryd_init.connect -def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: +def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect -def on_worker_init(sender: Any, **kwargs: Any) -> None: +def on_worker_init(sender: Worker, **kwargs: Any) -> None: logger.info("worker_init signal received.") SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME) # rkuo: been seeing transient connection exceptions here, so upping the connection count # from just concurrency/concurrency to concurrency/concurrency*2 - SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) + SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore app_base.wait_for_redis(sender, **kwargs) app_base.wait_for_db(sender, **kwargs) diff --git a/backend/onyx/background/celery/apps/light.py b/backend/onyx/background/celery/apps/light.py index 695bda69c..abc2cfab1 100644 --- a/backend/onyx/background/celery/apps/light.py +++ b/backend/onyx/background/celery/apps/light.py @@ -3,6 +3,7 @@ from typing import Any from celery import Celery from celery import signals from celery import Task +from celery.apps.worker import Worker from celery.signals import celeryd_init from celery.signals import worker_init from celery.signals import worker_ready @@ -14,7 +15,6 @@ from onyx.db.engine import SqlEngine from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT - logger = setup_logger() celery_app = Celery(__name__) @@ -48,18 +48,18 @@ def on_task_postrun( @celeryd_init.connect -def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: +def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect -def on_worker_init(sender: Any, **kwargs: Any) -> None: +def on_worker_init(sender: Worker, **kwargs: Any) -> None: logger.info("worker_init signal received.") - logger.info(f"Concurrency: {sender.concurrency}") + logger.info(f"Concurrency: {sender.concurrency}") # type: ignore SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME) - SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) + SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore app_base.wait_for_redis(sender, **kwargs) app_base.wait_for_db(sender, **kwargs) diff --git a/backend/onyx/background/celery/apps/primary.py b/backend/onyx/background/celery/apps/primary.py index b4f9868ac..8056e3d5e 100644 --- a/backend/onyx/background/celery/apps/primary.py +++ b/backend/onyx/background/celery/apps/primary.py @@ -6,6 +6,7 @@ from celery import bootsteps # type: ignore from celery import Celery from celery import signals from celery import Task +from celery.apps.worker import Worker from celery.exceptions import WorkerShutdown from celery.signals import celeryd_init from celery.signals import worker_init @@ -72,12 +73,12 @@ def on_task_postrun( @celeryd_init.connect -def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: +def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect -def on_worker_init(sender: Any, **kwargs: Any) -> None: +def on_worker_init(sender: Worker, **kwargs: Any) -> None: logger.info("worker_init signal received.") SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME) @@ -133,7 +134,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None: raise WorkerShutdown("Primary worker lock could not be acquired!") # tacking on our own user data to the sender - sender.primary_worker_lock = lock + sender.primary_worker_lock = lock # type: ignore # As currently designed, when this worker starts as "primary", we reinitialize redis # to a clean state (for our purposes, anyway)