mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-30 17:50:27 +02:00
163 lines
5.6 KiB
Python
163 lines
5.6 KiB
Python
# These are helper objects for tracking the keys we need to write in redis
|
|
import json
|
|
from typing import Any
|
|
from typing import cast
|
|
|
|
from celery import Celery
|
|
from redis import Redis
|
|
|
|
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
|
from onyx.configs.constants import OnyxCeleryPriority
|
|
|
|
|
|
def celery_get_unacked_length(r: Redis) -> int:
|
|
"""Checking the unacked queue is useful because a non-zero length tells us there
|
|
may be prefetched tasks.
|
|
|
|
There can be other tasks in here besides indexing tasks, so this is mostly useful
|
|
just to see if the task count is non zero.
|
|
|
|
ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html
|
|
"""
|
|
length = cast(int, r.hlen("unacked"))
|
|
return length
|
|
|
|
|
|
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
|
|
"""Gets the set of task id's matching the given queue in the unacked hash.
|
|
|
|
Unacked entries belonging to the indexing queue are "prefetched", so this gives
|
|
us crucial visibility as to what tasks are in that state.
|
|
"""
|
|
tasks: set[str] = set()
|
|
|
|
for _, v in r.hscan_iter("unacked"):
|
|
v_bytes = cast(bytes, v)
|
|
v_str = v_bytes.decode("utf-8")
|
|
task = json.loads(v_str)
|
|
|
|
task_description = task[0]
|
|
task_queue = task[2]
|
|
|
|
if task_queue != queue:
|
|
continue
|
|
|
|
task_id = task_description.get("headers", {}).get("id")
|
|
if not task_id:
|
|
continue
|
|
|
|
# if the queue matches and we see the task_id, add it
|
|
tasks.add(task_id)
|
|
return tasks
|
|
|
|
|
|
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
|
"""This is a redis specific way to get the length of a celery queue.
|
|
It is priority aware and knows how to count across the multiple redis lists
|
|
used to implement task prioritization.
|
|
This operation is not atomic."""
|
|
total_length = 0
|
|
for i in range(len(OnyxCeleryPriority)):
|
|
queue_name = queue
|
|
if i > 0:
|
|
queue_name += CELERY_SEPARATOR
|
|
queue_name += str(i)
|
|
|
|
length = r.llen(queue_name)
|
|
total_length += cast(int, length)
|
|
|
|
return total_length
|
|
|
|
|
|
def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
|
"""This is a redis specific way to find a task for a particular queue in redis.
|
|
It is priority aware and knows how to look through the multiple redis lists
|
|
used to implement task prioritization.
|
|
This operation is not atomic.
|
|
|
|
This is a linear search O(n) ... so be careful using it when the task queues can be larger.
|
|
|
|
Returns true if the id is in the queue, False if not.
|
|
"""
|
|
for priority in range(len(OnyxCeleryPriority)):
|
|
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
|
|
|
|
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
|
|
for task in tasks:
|
|
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
|
|
if task_dict.get("headers", {}).get("id") == task_id:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
|
"""Returns a list of current workers containing name_filter, or all workers if
|
|
name_filter is None.
|
|
|
|
We've empirically discovered that the celery inspect API is potentially unstable
|
|
and may hang or return empty results when celery is under load. Suggest using this
|
|
more to debug and troubleshoot than in production code.
|
|
"""
|
|
worker_names: list[str] = []
|
|
|
|
# filter for and create an indexing specific inspect object
|
|
inspect = app.control.inspect()
|
|
workers: dict[str, Any] = inspect.ping() # type: ignore
|
|
if workers:
|
|
for worker_name in list(workers.keys()):
|
|
# if the name filter not set, return all worker names
|
|
if not name_filter:
|
|
worker_names.append(worker_name)
|
|
continue
|
|
|
|
# if the name filter is set, return only worker names that contain the name filter
|
|
if name_filter not in worker_name:
|
|
continue
|
|
|
|
worker_names.append(worker_name)
|
|
|
|
return worker_names
|
|
|
|
|
|
def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]:
|
|
"""Returns a list of reserved tasks on the specified workers.
|
|
|
|
We've empirically discovered that the celery inspect API is potentially unstable
|
|
and may hang or return empty results when celery is under load. Suggest using this
|
|
more to debug and troubleshoot than in production code.
|
|
"""
|
|
reserved_task_ids: set[str] = set()
|
|
|
|
inspect = app.control.inspect(destination=worker_names)
|
|
|
|
# get the list of reserved tasks
|
|
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
|
|
if reserved_tasks:
|
|
for _, task_list in reserved_tasks.items():
|
|
for task in task_list:
|
|
reserved_task_ids.add(task["id"])
|
|
|
|
return reserved_task_ids
|
|
|
|
|
|
def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
|
|
"""Returns a list of active tasks on the specified workers.
|
|
|
|
We've empirically discovered that the celery inspect API is potentially unstable
|
|
and may hang or return empty results when celery is under load. Suggest using this
|
|
more to debug and troubleshoot than in production code.
|
|
"""
|
|
active_task_ids: set[str] = set()
|
|
|
|
inspect = app.control.inspect(destination=worker_names)
|
|
|
|
# get the list of reserved tasks
|
|
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
|
|
if active_tasks:
|
|
for _, task_list in active_tasks.items():
|
|
for task in task_list:
|
|
active_task_ids.add(task["id"])
|
|
|
|
return active_task_ids
|