mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-29 23:26:24 +02:00
add k8s probes (#4752)
* add file signals to celery workers * improve probe script * cancel tref --------- Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
@@ -2,10 +2,12 @@ import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sentry_sdk
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
@@ -81,6 +83,19 @@ class TenantAwareTask(Task):
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
def _make_probe_path(probe: str, hostname: str) -> Path:
|
||||
hostname_parts = hostname.split("@")
|
||||
if len(hostname_parts) != 2:
|
||||
raise ValueError(f"hostname could not be split! {hostname=}")
|
||||
|
||||
name = hostname_parts[0]
|
||||
if not name:
|
||||
raise ValueError(f"name cannot be empty! {name=}")
|
||||
|
||||
safe_name = "".join(c for c in name if c.isalnum()).rstrip()
|
||||
return Path(f"/tmp/onyx_k8s_{safe_name}_{probe}.txt")
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
@@ -340,10 +355,23 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
task_logger.info("worker_ready signal received.")
|
||||
|
||||
#
|
||||
# https://medium.com/ambient-innovation/health-checks-for-celery-in-kubernetes-cf3274a3e106
|
||||
# https://github.com/celery/celery/issues/4079#issuecomment-1270085680
|
||||
|
||||
hostname: str = cast(str, sender.hostname)
|
||||
path = _make_probe_path("readiness", hostname)
|
||||
path.touch()
|
||||
logger.info(f"Readiness signal touched at {path}.")
|
||||
|
||||
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
HttpxPool.close_all()
|
||||
|
||||
hostname: str = cast(str, sender.hostname)
|
||||
path = _make_probe_path("readiness", hostname)
|
||||
path.unlink(missing_ok=True)
|
||||
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
@@ -483,3 +511,34 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
|
||||
# File for validating worker liveness
|
||||
class LivenessProbe(bootsteps.StartStopStep):
|
||||
requires = {"celery.worker.components:Timer"}
|
||||
|
||||
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
||||
super().__init__(worker, **kwargs)
|
||||
self.requests: list[Any] = []
|
||||
self.task_tref = None
|
||||
self.path = _make_probe_path("liveness", worker.hostname)
|
||||
|
||||
def start(self, worker: Any) -> None:
|
||||
self.task_tref = worker.timer.call_repeatedly(
|
||||
15.0,
|
||||
self.update_liveness_file,
|
||||
(worker,),
|
||||
priority=10,
|
||||
)
|
||||
|
||||
def stop(self, worker: Any) -> None:
|
||||
self.path.unlink(missing_ok=True)
|
||||
if self.task_tref:
|
||||
self.task_tref.cancel()
|
||||
|
||||
def update_liveness_file(self, worker: Any) -> None:
|
||||
self.path.touch()
|
||||
|
||||
|
||||
def get_bootsteps() -> list[type]:
|
||||
return [LivenessProbe]
|
||||
|
@@ -91,6 +91,10 @@ def on_setup_logging(
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
|
@@ -102,6 +102,10 @@ def on_setup_logging(
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
|
@@ -105,6 +105,10 @@ def on_setup_logging(
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.shared",
|
||||
|
@@ -89,6 +89,10 @@ def on_setup_logging(
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
|
@@ -284,6 +284,10 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
|
||||
celery_app.steps["worker"].add(HubPeriodicTask)
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
|
55
backend/onyx/background/celery/celery_k8s_probe.py
Normal file
55
backend/onyx/background/celery/celery_k8s_probe.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# script to use as a kubernetes readiness / liveness probe
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main_readiness(filename: str) -> int:
|
||||
"""Checks if the file exists."""
|
||||
path = Path(filename)
|
||||
if not path.is_file():
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main_liveness(filename: str) -> int:
|
||||
"""Checks if the file exists AND was recently modified."""
|
||||
path = Path(filename)
|
||||
if not path.is_file():
|
||||
return 1
|
||||
|
||||
stats = path.stat()
|
||||
liveness_timestamp = stats.st_mtime
|
||||
current_timestamp = time.time()
|
||||
time_diff = current_timestamp - liveness_timestamp
|
||||
if time_diff > 60:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code: int
|
||||
|
||||
parser = argparse.ArgumentParser(description="k8s readiness/liveness probe")
|
||||
parser.add_argument(
|
||||
"--probe",
|
||||
type=str,
|
||||
choices=["readiness", "liveness"],
|
||||
help="The type of probe",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument("--filename", help="The filename to watch", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.probe == "readiness":
|
||||
exit_code = main_readiness(args.filename)
|
||||
elif args.probe == "liveness":
|
||||
exit_code = main_liveness(args.filename)
|
||||
else:
|
||||
raise ValueError(f"Unknown probe type: {args.probe}")
|
||||
|
||||
sys.exit(exit_code)
|
Reference in New Issue
Block a user