add some type hints

This commit is contained in:
Richard Kuo (Danswer) 2025-01-13 11:31:57 -08:00
parent 9a09222b7d
commit b6dd999c1b
5 changed files with 18 additions and 15 deletions

View File

@ -161,7 +161,7 @@ def on_task_postrun(
return
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
# NOTE(rkuo): start method "fork" is unsafe and we really need it to be "spawn"

View File

@ -3,6 +3,7 @@ from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
@ -48,16 +49,16 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=4, max_overflow=12)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@ -3,6 +3,7 @@ from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
@ -49,19 +50,19 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
# rkuo: been seeing transient connection exceptions here, so upping the connection count
# from just concurrency/concurrency to concurrency/concurrency*2
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@ -3,6 +3,7 @@ from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
@ -14,7 +15,6 @@ from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
@ -48,18 +48,18 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Concurrency: {sender.concurrency}")
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@ -6,6 +6,7 @@ from celery import bootsteps # type: ignore
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.exceptions import WorkerShutdown
from celery.signals import celeryd_init
from celery.signals import worker_init
@ -72,12 +73,12 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
@ -133,7 +134,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
raise WorkerShutdown("Primary worker lock could not be acquired!")
# tacking on our own user data to the sender
sender.primary_worker_lock = lock
sender.primary_worker_lock = lock # type: ignore
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)