mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 03:48:14 +02:00
Merge pull request #3948 from onyx-dot-app/feature/beat_rtvar
refactoring and update multiplier in real time
This commit is contained in:
commit
552a0630fe
@ -3,42 +3,44 @@ from typing import Any
|
||||
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
|
||||
beat_system_tasks as base_beat_system_tasks,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
tasks_to_schedule as base_tasks_to_schedule,
|
||||
beat_task_templates as base_beat_task_templates,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import generate_cloud_tasks
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
get_tasks_to_schedule as base_get_tasks_to_schedule,
|
||||
)
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
ee_cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
ee_beat_system_tasks: list[dict] = []
|
||||
|
||||
ee_beat_task_templates: list[dict] = []
|
||||
ee_beat_task_templates.extend(
|
||||
[
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
},
|
||||
},
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
@ -65,9 +67,14 @@ if not MULTI_TENANT:
|
||||
]
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
|
||||
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
|
||||
beat_system_tasks = ee_beat_system_tasks + base_beat_system_tasks
|
||||
beat_task_templates = ee_beat_task_templates + base_beat_task_templates
|
||||
cloud_tasks = generate_cloud_tasks(
|
||||
beat_system_tasks, beat_task_templates, beat_multiplier
|
||||
)
|
||||
return cloud_tasks
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_tasks_to_schedule + base_tasks_to_schedule
|
||||
return ee_tasks_to_schedule + base_get_tasks_to_schedule()
|
||||
|
@ -1,41 +1,56 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.signals import beat_init
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.beat")
|
||||
|
||||
|
||||
class DynamicTenantScheduler(PersistentScheduler):
|
||||
"""This scheduler is useful because we can dynamically adjust task generation rates
|
||||
through it."""
|
||||
|
||||
RELOAD_INTERVAL = 60
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
logger.info("Initializing DynamicTenantScheduler")
|
||||
super().__init__(*args, **kwargs)
|
||||
self._reload_interval = timedelta(minutes=2)
|
||||
|
||||
self.last_beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
|
||||
self._reload_interval = timedelta(
|
||||
seconds=DynamicTenantScheduler.RELOAD_INTERVAL
|
||||
)
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
|
||||
# Let the parent class handle store initialization
|
||||
self.setup_schedule()
|
||||
self._try_updating_schedule()
|
||||
logger.info(f"Set reload interval to {self._reload_interval}")
|
||||
task_logger.info(
|
||||
f"DynamicTenantScheduler initialized: reload_interval={self._reload_interval}"
|
||||
)
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
logger.info("Setting up initial schedule")
|
||||
super().setup_schedule()
|
||||
logger.info("Initial schedule setup complete")
|
||||
|
||||
def tick(self) -> float:
|
||||
retval = super().tick()
|
||||
@ -44,36 +59,35 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reload interval reached, initiating task update")
|
||||
task_logger.debug("Reload interval reached, initiating task update")
|
||||
try:
|
||||
self._try_updating_schedule()
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tasks: {str(e)}")
|
||||
except (AttributeError, KeyError):
|
||||
task_logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected error updating tasks")
|
||||
|
||||
self._last_reload = now
|
||||
logger.info("Task update completed, reset reload timer")
|
||||
|
||||
return retval
|
||||
|
||||
def _generate_schedule(
|
||||
self, tenant_ids: list[str] | list[None]
|
||||
self, tenant_ids: list[str] | list[None], beat_multiplier: float
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Given a list of tenant id's, generates a new beat schedule for celery."""
|
||||
logger.info("Fetching tasks to schedule")
|
||||
|
||||
new_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if MULTI_TENANT:
|
||||
# cloud tasks only need the single task beat across all tenants
|
||||
# cloud tasks are system wide and thus only need to be on the beat schedule
|
||||
# once for all tenants
|
||||
get_cloud_tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule",
|
||||
"get_cloud_tasks_to_schedule",
|
||||
)
|
||||
|
||||
cloud_tasks_to_schedule: list[
|
||||
dict[str, Any]
|
||||
] = get_cloud_tasks_to_schedule()
|
||||
cloud_tasks_to_schedule: list[dict[str, Any]] = get_cloud_tasks_to_schedule(
|
||||
beat_multiplier
|
||||
)
|
||||
for task in cloud_tasks_to_schedule:
|
||||
task_name = task["name"]
|
||||
cloud_task = {
|
||||
@ -82,11 +96,14 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
"kwargs": task.get("kwargs", {}),
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
task_logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
cloud_task["options"] = options
|
||||
new_schedule[task_name] = cloud_task
|
||||
|
||||
# regular task beats are multiplied across all tenants
|
||||
# note that currently this just schedules for a single tenant in self hosted
|
||||
# and doesn't do anything in the cloud because it's much more scalable
|
||||
# to schedule a single cloud beat task to dispatch per tenant tasks.
|
||||
get_tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
@ -95,7 +112,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
logger.info(
|
||||
task_logger.debug(
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
continue
|
||||
@ -104,14 +121,14 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
task_name = task["name"]
|
||||
tenant_task_name = f"{task['name']}-{tenant_id}"
|
||||
|
||||
logger.debug(f"Creating task configuration for {tenant_task_name}")
|
||||
task_logger.debug(f"Creating task configuration for {tenant_task_name}")
|
||||
tenant_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(
|
||||
task_logger.debug(
|
||||
f"Adding options to task {tenant_task_name}: {options}"
|
||||
)
|
||||
tenant_task["options"] = options
|
||||
@ -121,44 +138,57 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
|
||||
def _try_updating_schedule(self) -> None:
|
||||
"""Only updates the actual beat schedule on the celery app when it changes"""
|
||||
do_update = False
|
||||
|
||||
logger.info("_try_updating_schedule starting")
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
task_logger.debug("_try_updating_schedule starting")
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} IDs")
|
||||
task_logger.debug(f"Found {len(tenant_ids)} IDs")
|
||||
|
||||
# get current schedule and extract current tenants
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
# there are no more per tenant beat tasks, so comment this out
|
||||
# NOTE: we may not actualy need this scheduler any more and should
|
||||
# test reverting to a regular beat schedule implementation
|
||||
# get potential new state
|
||||
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
|
||||
if beat_multiplier_raw is not None:
|
||||
try:
|
||||
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
|
||||
beat_multiplier = float(beat_multiplier_bytes.decode())
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
|
||||
)
|
||||
|
||||
# current_tenants = set()
|
||||
# for task_name, _ in current_schedule:
|
||||
# task_name = cast(str, task_name)
|
||||
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
# continue
|
||||
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)
|
||||
|
||||
# if "_" in task_name:
|
||||
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# # -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
# current_tenants.add(task_name.split("_")[-1])
|
||||
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
# if the schedule or beat multiplier has changed, update
|
||||
while True:
|
||||
if beat_multiplier != self.last_beat_multiplier:
|
||||
do_update = True
|
||||
break
|
||||
|
||||
# for tenant_id in tenant_ids:
|
||||
# if tenant_id not in current_tenants:
|
||||
# logger.info(f"Processing new tenant: {tenant_id}")
|
||||
if not DynamicTenantScheduler._compare_schedules(
|
||||
current_schedule, new_schedule
|
||||
):
|
||||
do_update = True
|
||||
break
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids)
|
||||
break
|
||||
|
||||
if DynamicTenantScheduler._compare_schedules(current_schedule, new_schedule):
|
||||
logger.info(
|
||||
"_try_updating_schedule: Current schedule is up to date, no changes needed"
|
||||
if not do_update:
|
||||
# exit early if nothing changed
|
||||
task_logger.info(
|
||||
f"_try_updating_schedule - Schedule unchanged: "
|
||||
f"tasks={len(new_schedule)} "
|
||||
f"beat_multiplier={beat_multiplier}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
# schedule needs updating
|
||||
task_logger.debug(
|
||||
"Schedule update required",
|
||||
extra={
|
||||
"new_tasks": len(new_schedule),
|
||||
@ -185,11 +215,19 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
# Ensure changes are persisted
|
||||
self.sync()
|
||||
|
||||
logger.info("_try_updating_schedule: Schedule updated successfully")
|
||||
task_logger.info(
|
||||
f"_try_updating_schedule - Schedule updated: "
|
||||
f"prev_num_tasks={len(current_schedule)} "
|
||||
f"prev_beat_multiplier={self.last_beat_multiplier} "
|
||||
f"tasks={len(new_schedule)} "
|
||||
f"beat_multiplier={beat_multiplier}"
|
||||
)
|
||||
|
||||
self.last_beat_multiplier = beat_multiplier
|
||||
|
||||
@staticmethod
|
||||
def _compare_schedules(schedule1: dict, schedule2: dict) -> bool:
|
||||
"""Compare schedules to determine if an update is needed.
|
||||
"""Compare schedules by task name only to determine if an update is needed.
|
||||
True if equivalent, False if not."""
|
||||
current_tasks = set(name for name, _ in schedule1)
|
||||
new_tasks = set(schedule2.keys())
|
||||
@ -201,7 +239,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("beat_init signal received.")
|
||||
task_logger.info("beat_init signal received.")
|
||||
|
||||
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
@ -18,7 +19,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# hack to slow down task dispatch in the cloud until
|
||||
# we have a better implementation (backpressure, etc)
|
||||
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 8
|
||||
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
beat_task_templates: list[dict] = []
|
||||
@ -121,7 +122,7 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
# constant options for cloud beat task generators
|
||||
task_schedule: timedelta = task["schedule"]
|
||||
cloud_task["schedule"] = task_schedule * CLOUD_BEAT_SCHEDULE_MULTIPLIER
|
||||
cloud_task["schedule"] = task_schedule
|
||||
cloud_task["options"] = {}
|
||||
cloud_task["options"]["priority"] = OnyxCeleryPriority.HIGHEST
|
||||
cloud_task["options"]["expires"] = BEAT_EXPIRES_DEFAULT
|
||||
@ -141,9 +142,9 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
# tasks that only run in the cloud
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
|
||||
# by the DynamicTenantScheduler
|
||||
cloud_tasks_to_schedule: list[dict] = [
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be seen
|
||||
# by the DynamicTenantScheduler as system wide task and not a per tenant task
|
||||
beat_system_tasks: list[dict] = [
|
||||
# cloud specific tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
|
||||
@ -157,18 +158,45 @@ cloud_tasks_to_schedule: list[dict] = [
|
||||
},
|
||||
]
|
||||
|
||||
# generate our cloud and self-hosted beat tasks from the templates
|
||||
for beat_task_template in beat_task_templates:
|
||||
cloud_task = make_cloud_generator_task(beat_task_template)
|
||||
cloud_tasks_to_schedule.append(cloud_task)
|
||||
|
||||
tasks_to_schedule: list[dict] = []
|
||||
if not MULTI_TENANT:
|
||||
tasks_to_schedule = beat_task_templates
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return cloud_tasks_to_schedule
|
||||
def generate_cloud_tasks(
|
||||
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
beat_tasks: system wide tasks that can be sent as is
|
||||
beat_templates: task templates that will be transformed into per tenant tasks via
|
||||
the cloud_beat_task_generator
|
||||
beat_multiplier: a multiplier that can be applied on top of the task schedule
|
||||
to speed up or slow down the task generation rate. useful in production.
|
||||
|
||||
Returns a list of cloud tasks, which consists of incoming tasks + tasks generated
|
||||
from incoming templates.
|
||||
"""
|
||||
|
||||
if beat_multiplier <= 0:
|
||||
raise ValueError("beat_multiplier must be positive!")
|
||||
|
||||
# start with the incoming beat tasks
|
||||
cloud_tasks: list[dict] = copy.deepcopy(beat_tasks)
|
||||
|
||||
# generate our cloud tasks from the templates
|
||||
for beat_template in beat_templates:
|
||||
cloud_task = make_cloud_generator_task(beat_template)
|
||||
cloud_tasks.append(cloud_task)
|
||||
|
||||
# factor in the cloud multiplier
|
||||
for cloud_task in cloud_tasks:
|
||||
cloud_task["schedule"] = cloud_task["schedule"] * beat_multiplier
|
||||
|
||||
return cloud_tasks
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
|
||||
return generate_cloud_tasks(beat_system_tasks, beat_task_templates, beat_multiplier)
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
|
@ -346,6 +346,9 @@ ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud"
|
||||
# the tenant id we use for system level redis operations
|
||||
ONYX_CLOUD_TENANT_ID = "cloud"
|
||||
|
||||
# the redis namespace for runtime variables
|
||||
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
|
||||
|
||||
|
||||
class OnyxCeleryTask:
|
||||
DEFAULT = "celery"
|
||||
|
Loading…
x
Reference in New Issue
Block a user