mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-03 21:52:29 +02:00
203 lines
6.5 KiB
Python
203 lines
6.5 KiB
Python
"""Custom client that works similarly to Dask, but simpler and more lightweight.
|
|
Dask jobs behaved very strangely - they would die all the time, retries would
|
|
not follow the expected behavior, etc.
|
|
|
|
NOTE: cannot use Celery directly due to
|
|
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
|
import multiprocessing as mp
|
|
import sys
|
|
import traceback
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from multiprocessing.context import SpawnProcess
|
|
from typing import Any
|
|
from typing import Literal
|
|
from typing import Optional
|
|
|
|
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
|
|
from onyx.db.engine import SqlEngine
|
|
from onyx.setup import setup_logger
|
|
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
|
from shared_configs.configs import TENANT_ID_PREFIX
|
|
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
class SimpleJobException(Exception):
|
|
"""lets us raise an exception that will return a specific error code"""
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
code: int | None = kwargs.pop("code", None)
|
|
self.code = code
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
JobStatusType = (
|
|
Literal["error"]
|
|
| Literal["finished"]
|
|
| Literal["pending"]
|
|
| Literal["running"]
|
|
| Literal["cancelled"]
|
|
)
|
|
|
|
|
|
def _initializer(
|
|
func: Callable,
|
|
queue: mp.Queue,
|
|
args: list | tuple,
|
|
kwargs: dict[str, Any] | None = None,
|
|
) -> Any:
|
|
"""Initialize the child process with a fresh SQLAlchemy Engine.
|
|
|
|
Based on SQLAlchemy's recommendations to handle multiprocessing:
|
|
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
|
|
"""
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
logger.info("Initializing spawned worker child process.")
|
|
# 1. Get tenant_id from args or fallback to default
|
|
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
|
for arg in reversed(args):
|
|
if isinstance(arg, str) and arg.startswith(TENANT_ID_PREFIX):
|
|
tenant_id = arg
|
|
break
|
|
|
|
# 2. Set the tenant context before running anything
|
|
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
|
|
|
# Reset the engine in the child process
|
|
SqlEngine.reset_engine()
|
|
|
|
# Optionally set a custom app name for database logging purposes
|
|
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
|
|
|
# Initialize a new engine with desired parameters
|
|
SqlEngine.init_engine(
|
|
pool_size=4, max_overflow=12, pool_recycle=60, pool_pre_ping=True
|
|
)
|
|
|
|
# Proceed with executing the target function
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except SimpleJobException as e:
|
|
logger.exception("SimpleJob raised a SimpleJobException")
|
|
error_msg = traceback.format_exc()
|
|
queue.put(error_msg) # Send the exception to the parent process
|
|
|
|
sys.exit(e.code) # use the given exit code
|
|
except Exception:
|
|
logger.exception("SimpleJob raised an exception")
|
|
error_msg = traceback.format_exc()
|
|
queue.put(error_msg) # Send the exception to the parent process
|
|
|
|
sys.exit(255) # use 255 to indicate a generic exception
|
|
finally:
|
|
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
|
|
|
|
|
def _run_in_process(
|
|
func: Callable,
|
|
queue: mp.Queue,
|
|
args: list | tuple,
|
|
kwargs: dict[str, Any] | None = None,
|
|
) -> None:
|
|
_initializer(func, queue, args, kwargs)
|
|
|
|
|
|
@dataclass
|
|
class SimpleJob:
|
|
"""Drop in replacement for `dask.distributed.Future`"""
|
|
|
|
id: int
|
|
process: Optional["SpawnProcess"] = None
|
|
queue: Optional[mp.Queue] = None
|
|
_exception: Optional[str] = None
|
|
|
|
def cancel(self) -> bool:
|
|
return self.release()
|
|
|
|
def release(self) -> bool:
|
|
if self.process is not None and self.process.is_alive():
|
|
self.process.terminate()
|
|
return True
|
|
return False
|
|
|
|
@property
|
|
def status(self) -> JobStatusType:
|
|
if not self.process:
|
|
return "pending"
|
|
elif self.process.is_alive():
|
|
return "running"
|
|
elif self.process.exitcode is None:
|
|
return "cancelled"
|
|
elif self.process.exitcode != 0:
|
|
return "error"
|
|
else:
|
|
return "finished"
|
|
|
|
def done(self) -> bool:
|
|
return (
|
|
self.status == "finished"
|
|
or self.status == "cancelled"
|
|
or self.status == "error"
|
|
)
|
|
|
|
def exception(self) -> str:
|
|
"""Needed to match the Dask API, but not implemented since we don't currently
|
|
have a way to get back the exception information from the child process."""
|
|
|
|
"""Retrieve exception from the multiprocessing queue if available."""
|
|
if self._exception is None and self.queue and not self.queue.empty():
|
|
self._exception = self.queue.get() # Get exception from queue
|
|
|
|
if self._exception:
|
|
return self._exception
|
|
|
|
return f"Job with ID '{self.id}' did not report an exception."
|
|
|
|
|
|
class SimpleJobClient:
|
|
"""Drop in replacement for `dask.distributed.Client`"""
|
|
|
|
def __init__(self, n_workers: int = 1) -> None:
|
|
self.n_workers = n_workers
|
|
self.job_id_counter = 0
|
|
self.jobs: dict[int, SimpleJob] = {}
|
|
|
|
def _cleanup_completed_jobs(self) -> None:
|
|
current_job_ids = list(self.jobs.keys())
|
|
for job_id in current_job_ids:
|
|
job = self.jobs.get(job_id)
|
|
if job and job.done():
|
|
logger.debug(f"Cleaning up job with id: '{job.id}'")
|
|
del self.jobs[job.id]
|
|
|
|
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
|
|
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
|
|
self._cleanup_completed_jobs()
|
|
if len(self.jobs) >= self.n_workers:
|
|
logger.debug(
|
|
f"No available workers to run job. "
|
|
f"Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
|
|
)
|
|
return None
|
|
|
|
job_id = self.job_id_counter
|
|
self.job_id_counter += 1
|
|
|
|
# this approach allows us to always "spawn" a new process regardless of
|
|
# get_start_method's current setting
|
|
ctx = mp.get_context("spawn")
|
|
queue = ctx.Queue()
|
|
process = ctx.Process(
|
|
target=_run_in_process, args=(func, queue, args), daemon=True
|
|
)
|
|
job = SimpleJob(id=job_id, process=process, queue=queue)
|
|
process.start()
|
|
|
|
self.jobs[job_id] = job
|
|
|
|
return job
|