mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-12 22:09:36 +02:00
* bump fastapi and starlette * bumping llama index and nltk and associated deps * bump to fix python-multipart * bump aiohttp * update package lock for examples/widget * bump black * sentencesplitter has changed namespaces * fix reorder import check, fix missing passlib * update package-lock.json * black formatter updated * reformatted again * change to black compatible reorder * change to black compatible reorder-python-imports fork * fix pytest dependency * black format again * we don't need cdk.txt. update packages to be consistent across all packages --------- Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app> Co-authored-by: Richard Kuo <rkuo@rkuo.com>
117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
from collections.abc import Callable
|
|
from functools import wraps
|
|
from typing import Any
|
|
from typing import cast
|
|
from typing import TypeVar
|
|
|
|
from celery import Task
|
|
from celery.result import AsyncResult
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.db.engine import get_sqlalchemy_engine
|
|
from onyx.db.tasks import mark_task_finished
|
|
from onyx.db.tasks import mark_task_start
|
|
from onyx.db.tasks import register_task
|
|
|
|
|
|
T = TypeVar("T", bound=Callable)
|
|
|
|
|
|
def build_run_wrapper(build_name_fn: Callable[..., str]) -> Callable[[T], T]:
|
|
"""Utility meant to wrap the celery task `run` function in order to
|
|
automatically update our custom `task_queue_jobs` table appropriately"""
|
|
|
|
def wrap_task_fn(task_fn: T) -> T:
|
|
@wraps(task_fn)
|
|
def wrapped_task_fn(*args: list, **kwargs: dict) -> Any:
|
|
engine = get_sqlalchemy_engine()
|
|
|
|
task_name = build_name_fn(*args, **kwargs)
|
|
with Session(engine) as db_session:
|
|
# mark the task as started
|
|
mark_task_start(task_name=task_name, db_session=db_session)
|
|
|
|
result = None
|
|
exception = None
|
|
try:
|
|
result = task_fn(*args, **kwargs)
|
|
except Exception as e:
|
|
exception = e
|
|
|
|
with Session(engine) as db_session:
|
|
mark_task_finished(
|
|
task_name=task_name,
|
|
db_session=db_session,
|
|
success=exception is None,
|
|
)
|
|
|
|
if not exception:
|
|
return result
|
|
else:
|
|
raise exception
|
|
|
|
return cast(T, wrapped_task_fn)
|
|
|
|
return wrap_task_fn
|
|
|
|
|
|
# rough type signature for `apply_async`
|
|
AA = TypeVar("AA", bound=Callable[..., AsyncResult])
|
|
|
|
|
|
def build_apply_async_wrapper(build_name_fn: Callable[..., str]) -> Callable[[AA], AA]:
|
|
"""Utility meant to wrap celery `apply_async` function in order to automatically
|
|
update create an entry in our `task_queue_jobs` table"""
|
|
|
|
def wrapper(fn: AA) -> AA:
|
|
@wraps(fn)
|
|
def wrapped_fn(
|
|
args: tuple | None = None,
|
|
kwargs: dict[str, Any] | None = None,
|
|
*other_args: list,
|
|
**other_kwargs: dict[str, Any],
|
|
) -> Any:
|
|
# `apply_async` takes in args / kwargs directly as arguments
|
|
args_for_build_name = args or tuple()
|
|
kwargs_for_build_name = kwargs or {}
|
|
task_name = build_name_fn(*args_for_build_name, **kwargs_for_build_name)
|
|
with Session(get_sqlalchemy_engine()) as db_session:
|
|
# register_task must come before fn = apply_async or else the task
|
|
# might run mark_task_start (and crash) before the task row exists
|
|
db_task = register_task(task_name, db_session)
|
|
|
|
task = fn(args, kwargs, *other_args, **other_kwargs)
|
|
|
|
# we update the celery task id for diagnostic purposes
|
|
# but it isn't currently used by any code
|
|
db_task.task_id = task.id
|
|
db_session.commit()
|
|
|
|
return task
|
|
|
|
return cast(AA, wrapped_fn)
|
|
|
|
return wrapper
|
|
|
|
|
|
def build_celery_task_wrapper(
|
|
build_name_fn: Callable[..., str],
|
|
) -> Callable[[Task], Task]:
|
|
"""Utility meant to wrap celery task functions in order to automatically
|
|
update our custom `task_queue_jobs` table appropriately.
|
|
|
|
On task creation (e.g. `apply_async`), a row is inserted into the table with
|
|
status `PENDING`.
|
|
On task start, the latest row is updated to have status `STARTED`.
|
|
On task success, the latest row is updated to have status `SUCCESS`.
|
|
On the task raising an unhandled exception, the latest row is updated to have
|
|
status `FAILURE`.
|
|
"""
|
|
|
|
def wrap_task(task: Task) -> Task:
|
|
task.run = build_run_wrapper(build_name_fn)(task.run) # type: ignore
|
|
task.apply_async = build_apply_async_wrapper(build_name_fn)(task.apply_async) # type: ignore
|
|
return task
|
|
|
|
return wrap_task
|