2024-12-20 06:54:00 +00:00

163 lines
5.6 KiB
Python

# These are helper objects for tracking the keys we need to write in redis
import json
from typing import Any
from typing import cast
from celery import Celery
from redis import Redis
from onyx.background.celery.configs.base import CELERY_SEPARATOR
from onyx.configs.constants import OnyxCeleryPriority
def celery_get_unacked_length(r: Redis) -> int:
"""Checking the unacked queue is useful because a non-zero length tells us there
may be prefetched tasks.
There can be other tasks in here besides indexing tasks, so this is mostly useful
just to see if the task count is non zero.
ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html
"""
length = cast(int, r.hlen("unacked"))
return length
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
"""Gets the set of task id's matching the given queue in the unacked hash.
Unacked entries belonging to the indexing queue are "prefetched", so this gives
us crucial visibility as to what tasks are in that state.
"""
tasks: set[str] = set()
for _, v in r.hscan_iter("unacked"):
v_bytes = cast(bytes, v)
v_str = v_bytes.decode("utf-8")
task = json.loads(v_str)
task_description = task[0]
task_queue = task[2]
if task_queue != queue:
continue
task_id = task_description.get("headers", {}).get("id")
if not task_id:
continue
# if the queue matches and we see the task_id, add it
tasks.add(task_id)
return tasks
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists
used to implement task prioritization.
This operation is not atomic."""
total_length = 0
for i in range(len(OnyxCeleryPriority)):
queue_name = queue
if i > 0:
queue_name += CELERY_SEPARATOR
queue_name += str(i)
length = r.llen(queue_name)
total_length += cast(int, length)
return total_length
def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
"""This is a redis specific way to find a task for a particular queue in redis.
It is priority aware and knows how to look through the multiple redis lists
used to implement task prioritization.
This operation is not atomic.
This is a linear search O(n) ... so be careful using it when the task queues can be larger.
Returns true if the id is in the queue, False if not.
"""
for priority in range(len(OnyxCeleryPriority)):
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
for task in tasks:
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
if task_dict.get("headers", {}).get("id") == task_id:
return True
return False
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
"""Returns a list of current workers containing name_filter, or all workers if
name_filter is None.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
worker_names: list[str] = []
# filter for and create an indexing specific inspect object
inspect = app.control.inspect()
workers: dict[str, Any] = inspect.ping() # type: ignore
if workers:
for worker_name in list(workers.keys()):
# if the name filter not set, return all worker names
if not name_filter:
worker_names.append(worker_name)
continue
# if the name filter is set, return only worker names that contain the name filter
if name_filter not in worker_name:
continue
worker_names.append(worker_name)
return worker_names
def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]:
"""Returns a list of reserved tasks on the specified workers.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
reserved_task_ids: set[str] = set()
inspect = app.control.inspect(destination=worker_names)
# get the list of reserved tasks
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
if reserved_tasks:
for _, task_list in reserved_tasks.items():
for task in task_list:
reserved_task_ids.add(task["id"])
return reserved_task_ids
def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
"""Returns a list of active tasks on the specified workers.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
active_task_ids: set[str] = set()
inspect = app.control.inspect(destination=worker_names)
# get the list of reserved tasks
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
if active_tasks:
for _, task_list in active_tasks.items():
for task in task_list:
active_task_ids.add(task["id"])
return active_task_ids