Merge pull request #3948 from onyx-dot-app/feature/beat_rtvar

refactoring and update multiplier in real time
This commit is contained in:
rkuo-danswer
2025-02-11 14:05:14 -08:00
committed by GitHub
4 changed files with 168 additions and 92 deletions

View File

@ -3,42 +3,44 @@ from typing import Any
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
from onyx.background.celery.tasks.beat_schedule import ( from onyx.background.celery.tasks.beat_schedule import (
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule, beat_system_tasks as base_beat_system_tasks,
) )
from onyx.background.celery.tasks.beat_schedule import ( from onyx.background.celery.tasks.beat_schedule import (
tasks_to_schedule as base_tasks_to_schedule, beat_task_templates as base_beat_task_templates,
)
from onyx.background.celery.tasks.beat_schedule import generate_cloud_tasks
from onyx.background.celery.tasks.beat_schedule import (
get_tasks_to_schedule as base_get_tasks_to_schedule,
) )
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxCeleryTask
from shared_configs.configs import MULTI_TENANT from shared_configs.configs import MULTI_TENANT
ee_cloud_tasks_to_schedule = [ ee_beat_system_tasks: list[dict] = []
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report", ee_beat_task_templates: list[dict] = []
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR, ee_beat_task_templates.extend(
"schedule": timedelta(days=30), [
"options": { {
"priority": OnyxCeleryPriority.HIGHEST, "name": "autogenerate-usage-report",
"expires": BEAT_EXPIRES_DEFAULT, "task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
}, },
"kwargs": { {
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK, "name": "check-ttl-management",
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
}, },
}, ]
{ )
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
"kwargs": {
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
},
},
]
ee_tasks_to_schedule: list[dict] = [] ee_tasks_to_schedule: list[dict] = []
@ -65,9 +67,14 @@ if not MULTI_TENANT:
] ]
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]: def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule beat_system_tasks = ee_beat_system_tasks + base_beat_system_tasks
beat_task_templates = ee_beat_task_templates + base_beat_task_templates
cloud_tasks = generate_cloud_tasks(
beat_system_tasks, beat_task_templates, beat_multiplier
)
return cloud_tasks
def get_tasks_to_schedule() -> list[dict[str, Any]]: def get_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_tasks_to_schedule + base_tasks_to_schedule return ee_tasks_to_schedule + base_get_tasks_to_schedule()

View File

@ -1,41 +1,56 @@
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
from typing import cast
from celery import Celery from celery import Celery
from celery import signals from celery import signals
from celery.beat import PersistentScheduler # type: ignore from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init from celery.signals import beat_init
from celery.utils.log import get_task_logger
import onyx.background.celery.apps.app_base as app_base import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__) task_logger = get_task_logger(__name__)
celery_app = Celery(__name__) celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.beat") celery_app.config_from_object("onyx.background.celery.configs.beat")
class DynamicTenantScheduler(PersistentScheduler): class DynamicTenantScheduler(PersistentScheduler):
"""This scheduler is useful because we can dynamically adjust task generation rates
through it."""
RELOAD_INTERVAL = 60
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.info("Initializing DynamicTenantScheduler")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=2)
self.last_beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
self._reload_interval = timedelta(
seconds=DynamicTenantScheduler.RELOAD_INTERVAL
)
self._last_reload = self.app.now() - self._reload_interval self._last_reload = self.app.now() - self._reload_interval
# Let the parent class handle store initialization # Let the parent class handle store initialization
self.setup_schedule() self.setup_schedule()
self._try_updating_schedule() self._try_updating_schedule()
logger.info(f"Set reload interval to {self._reload_interval}") task_logger.info(
f"DynamicTenantScheduler initialized: reload_interval={self._reload_interval}"
)
def setup_schedule(self) -> None: def setup_schedule(self) -> None:
logger.info("Setting up initial schedule")
super().setup_schedule() super().setup_schedule()
logger.info("Initial schedule setup complete")
def tick(self) -> float: def tick(self) -> float:
retval = super().tick() retval = super().tick()
@ -44,36 +59,35 @@ class DynamicTenantScheduler(PersistentScheduler):
self._last_reload is None self._last_reload is None
or (now - self._last_reload) > self._reload_interval or (now - self._last_reload) > self._reload_interval
): ):
logger.info("Reload interval reached, initiating task update") task_logger.debug("Reload interval reached, initiating task update")
try: try:
self._try_updating_schedule() self._try_updating_schedule()
except (AttributeError, KeyError) as e: except (AttributeError, KeyError):
logger.exception(f"Failed to process task configuration: {str(e)}") task_logger.exception("Failed to process task configuration")
except Exception as e: except Exception:
logger.exception(f"Unexpected error updating tasks: {str(e)}") task_logger.exception("Unexpected error updating tasks")
self._last_reload = now self._last_reload = now
logger.info("Task update completed, reset reload timer")
return retval return retval
def _generate_schedule( def _generate_schedule(
self, tenant_ids: list[str] | list[None] self, tenant_ids: list[str] | list[None], beat_multiplier: float
) -> dict[str, dict[str, Any]]: ) -> dict[str, dict[str, Any]]:
"""Given a list of tenant id's, generates a new beat schedule for celery.""" """Given a list of tenant id's, generates a new beat schedule for celery."""
logger.info("Fetching tasks to schedule")
new_schedule: dict[str, dict[str, Any]] = {} new_schedule: dict[str, dict[str, Any]] = {}
if MULTI_TENANT: if MULTI_TENANT:
# cloud tasks only need the single task beat across all tenants # cloud tasks are system wide and thus only need to be on the beat schedule
# once for all tenants
get_cloud_tasks_to_schedule = fetch_versioned_implementation( get_cloud_tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule", "onyx.background.celery.tasks.beat_schedule",
"get_cloud_tasks_to_schedule", "get_cloud_tasks_to_schedule",
) )
cloud_tasks_to_schedule: list[ cloud_tasks_to_schedule: list[dict[str, Any]] = get_cloud_tasks_to_schedule(
dict[str, Any] beat_multiplier
] = get_cloud_tasks_to_schedule() )
for task in cloud_tasks_to_schedule: for task in cloud_tasks_to_schedule:
task_name = task["name"] task_name = task["name"]
cloud_task = { cloud_task = {
@ -82,11 +96,14 @@ class DynamicTenantScheduler(PersistentScheduler):
"kwargs": task.get("kwargs", {}), "kwargs": task.get("kwargs", {}),
} }
if options := task.get("options"): if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}") task_logger.debug(f"Adding options to task {task_name}: {options}")
cloud_task["options"] = options cloud_task["options"] = options
new_schedule[task_name] = cloud_task new_schedule[task_name] = cloud_task
# regular task beats are multiplied across all tenants # regular task beats are multiplied across all tenants
# note that currently this just schedules for a single tenant in self hosted
# and doesn't do anything in the cloud because it's much more scalable
# to schedule a single cloud beat task to dispatch per tenant tasks.
get_tasks_to_schedule = fetch_versioned_implementation( get_tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule" "onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
) )
@ -95,7 +112,7 @@ class DynamicTenantScheduler(PersistentScheduler):
for tenant_id in tenant_ids: for tenant_id in tenant_ids:
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST: if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
logger.info( task_logger.debug(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list" f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
) )
continue continue
@ -104,14 +121,14 @@ class DynamicTenantScheduler(PersistentScheduler):
task_name = task["name"] task_name = task["name"]
tenant_task_name = f"{task['name']}-{tenant_id}" tenant_task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {tenant_task_name}") task_logger.debug(f"Creating task configuration for {tenant_task_name}")
tenant_task = { tenant_task = {
"task": task["task"], "task": task["task"],
"schedule": task["schedule"], "schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id}, "kwargs": {"tenant_id": tenant_id},
} }
if options := task.get("options"): if options := task.get("options"):
logger.debug( task_logger.debug(
f"Adding options to task {tenant_task_name}: {options}" f"Adding options to task {tenant_task_name}: {options}"
) )
tenant_task["options"] = options tenant_task["options"] = options
@ -121,44 +138,57 @@ class DynamicTenantScheduler(PersistentScheduler):
def _try_updating_schedule(self) -> None: def _try_updating_schedule(self) -> None:
"""Only updates the actual beat schedule on the celery app when it changes""" """Only updates the actual beat schedule on the celery app when it changes"""
do_update = False
logger.info("_try_updating_schedule starting") r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
task_logger.debug("_try_updating_schedule starting")
tenant_ids = get_all_tenant_ids() tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} IDs") task_logger.debug(f"Found {len(tenant_ids)} IDs")
# get current schedule and extract current tenants # get current schedule and extract current tenants
current_schedule = self.schedule.items() current_schedule = self.schedule.items()
# there are no more per tenant beat tasks, so comment this out # get potential new state
# NOTE: we may not actualy need this scheduler any more and should beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
# test reverting to a regular beat schedule implementation beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
task_logger.error(
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
)
# current_tenants = set() new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)
# for task_name, _ in current_schedule:
# task_name = cast(str, task_name)
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
# continue
# if "_" in task_name: # if the schedule or beat multiplier has changed, update
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678" while True:
# # -> "12345678-abcd-efgh-ijkl-12345678" if beat_multiplier != self.last_beat_multiplier:
# current_tenants.add(task_name.split("_")[-1]) do_update = True
# logger.info(f"Found {len(current_tenants)} existing items in schedule") break
# for tenant_id in tenant_ids: if not DynamicTenantScheduler._compare_schedules(
# if tenant_id not in current_tenants: current_schedule, new_schedule
# logger.info(f"Processing new tenant: {tenant_id}") ):
do_update = True
break
new_schedule = self._generate_schedule(tenant_ids) break
if DynamicTenantScheduler._compare_schedules(current_schedule, new_schedule): if not do_update:
logger.info( # exit early if nothing changed
"_try_updating_schedule: Current schedule is up to date, no changes needed" task_logger.info(
f"_try_updating_schedule - Schedule unchanged: "
f"tasks={len(new_schedule)} "
f"beat_multiplier={beat_multiplier}"
) )
return return
logger.info( # schedule needs updating
task_logger.debug(
"Schedule update required", "Schedule update required",
extra={ extra={
"new_tasks": len(new_schedule), "new_tasks": len(new_schedule),
@ -185,11 +215,19 @@ class DynamicTenantScheduler(PersistentScheduler):
# Ensure changes are persisted # Ensure changes are persisted
self.sync() self.sync()
logger.info("_try_updating_schedule: Schedule updated successfully") task_logger.info(
f"_try_updating_schedule - Schedule updated: "
f"prev_num_tasks={len(current_schedule)} "
f"prev_beat_multiplier={self.last_beat_multiplier} "
f"tasks={len(new_schedule)} "
f"beat_multiplier={beat_multiplier}"
)
self.last_beat_multiplier = beat_multiplier
@staticmethod @staticmethod
def _compare_schedules(schedule1: dict, schedule2: dict) -> bool: def _compare_schedules(schedule1: dict, schedule2: dict) -> bool:
"""Compare schedules to determine if an update is needed. """Compare schedules by task name only to determine if an update is needed.
True if equivalent, False if not.""" True if equivalent, False if not."""
current_tasks = set(name for name, _ in schedule1) current_tasks = set(name for name, _ in schedule1)
new_tasks = set(schedule2.keys()) new_tasks = set(schedule2.keys())
@ -201,7 +239,7 @@ class DynamicTenantScheduler(PersistentScheduler):
@beat_init.connect @beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None: def on_beat_init(sender: Any, **kwargs: Any) -> None:
logger.info("beat_init signal received.") task_logger.info("beat_init signal received.")
# Celery beat shouldn't touch the db at all. But just setting a low minimum here. # Celery beat shouldn't touch the db at all. But just setting a low minimum here.
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME) SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)

View File

@ -1,3 +1,4 @@
import copy
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
@ -18,7 +19,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# hack to slow down task dispatch in the cloud until # hack to slow down task dispatch in the cloud until
# we have a better implementation (backpressure, etc) # we have a better implementation (backpressure, etc)
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 8 CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
# tasks that run in either self-hosted on cloud # tasks that run in either self-hosted on cloud
beat_task_templates: list[dict] = [] beat_task_templates: list[dict] = []
@ -121,7 +122,7 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
# constant options for cloud beat task generators # constant options for cloud beat task generators
task_schedule: timedelta = task["schedule"] task_schedule: timedelta = task["schedule"]
cloud_task["schedule"] = task_schedule * CLOUD_BEAT_SCHEDULE_MULTIPLIER cloud_task["schedule"] = task_schedule
cloud_task["options"] = {} cloud_task["options"] = {}
cloud_task["options"]["priority"] = OnyxCeleryPriority.HIGHEST cloud_task["options"]["priority"] = OnyxCeleryPriority.HIGHEST
cloud_task["options"]["expires"] = BEAT_EXPIRES_DEFAULT cloud_task["options"]["expires"] = BEAT_EXPIRES_DEFAULT
@ -141,9 +142,9 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
# tasks that only run in the cloud # tasks that only run in the cloud
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered # the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be seen
# by the DynamicTenantScheduler # by the DynamicTenantScheduler as system wide task and not a per tenant task
cloud_tasks_to_schedule: list[dict] = [ beat_system_tasks: list[dict] = [
# cloud specific tasks # cloud specific tasks
{ {
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic", "name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
@ -157,18 +158,45 @@ cloud_tasks_to_schedule: list[dict] = [
}, },
] ]
# generate our cloud and self-hosted beat tasks from the templates
for beat_task_template in beat_task_templates:
cloud_task = make_cloud_generator_task(beat_task_template)
cloud_tasks_to_schedule.append(cloud_task)
tasks_to_schedule: list[dict] = [] tasks_to_schedule: list[dict] = []
if not MULTI_TENANT: if not MULTI_TENANT:
tasks_to_schedule = beat_task_templates tasks_to_schedule = beat_task_templates
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]: def generate_cloud_tasks(
return cloud_tasks_to_schedule beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float
) -> list[dict[str, Any]]:
"""
beat_tasks: system wide tasks that can be sent as is
beat_templates: task templates that will be transformed into per tenant tasks via
the cloud_beat_task_generator
beat_multiplier: a multiplier that can be applied on top of the task schedule
to speed up or slow down the task generation rate. useful in production.
Returns a list of cloud tasks, which consists of incoming tasks + tasks generated
from incoming templates.
"""
if beat_multiplier <= 0:
raise ValueError("beat_multiplier must be positive!")
# start with the incoming beat tasks
cloud_tasks: list[dict] = copy.deepcopy(beat_tasks)
# generate our cloud tasks from the templates
for beat_template in beat_templates:
cloud_task = make_cloud_generator_task(beat_template)
cloud_tasks.append(cloud_task)
# factor in the cloud multiplier
for cloud_task in cloud_tasks:
cloud_task["schedule"] = cloud_task["schedule"] * beat_multiplier
return cloud_tasks
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
return generate_cloud_tasks(beat_system_tasks, beat_task_templates, beat_multiplier)
def get_tasks_to_schedule() -> list[dict[str, Any]]: def get_tasks_to_schedule() -> list[dict[str, Any]]:

View File

@ -346,6 +346,9 @@ ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud"
# the tenant id we use for system level redis operations # the tenant id we use for system level redis operations
ONYX_CLOUD_TENANT_ID = "cloud" ONYX_CLOUD_TENANT_ID = "cloud"
# the redis namespace for runtime variables
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
class OnyxCeleryTask: class OnyxCeleryTask:
DEFAULT = "celery" DEFAULT = "celery"