mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-04 00:40:44 +02:00
Feature/celery refactor (#2813)
* fresh indexing feature branch * cherry pick test * Revert "cherry pick test" This reverts commit 2a624220687affdda3de347e30f2011136f64bda. * set multitenant so that vespa fields match when indexing * cleanup pass * mypy * pass through env var to control celery indexing concurrency * comments on task kickoff and some logging improvements * disentangle configuration for different workers and beats. * use get_session_with_tenant * comment out all of update.py * rename to RedisConnectorIndexingFenceData * first check num_indexing_workers * refactor RedisConnectorIndexingFenceData * comment out on_worker_process_init * missed a file * scope db sessions to short lengths * update launch.json template * fix types * code review
This commit is contained in:
parent
eccec6ab7c
commit
9105f95d13
300
.vscode/launch.template.jsonc
vendored
300
.vscode/launch.template.jsonc
vendored
@ -6,19 +6,69 @@
|
|||||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"compounds": [
|
"compounds": [
|
||||||
|
{
|
||||||
|
// Dummy entry used to label the group
|
||||||
|
"name": "--- Compound ---",
|
||||||
|
"configurations": [
|
||||||
|
"--- Individual ---"
|
||||||
|
],
|
||||||
|
"presentation": {
|
||||||
|
"group": "1",
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Run All Danswer Services",
|
"name": "Run All Danswer Services",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
"Web Server",
|
"Web Server",
|
||||||
"Model Server",
|
"Model Server",
|
||||||
"API Server",
|
"API Server",
|
||||||
"Indexing",
|
"Slack Bot",
|
||||||
"Background Jobs",
|
"Celery primary",
|
||||||
"Slack Bot"
|
"Celery light",
|
||||||
]
|
"Celery heavy",
|
||||||
}
|
"Celery indexing",
|
||||||
|
"Celery beat",
|
||||||
|
],
|
||||||
|
"presentation": {
|
||||||
|
"group": "1",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Web / Model / API",
|
||||||
|
"configurations": [
|
||||||
|
"Web Server",
|
||||||
|
"Model Server",
|
||||||
|
"API Server",
|
||||||
|
],
|
||||||
|
"presentation": {
|
||||||
|
"group": "1",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Celery (all)",
|
||||||
|
"configurations": [
|
||||||
|
"Celery primary",
|
||||||
|
"Celery light",
|
||||||
|
"Celery heavy",
|
||||||
|
"Celery indexing",
|
||||||
|
"Celery beat"
|
||||||
|
],
|
||||||
|
"presentation": {
|
||||||
|
"group": "1",
|
||||||
|
}
|
||||||
|
}
|
||||||
],
|
],
|
||||||
"configurations": [
|
"configurations": [
|
||||||
|
{
|
||||||
|
// Dummy entry used to label the group
|
||||||
|
"name": "--- Individual ---",
|
||||||
|
"type": "node",
|
||||||
|
"request": "launch",
|
||||||
|
"presentation": {
|
||||||
|
"group": "2",
|
||||||
|
"order": 0
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Web Server",
|
"name": "Web Server",
|
||||||
"type": "node",
|
"type": "node",
|
||||||
@ -29,7 +79,11 @@
|
|||||||
"runtimeArgs": [
|
"runtimeArgs": [
|
||||||
"run", "dev"
|
"run", "dev"
|
||||||
],
|
],
|
||||||
"console": "integratedTerminal"
|
"presentation": {
|
||||||
|
"group": "2",
|
||||||
|
},
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"consoleTitle": "Web Server Console"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Model Server",
|
"name": "Model Server",
|
||||||
@ -48,7 +102,11 @@
|
|||||||
"--reload",
|
"--reload",
|
||||||
"--port",
|
"--port",
|
||||||
"9000"
|
"9000"
|
||||||
]
|
],
|
||||||
|
"presentation": {
|
||||||
|
"group": "2",
|
||||||
|
},
|
||||||
|
"consoleTitle": "Model Server Console"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "API Server",
|
"name": "API Server",
|
||||||
@ -68,43 +126,13 @@
|
|||||||
"--reload",
|
"--reload",
|
||||||
"--port",
|
"--port",
|
||||||
"8080"
|
"8080"
|
||||||
]
|
],
|
||||||
|
"presentation": {
|
||||||
|
"group": "2",
|
||||||
|
},
|
||||||
|
"consoleTitle": "API Server Console"
|
||||||
},
|
},
|
||||||
{
|
// For the listener to access the Slack API,
|
||||||
"name": "Indexing",
|
|
||||||
"consoleName": "Indexing",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "danswer/background/update.py",
|
|
||||||
"cwd": "${workspaceFolder}/backend",
|
|
||||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
|
||||||
"env": {
|
|
||||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
|
||||||
"LOG_LEVEL": "DEBUG",
|
|
||||||
"PYTHONUNBUFFERED": "1",
|
|
||||||
"PYTHONPATH": "."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
|
|
||||||
{
|
|
||||||
"name": "Background Jobs",
|
|
||||||
"consoleName": "Background Jobs",
|
|
||||||
"type": "debugpy",
|
|
||||||
"request": "launch",
|
|
||||||
"program": "scripts/dev_run_background_jobs.py",
|
|
||||||
"cwd": "${workspaceFolder}/backend",
|
|
||||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
|
||||||
"env": {
|
|
||||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
|
||||||
"LOG_LEVEL": "DEBUG",
|
|
||||||
"PYTHONUNBUFFERED": "1",
|
|
||||||
"PYTHONPATH": "."
|
|
||||||
},
|
|
||||||
"args": [
|
|
||||||
"--no-indexing"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
// For the listner to access the Slack API,
|
|
||||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||||
{
|
{
|
||||||
"name": "Slack Bot",
|
"name": "Slack Bot",
|
||||||
@ -118,7 +146,151 @@
|
|||||||
"LOG_LEVEL": "DEBUG",
|
"LOG_LEVEL": "DEBUG",
|
||||||
"PYTHONUNBUFFERED": "1",
|
"PYTHONUNBUFFERED": "1",
|
||||||
"PYTHONPATH": "."
|
"PYTHONPATH": "."
|
||||||
}
|
},
|
||||||
|
"presentation": {
|
||||||
|
"group": "2",
|
||||||
|
},
|
||||||
|
"consoleTitle": "Slack Bot Console"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Celery primary",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "celery",
|
||||||
|
"cwd": "${workspaceFolder}/backend",
|
||||||
|
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||||
|
"env": {
|
||||||
|
"LOG_LEVEL": "INFO",
|
||||||
|
"PYTHONUNBUFFERED": "1",
|
||||||
|
"PYTHONPATH": "."
|
||||||
|
},
|
||||||
|
"args": [
|
||||||
|
"-A",
|
||||||
|
"danswer.background.celery.versioned_apps.primary",
|
||||||
|
"worker",
|
||||||
|
"--pool=threads",
|
||||||
|
"--concurrency=4",
|
||||||
|
"--prefetch-multiplier=1",
|
||||||
|
"--loglevel=INFO",
|
||||||
|
"--hostname=primary@%n",
|
||||||
|
"-Q",
|
||||||
|
"celery",
|
||||||
|
],
|
||||||
|
"presentation": {
|
||||||
|
"group": "2",
|
||||||
|
},
|
||||||
|
"consoleTitle": "Celery primary Console"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Celery light",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "celery",
|
||||||
|
"cwd": "${workspaceFolder}/backend",
|
||||||
|
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||||
|
"env": {
|
||||||
|
"LOG_LEVEL": "INFO",
|
||||||
|
"PYTHONUNBUFFERED": "1",
|
||||||
|
"PYTHONPATH": "."
|
||||||
|
},
|
||||||
|
"args": [
|
||||||
|
"-A",
|
||||||
|
"danswer.background.celery.versioned_apps.light",
|
||||||
|
"worker",
|
||||||
|
"--pool=threads",
|
||||||
|
"--concurrency=64",
|
||||||
|
"--prefetch-multiplier=8",
|
||||||
|
"--loglevel=INFO",
|
||||||
|
"--hostname=light@%n",
|
||||||
|
"-Q",
|
||||||
|
"vespa_metadata_sync,connector_deletion",
|
||||||
|
],
|
||||||
|
"presentation": {
|
||||||
|
"group": "2",
|
||||||
|
},
|
||||||
|
"consoleTitle": "Celery light Console"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Celery heavy",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "celery",
|
||||||
|
"cwd": "${workspaceFolder}/backend",
|
||||||
|
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||||
|
"env": {
|
||||||
|
"LOG_LEVEL": "INFO",
|
||||||
|
"PYTHONUNBUFFERED": "1",
|
||||||
|
"PYTHONPATH": "."
|
||||||
|
},
|
||||||
|
"args": [
|
||||||
|
"-A",
|
||||||
|
"danswer.background.celery.versioned_apps.heavy",
|
||||||
|
"worker",
|
||||||
|
"--pool=threads",
|
||||||
|
"--concurrency=4",
|
||||||
|
"--prefetch-multiplier=1",
|
||||||
|
"--loglevel=INFO",
|
||||||
|
"--hostname=heavy@%n",
|
||||||
|
"-Q",
|
||||||
|
"connector_pruning",
|
||||||
|
],
|
||||||
|
"presentation": {
|
||||||
|
"group": "2",
|
||||||
|
},
|
||||||
|
"consoleTitle": "Celery heavy Console"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Celery indexing",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "celery",
|
||||||
|
"cwd": "${workspaceFolder}/backend",
|
||||||
|
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||||
|
"env": {
|
||||||
|
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||||
|
"LOG_LEVEL": "DEBUG",
|
||||||
|
"PYTHONUNBUFFERED": "1",
|
||||||
|
"PYTHONPATH": "."
|
||||||
|
},
|
||||||
|
"args": [
|
||||||
|
"-A",
|
||||||
|
"danswer.background.celery.versioned_apps.indexing",
|
||||||
|
"worker",
|
||||||
|
"--pool=threads",
|
||||||
|
"--concurrency=1",
|
||||||
|
"--prefetch-multiplier=1",
|
||||||
|
"--loglevel=INFO",
|
||||||
|
"--hostname=indexing@%n",
|
||||||
|
"-Q",
|
||||||
|
"connector_indexing",
|
||||||
|
],
|
||||||
|
"presentation": {
|
||||||
|
"group": "2",
|
||||||
|
},
|
||||||
|
"consoleTitle": "Celery indexing Console"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Celery beat",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "celery",
|
||||||
|
"cwd": "${workspaceFolder}/backend",
|
||||||
|
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||||
|
"env": {
|
||||||
|
"LOG_LEVEL": "DEBUG",
|
||||||
|
"PYTHONUNBUFFERED": "1",
|
||||||
|
"PYTHONPATH": "."
|
||||||
|
},
|
||||||
|
"args": [
|
||||||
|
"-A",
|
||||||
|
"danswer.background.celery.versioned_apps.beat",
|
||||||
|
"beat",
|
||||||
|
"--loglevel=INFO",
|
||||||
|
],
|
||||||
|
"presentation": {
|
||||||
|
"group": "2",
|
||||||
|
},
|
||||||
|
"consoleTitle": "Celery beat Console"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Pytest",
|
"name": "Pytest",
|
||||||
@ -137,8 +309,22 @@
|
|||||||
"-v"
|
"-v"
|
||||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||||
]
|
],
|
||||||
|
"presentation": {
|
||||||
|
"group": "2",
|
||||||
|
},
|
||||||
|
"consoleTitle": "Pytest Console"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
// Dummy entry used to label the group
|
||||||
|
"name": "--- Tasks ---",
|
||||||
|
"type": "node",
|
||||||
|
"request": "launch",
|
||||||
|
"presentation": {
|
||||||
|
"group": "3",
|
||||||
|
"order": 0
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "Clear and Restart External Volumes and Containers",
|
"name": "Clear and Restart External Volumes and Containers",
|
||||||
"type": "node",
|
"type": "node",
|
||||||
@ -147,7 +333,27 @@
|
|||||||
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
|
"runtimeArgs": ["${workspaceFolder}/backend/scripts/restart_containers.sh"],
|
||||||
"cwd": "${workspaceFolder}",
|
"cwd": "${workspaceFolder}",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"stopOnEntry": true
|
"stopOnEntry": true,
|
||||||
}
|
"presentation": {
|
||||||
|
"group": "3",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Celery jobs launched through a single background script (legacy)
|
||||||
|
// Recommend using the "Celery (all)" compound launch instead.
|
||||||
|
"name": "Background Jobs",
|
||||||
|
"consoleName": "Background Jobs",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "scripts/dev_run_background_jobs.py",
|
||||||
|
"cwd": "${workspaceFolder}/backend",
|
||||||
|
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||||
|
"env": {
|
||||||
|
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||||
|
"LOG_LEVEL": "DEBUG",
|
||||||
|
"PYTHONUNBUFFERED": "1",
|
||||||
|
"PYTHONPATH": "."
|
||||||
|
},
|
||||||
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
256
backend/danswer/background/celery/apps/app_base.py
Normal file
256
backend/danswer/background/celery/apps/app_base.py
Normal file
@ -0,0 +1,256 @@
|
|||||||
|
import logging
|
||||||
|
import multiprocessing
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import sentry_sdk
|
||||||
|
from celery import Task
|
||||||
|
from celery.exceptions import WorkerShutdown
|
||||||
|
from celery.states import READY_STATES
|
||||||
|
from celery.utils.log import get_task_logger
|
||||||
|
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||||
|
|
||||||
|
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||||
|
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||||
|
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||||
|
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||||
|
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||||
|
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||||
|
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||||
|
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||||
|
from danswer.configs.constants import DanswerRedisLocks
|
||||||
|
from danswer.redis.redis_pool import get_redis_client
|
||||||
|
from danswer.utils.logger import ColoredFormatter
|
||||||
|
from danswer.utils.logger import PlainFormatter
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from shared_configs.configs import SENTRY_DSN
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
task_logger = get_task_logger(__name__)
|
||||||
|
|
||||||
|
if SENTRY_DSN:
|
||||||
|
sentry_sdk.init(
|
||||||
|
dsn=SENTRY_DSN,
|
||||||
|
integrations=[CeleryIntegration()],
|
||||||
|
traces_sample_rate=0.5,
|
||||||
|
)
|
||||||
|
logger.info("Sentry initialized")
|
||||||
|
else:
|
||||||
|
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||||
|
|
||||||
|
|
||||||
|
def on_task_prerun(
|
||||||
|
sender: Any | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
task: Task | None = None,
|
||||||
|
args: tuple | None = None,
|
||||||
|
kwargs: dict | None = None,
|
||||||
|
**kwds: Any,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def on_task_postrun(
|
||||||
|
sender: Any | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
task: Task | None = None,
|
||||||
|
args: tuple | None = None,
|
||||||
|
kwargs: dict | None = None,
|
||||||
|
retval: Any | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
**kwds: Any,
|
||||||
|
) -> None:
|
||||||
|
"""We handle this signal in order to remove completed tasks
|
||||||
|
from their respective tasksets. This allows us to track the progress of document set
|
||||||
|
and user group syncs.
|
||||||
|
|
||||||
|
This function runs after any task completes (both success and failure)
|
||||||
|
Note that this signal does not fire on a task that failed to complete and is going
|
||||||
|
to be retried.
|
||||||
|
|
||||||
|
This also does not fire if a worker with acks_late=False crashes (which all of our
|
||||||
|
long running workers are)
|
||||||
|
"""
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
|
||||||
|
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
|
||||||
|
|
||||||
|
if state not in READY_STATES:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
r = get_redis_client()
|
||||||
|
|
||||||
|
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||||
|
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
if task_id.startswith(RedisDocumentSet.PREFIX):
|
||||||
|
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
||||||
|
if document_set_id is not None:
|
||||||
|
rds = RedisDocumentSet(int(document_set_id))
|
||||||
|
r.srem(rds.taskset_key, task_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
if task_id.startswith(RedisUserGroup.PREFIX):
|
||||||
|
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
||||||
|
if usergroup_id is not None:
|
||||||
|
rug = RedisUserGroup(int(usergroup_id))
|
||||||
|
r.srem(rug.taskset_key, task_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
if task_id.startswith(RedisConnectorDeletion.PREFIX):
|
||||||
|
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
|
||||||
|
if cc_pair_id is not None:
|
||||||
|
rcd = RedisConnectorDeletion(int(cc_pair_id))
|
||||||
|
r.srem(rcd.taskset_key, task_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
|
||||||
|
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
|
||||||
|
if cc_pair_id is not None:
|
||||||
|
rcp = RedisConnectorPruning(int(cc_pair_id))
|
||||||
|
r.srem(rcp.taskset_key, task_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||||
|
"""The first signal sent on celery worker startup"""
|
||||||
|
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||||
|
r = get_redis_client()
|
||||||
|
|
||||||
|
WAIT_INTERVAL = 5
|
||||||
|
WAIT_LIMIT = 60
|
||||||
|
|
||||||
|
time_start = time.monotonic()
|
||||||
|
logger.info("Redis: Readiness check starting.")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if r.ping():
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
time_elapsed = time.monotonic() - time_start
|
||||||
|
logger.info(
|
||||||
|
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||||
|
)
|
||||||
|
if time_elapsed > WAIT_LIMIT:
|
||||||
|
msg = (
|
||||||
|
f"Redis: Readiness check did not succeed within the timeout "
|
||||||
|
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||||
|
)
|
||||||
|
logger.error(msg)
|
||||||
|
raise WorkerShutdown(msg)
|
||||||
|
|
||||||
|
time.sleep(WAIT_INTERVAL)
|
||||||
|
|
||||||
|
logger.info("Redis: Readiness check succeeded. Continuing...")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||||
|
r = get_redis_client()
|
||||||
|
|
||||||
|
WAIT_INTERVAL = 5
|
||||||
|
WAIT_LIMIT = 60
|
||||||
|
|
||||||
|
logger.info("Running as a secondary celery worker.")
|
||||||
|
logger.info("Waiting for primary worker to be ready...")
|
||||||
|
time_start = time.monotonic()
|
||||||
|
while True:
|
||||||
|
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
|
||||||
|
break
|
||||||
|
|
||||||
|
time.monotonic()
|
||||||
|
time_elapsed = time.monotonic() - time_start
|
||||||
|
logger.info(
|
||||||
|
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||||
|
)
|
||||||
|
if time_elapsed > WAIT_LIMIT:
|
||||||
|
msg = (
|
||||||
|
f"Primary worker was not ready within the timeout. "
|
||||||
|
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||||
|
)
|
||||||
|
logger.error(msg)
|
||||||
|
raise WorkerShutdown(msg)
|
||||||
|
|
||||||
|
time.sleep(WAIT_INTERVAL)
|
||||||
|
|
||||||
|
logger.info("Wait for primary worker completed successfully. Continuing...")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||||
|
task_logger.info("worker_ready signal received.")
|
||||||
|
|
||||||
|
|
||||||
|
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||||
|
if not celery_is_worker_primary(sender):
|
||||||
|
return
|
||||||
|
|
||||||
|
if not sender.primary_worker_lock:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Releasing primary worker lock.")
|
||||||
|
lock = sender.primary_worker_lock
|
||||||
|
if lock.owned():
|
||||||
|
lock.release()
|
||||||
|
sender.primary_worker_lock = None
|
||||||
|
|
||||||
|
|
||||||
|
def on_setup_logging(
|
||||||
|
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
# TODO: could unhardcode format and colorize and accept these as options from
|
||||||
|
# celery's config
|
||||||
|
|
||||||
|
# reformats the root logger
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
|
||||||
|
root_handler = logging.StreamHandler() # Set up a handler for the root logger
|
||||||
|
root_formatter = ColoredFormatter(
|
||||||
|
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||||
|
)
|
||||||
|
root_handler.setFormatter(root_formatter)
|
||||||
|
root_logger.addHandler(root_handler) # Apply the handler to the root logger
|
||||||
|
|
||||||
|
if logfile:
|
||||||
|
root_file_handler = logging.FileHandler(logfile)
|
||||||
|
root_file_formatter = PlainFormatter(
|
||||||
|
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||||
|
)
|
||||||
|
root_file_handler.setFormatter(root_file_formatter)
|
||||||
|
root_logger.addHandler(root_file_handler)
|
||||||
|
|
||||||
|
root_logger.setLevel(loglevel)
|
||||||
|
|
||||||
|
# reformats celery's task logger
|
||||||
|
task_formatter = CeleryTaskColoredFormatter(
|
||||||
|
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||||
|
)
|
||||||
|
task_handler = logging.StreamHandler() # Set up a handler for the task logger
|
||||||
|
task_handler.setFormatter(task_formatter)
|
||||||
|
task_logger.addHandler(task_handler) # Apply the handler to the task logger
|
||||||
|
|
||||||
|
if logfile:
|
||||||
|
task_file_handler = logging.FileHandler(logfile)
|
||||||
|
task_file_formatter = CeleryTaskPlainFormatter(
|
||||||
|
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||||
|
)
|
||||||
|
task_file_handler.setFormatter(task_file_formatter)
|
||||||
|
task_logger.addHandler(task_file_handler)
|
||||||
|
|
||||||
|
task_logger.setLevel(loglevel)
|
||||||
|
task_logger.propagate = False
|
99
backend/danswer/background/celery/apps/beat.py
Normal file
99
backend/danswer/background/celery/apps/beat.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
from datetime import timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from celery import Celery
|
||||||
|
from celery import signals
|
||||||
|
from celery.signals import beat_init
|
||||||
|
|
||||||
|
import danswer.background.celery.apps.app_base as app_base
|
||||||
|
from danswer.configs.constants import DanswerCeleryPriority
|
||||||
|
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||||
|
from danswer.db.engine import get_all_tenant_ids
|
||||||
|
from danswer.db.engine import SqlEngine
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
celery_app = Celery(__name__)
|
||||||
|
celery_app.config_from_object("danswer.background.celery.configs.beat")
|
||||||
|
|
||||||
|
|
||||||
|
@beat_init.connect
|
||||||
|
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||||
|
logger.info("beat_init signal received.")
|
||||||
|
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||||
|
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||||
|
app_base.wait_for_redis(sender, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@signals.setup_logging.connect
|
||||||
|
def on_setup_logging(
|
||||||
|
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
#####
|
||||||
|
# Celery Beat (Periodic Tasks) Settings
|
||||||
|
#####
|
||||||
|
|
||||||
|
tenant_ids = get_all_tenant_ids()
|
||||||
|
|
||||||
|
tasks_to_schedule = [
|
||||||
|
{
|
||||||
|
"name": "check-for-vespa-sync",
|
||||||
|
"task": "check_for_vespa_sync_task",
|
||||||
|
"schedule": timedelta(seconds=5),
|
||||||
|
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "check-for-connector-deletion",
|
||||||
|
"task": "check_for_connector_deletion_task",
|
||||||
|
"schedule": timedelta(seconds=60),
|
||||||
|
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "check-for-indexing",
|
||||||
|
"task": "check_for_indexing",
|
||||||
|
"schedule": timedelta(seconds=10),
|
||||||
|
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "check-for-prune",
|
||||||
|
"task": "check_for_pruning",
|
||||||
|
"schedule": timedelta(seconds=10),
|
||||||
|
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "kombu-message-cleanup",
|
||||||
|
"task": "kombu_message_cleanup_task",
|
||||||
|
"schedule": timedelta(seconds=3600),
|
||||||
|
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "monitor-vespa-sync",
|
||||||
|
"task": "monitor_vespa_sync",
|
||||||
|
"schedule": timedelta(seconds=5),
|
||||||
|
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build the celery beat schedule dynamically
|
||||||
|
beat_schedule = {}
|
||||||
|
|
||||||
|
for tenant_id in tenant_ids:
|
||||||
|
for task in tasks_to_schedule:
|
||||||
|
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
||||||
|
beat_schedule[task_name] = {
|
||||||
|
"task": task["task"],
|
||||||
|
"schedule": task["schedule"],
|
||||||
|
"options": task["options"],
|
||||||
|
"args": (tenant_id,), # Must pass tenant_id as an argument
|
||||||
|
}
|
||||||
|
|
||||||
|
# Include any existing beat schedules
|
||||||
|
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
||||||
|
beat_schedule.update(existing_beat_schedule)
|
||||||
|
|
||||||
|
# Update the Celery app configuration once
|
||||||
|
celery_app.conf.beat_schedule = beat_schedule
|
88
backend/danswer/background/celery/apps/heavy.py
Normal file
88
backend/danswer/background/celery/apps/heavy.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import multiprocessing
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from celery import Celery
|
||||||
|
from celery import signals
|
||||||
|
from celery import Task
|
||||||
|
from celery.signals import celeryd_init
|
||||||
|
from celery.signals import worker_init
|
||||||
|
from celery.signals import worker_ready
|
||||||
|
from celery.signals import worker_shutdown
|
||||||
|
|
||||||
|
import danswer.background.celery.apps.app_base as app_base
|
||||||
|
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||||
|
from danswer.db.engine import SqlEngine
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
celery_app = Celery(__name__)
|
||||||
|
celery_app.config_from_object("danswer.background.celery.configs.heavy")
|
||||||
|
|
||||||
|
|
||||||
|
@signals.task_prerun.connect
|
||||||
|
def on_task_prerun(
|
||||||
|
sender: Any | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
task: Task | None = None,
|
||||||
|
args: tuple | None = None,
|
||||||
|
kwargs: dict | None = None,
|
||||||
|
**kwds: Any,
|
||||||
|
) -> None:
|
||||||
|
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||||
|
|
||||||
|
|
||||||
|
@signals.task_postrun.connect
|
||||||
|
def on_task_postrun(
|
||||||
|
sender: Any | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
task: Task | None = None,
|
||||||
|
args: tuple | None = None,
|
||||||
|
kwargs: dict | None = None,
|
||||||
|
retval: Any | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
**kwds: Any,
|
||||||
|
) -> None:
|
||||||
|
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||||
|
|
||||||
|
|
||||||
|
@celeryd_init.connect
|
||||||
|
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||||
|
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_init.connect
|
||||||
|
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||||
|
logger.info("worker_init signal received.")
|
||||||
|
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||||
|
|
||||||
|
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||||
|
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||||
|
|
||||||
|
app_base.wait_for_redis(sender, **kwargs)
|
||||||
|
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_ready.connect
|
||||||
|
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||||
|
app_base.on_worker_ready(sender, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_shutdown.connect
|
||||||
|
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||||
|
app_base.on_worker_shutdown(sender, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@signals.setup_logging.connect
|
||||||
|
def on_setup_logging(
|
||||||
|
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
celery_app.autodiscover_tasks(
|
||||||
|
[
|
||||||
|
"danswer.background.celery.tasks.pruning",
|
||||||
|
]
|
||||||
|
)
|
116
backend/danswer/background/celery/apps/indexing.py
Normal file
116
backend/danswer/background/celery/apps/indexing.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import multiprocessing
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from celery import Celery
|
||||||
|
from celery import signals
|
||||||
|
from celery import Task
|
||||||
|
from celery.signals import celeryd_init
|
||||||
|
from celery.signals import worker_init
|
||||||
|
from celery.signals import worker_ready
|
||||||
|
from celery.signals import worker_shutdown
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
import danswer.background.celery.apps.app_base as app_base
|
||||||
|
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||||
|
from danswer.db.engine import SqlEngine
|
||||||
|
from danswer.db.search_settings import get_current_search_settings
|
||||||
|
from danswer.db.swap_index import check_index_swap
|
||||||
|
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||||
|
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||||
|
from shared_configs.configs import MODEL_SERVER_PORT
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
celery_app = Celery(__name__)
|
||||||
|
celery_app.config_from_object("danswer.background.celery.configs.indexing")
|
||||||
|
|
||||||
|
|
||||||
|
@signals.task_prerun.connect
|
||||||
|
def on_task_prerun(
|
||||||
|
sender: Any | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
task: Task | None = None,
|
||||||
|
args: tuple | None = None,
|
||||||
|
kwargs: dict | None = None,
|
||||||
|
**kwds: Any,
|
||||||
|
) -> None:
|
||||||
|
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||||
|
|
||||||
|
|
||||||
|
@signals.task_postrun.connect
|
||||||
|
def on_task_postrun(
|
||||||
|
sender: Any | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
task: Task | None = None,
|
||||||
|
args: tuple | None = None,
|
||||||
|
kwargs: dict | None = None,
|
||||||
|
retval: Any | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
**kwds: Any,
|
||||||
|
) -> None:
|
||||||
|
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||||
|
|
||||||
|
|
||||||
|
@celeryd_init.connect
|
||||||
|
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||||
|
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_init.connect
|
||||||
|
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||||
|
logger.info("worker_init signal received.")
|
||||||
|
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||||
|
|
||||||
|
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||||
|
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||||
|
|
||||||
|
# TODO: why is this necessary for the indexer to do?
|
||||||
|
engine = SqlEngine.get_engine()
|
||||||
|
with Session(engine) as db_session:
|
||||||
|
check_index_swap(db_session=db_session)
|
||||||
|
search_settings = get_current_search_settings(db_session)
|
||||||
|
|
||||||
|
# So that the first time users aren't surprised by really slow speed of first
|
||||||
|
# batch of documents indexed
|
||||||
|
if search_settings.provider_type is None:
|
||||||
|
logger.notice("Running a first inference to warm up embedding model")
|
||||||
|
embedding_model = EmbeddingModel.from_db_model(
|
||||||
|
search_settings=search_settings,
|
||||||
|
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||||
|
server_port=MODEL_SERVER_PORT,
|
||||||
|
)
|
||||||
|
|
||||||
|
warm_up_bi_encoder(
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
)
|
||||||
|
logger.notice("First inference complete.")
|
||||||
|
|
||||||
|
app_base.wait_for_redis(sender, **kwargs)
|
||||||
|
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_ready.connect
|
||||||
|
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||||
|
app_base.on_worker_ready(sender, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_shutdown.connect
|
||||||
|
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||||
|
app_base.on_worker_shutdown(sender, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@signals.setup_logging.connect
|
||||||
|
def on_setup_logging(
|
||||||
|
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
celery_app.autodiscover_tasks(
|
||||||
|
[
|
||||||
|
"danswer.background.celery.tasks.indexing",
|
||||||
|
]
|
||||||
|
)
|
89
backend/danswer/background/celery/apps/light.py
Normal file
89
backend/danswer/background/celery/apps/light.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import multiprocessing
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from celery import Celery
|
||||||
|
from celery import signals
|
||||||
|
from celery import Task
|
||||||
|
from celery.signals import celeryd_init
|
||||||
|
from celery.signals import worker_init
|
||||||
|
from celery.signals import worker_ready
|
||||||
|
from celery.signals import worker_shutdown
|
||||||
|
|
||||||
|
import danswer.background.celery.apps.app_base as app_base
|
||||||
|
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||||
|
from danswer.db.engine import SqlEngine
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
celery_app = Celery(__name__)
|
||||||
|
celery_app.config_from_object("danswer.background.celery.configs.light")
|
||||||
|
|
||||||
|
|
||||||
|
@signals.task_prerun.connect
|
||||||
|
def on_task_prerun(
|
||||||
|
sender: Any | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
task: Task | None = None,
|
||||||
|
args: tuple | None = None,
|
||||||
|
kwargs: dict | None = None,
|
||||||
|
**kwds: Any,
|
||||||
|
) -> None:
|
||||||
|
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||||
|
|
||||||
|
|
||||||
|
@signals.task_postrun.connect
|
||||||
|
def on_task_postrun(
|
||||||
|
sender: Any | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
task: Task | None = None,
|
||||||
|
args: tuple | None = None,
|
||||||
|
kwargs: dict | None = None,
|
||||||
|
retval: Any | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
**kwds: Any,
|
||||||
|
) -> None:
|
||||||
|
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||||
|
|
||||||
|
|
||||||
|
@celeryd_init.connect
|
||||||
|
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||||
|
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_init.connect
|
||||||
|
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||||
|
logger.info("worker_init signal received.")
|
||||||
|
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||||
|
|
||||||
|
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||||
|
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||||
|
|
||||||
|
app_base.wait_for_redis(sender, **kwargs)
|
||||||
|
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_ready.connect
|
||||||
|
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||||
|
app_base.on_worker_ready(sender, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_shutdown.connect
|
||||||
|
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||||
|
app_base.on_worker_shutdown(sender, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@signals.setup_logging.connect
|
||||||
|
def on_setup_logging(
|
||||||
|
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
celery_app.autodiscover_tasks(
|
||||||
|
[
|
||||||
|
"danswer.background.celery.tasks.shared",
|
||||||
|
"danswer.background.celery.tasks.vespa",
|
||||||
|
]
|
||||||
|
)
|
278
backend/danswer/background/celery/apps/primary.py
Normal file
278
backend/danswer/background/celery/apps/primary.py
Normal file
@ -0,0 +1,278 @@
|
|||||||
|
import multiprocessing
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import redis
|
||||||
|
from celery import bootsteps # type: ignore
|
||||||
|
from celery import Celery
|
||||||
|
from celery import signals
|
||||||
|
from celery import Task
|
||||||
|
from celery.exceptions import WorkerShutdown
|
||||||
|
from celery.signals import celeryd_init
|
||||||
|
from celery.signals import worker_init
|
||||||
|
from celery.signals import worker_ready
|
||||||
|
from celery.signals import worker_shutdown
|
||||||
|
from celery.utils.log import get_task_logger
|
||||||
|
|
||||||
|
import danswer.background.celery.apps.app_base as app_base
|
||||||
|
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||||
|
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||||
|
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||||
|
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||||
|
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||||
|
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||||
|
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||||
|
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||||
|
from danswer.configs.constants import DanswerRedisLocks
|
||||||
|
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||||
|
from danswer.db.engine import SqlEngine
|
||||||
|
from danswer.redis.redis_pool import get_redis_client
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
# use this within celery tasks to get celery task specific logging
|
||||||
|
task_logger = get_task_logger(__name__)
|
||||||
|
|
||||||
|
celery_app = Celery(__name__)
|
||||||
|
celery_app.config_from_object("danswer.background.celery.configs.primary")
|
||||||
|
|
||||||
|
|
||||||
|
@signals.task_prerun.connect
|
||||||
|
def on_task_prerun(
|
||||||
|
sender: Any | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
task: Task | None = None,
|
||||||
|
args: tuple | None = None,
|
||||||
|
kwargs: dict | None = None,
|
||||||
|
**kwds: Any,
|
||||||
|
) -> None:
|
||||||
|
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||||
|
|
||||||
|
|
||||||
|
@signals.task_postrun.connect
|
||||||
|
def on_task_postrun(
|
||||||
|
sender: Any | None = None,
|
||||||
|
task_id: str | None = None,
|
||||||
|
task: Task | None = None,
|
||||||
|
args: tuple | None = None,
|
||||||
|
kwargs: dict | None = None,
|
||||||
|
retval: Any | None = None,
|
||||||
|
state: str | None = None,
|
||||||
|
**kwds: Any,
|
||||||
|
) -> None:
|
||||||
|
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||||
|
|
||||||
|
|
||||||
|
@celeryd_init.connect
|
||||||
|
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
||||||
|
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_init.connect
|
||||||
|
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||||
|
logger.info("worker_init signal received.")
|
||||||
|
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||||
|
|
||||||
|
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||||
|
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||||
|
|
||||||
|
app_base.wait_for_redis(sender, **kwargs)
|
||||||
|
|
||||||
|
logger.info("Running as the primary celery worker.")
|
||||||
|
|
||||||
|
# This is singleton work that should be done on startup exactly once
|
||||||
|
# by the primary worker
|
||||||
|
r = get_redis_client()
|
||||||
|
|
||||||
|
# For the moment, we're assuming that we are the only primary worker
|
||||||
|
# that should be running.
|
||||||
|
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
||||||
|
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
||||||
|
|
||||||
|
# this process wide lock is taken to help other workers start up in order.
|
||||||
|
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||||
|
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||||
|
# implemented yet.
|
||||||
|
lock = r.lock(
|
||||||
|
DanswerRedisLocks.PRIMARY_WORKER,
|
||||||
|
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Primary worker lock: Acquire starting.")
|
||||||
|
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
||||||
|
if acquired:
|
||||||
|
logger.info("Primary worker lock: Acquire succeeded.")
|
||||||
|
else:
|
||||||
|
logger.error("Primary worker lock: Acquire failed!")
|
||||||
|
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||||
|
|
||||||
|
sender.primary_worker_lock = lock
|
||||||
|
|
||||||
|
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||||
|
# to a clean state (for our purposes, anyway)
|
||||||
|
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||||
|
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||||
|
|
||||||
|
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
||||||
|
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||||
|
r.delete(key)
|
||||||
|
|
||||||
|
|
||||||
|
# @worker_process_init.connect
|
||||||
|
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
|
||||||
|
# """This only runs inside child processes when the worker is in pool=prefork mode.
|
||||||
|
# This may be technically unnecessary since we're finding prefork pools to be
|
||||||
|
# unstable and currently aren't planning on using them."""
|
||||||
|
# logger.info("worker_process_init signal received.")
|
||||||
|
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||||
|
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
|
||||||
|
|
||||||
|
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
|
||||||
|
# SqlEngine.get_engine().dispose(close=False)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_ready.connect
|
||||||
|
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||||
|
app_base.on_worker_ready(sender, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@worker_shutdown.connect
|
||||||
|
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||||
|
app_base.on_worker_shutdown(sender, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@signals.setup_logging.connect
|
||||||
|
def on_setup_logging(
|
||||||
|
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||||
|
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||||
|
Use the task_logger in this class to avoid double logging.
|
||||||
|
|
||||||
|
This cannot be done inside a regular beat task because it must run on schedule and
|
||||||
|
a queue of existing work would starve the task from running.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# it's unclear to me whether using the hub's timer or the bootstep timer is better
|
||||||
|
requires = {"celery.worker.components:Hub"}
|
||||||
|
|
||||||
|
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
||||||
|
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
|
||||||
|
self.task_tref = None
|
||||||
|
|
||||||
|
def start(self, worker: Any) -> None:
|
||||||
|
if not celery_is_worker_primary(worker):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Access the worker's event loop (hub)
|
||||||
|
hub = worker.consumer.controller.hub
|
||||||
|
|
||||||
|
# Schedule the periodic task
|
||||||
|
self.task_tref = hub.call_repeatedly(
|
||||||
|
self.interval, self.run_periodic_task, worker
|
||||||
|
)
|
||||||
|
task_logger.info("Scheduled periodic task with hub.")
|
||||||
|
|
||||||
|
def run_periodic_task(self, worker: Any) -> None:
|
||||||
|
try:
|
||||||
|
if not worker.primary_worker_lock:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not hasattr(worker, "primary_worker_lock"):
|
||||||
|
return
|
||||||
|
|
||||||
|
r = get_redis_client()
|
||||||
|
|
||||||
|
lock: redis.lock.Lock = worker.primary_worker_lock
|
||||||
|
|
||||||
|
if lock.owned():
|
||||||
|
task_logger.debug("Reacquiring primary worker lock.")
|
||||||
|
lock.reacquire()
|
||||||
|
else:
|
||||||
|
task_logger.warning(
|
||||||
|
"Full acquisition of primary worker lock. "
|
||||||
|
"Reasons could be computer sleep or a clock change."
|
||||||
|
)
|
||||||
|
lock = r.lock(
|
||||||
|
DanswerRedisLocks.PRIMARY_WORKER,
|
||||||
|
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
task_logger.info("Primary worker lock: Acquire starting.")
|
||||||
|
acquired = lock.acquire(
|
||||||
|
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
||||||
|
)
|
||||||
|
if acquired:
|
||||||
|
task_logger.info("Primary worker lock: Acquire succeeded.")
|
||||||
|
else:
|
||||||
|
task_logger.error("Primary worker lock: Acquire failed!")
|
||||||
|
raise TimeoutError("Primary worker lock could not be acquired!")
|
||||||
|
|
||||||
|
worker.primary_worker_lock = lock
|
||||||
|
except Exception:
|
||||||
|
task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.")
|
||||||
|
|
||||||
|
def stop(self, worker: Any) -> None:
|
||||||
|
# Cancel the scheduled task when the worker stops
|
||||||
|
if self.task_tref:
|
||||||
|
self.task_tref.cancel()
|
||||||
|
task_logger.info("Canceled periodic task with hub.")
|
||||||
|
|
||||||
|
|
||||||
|
celery_app.steps["worker"].add(HubPeriodicTask)
|
||||||
|
|
||||||
|
celery_app.autodiscover_tasks(
|
||||||
|
[
|
||||||
|
"danswer.background.celery.tasks.connector_deletion",
|
||||||
|
"danswer.background.celery.tasks.indexing",
|
||||||
|
"danswer.background.celery.tasks.periodic",
|
||||||
|
"danswer.background.celery.tasks.pruning",
|
||||||
|
"danswer.background.celery.tasks.shared",
|
||||||
|
"danswer.background.celery.tasks.vespa",
|
||||||
|
]
|
||||||
|
)
|
26
backend/danswer/background/celery/apps/task_formatters.py
Normal file
26
backend/danswer/background/celery/apps/task_formatters.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from celery import current_task
|
||||||
|
|
||||||
|
from danswer.utils.logger import ColoredFormatter
|
||||||
|
from danswer.utils.logger import PlainFormatter
|
||||||
|
|
||||||
|
|
||||||
|
class CeleryTaskPlainFormatter(PlainFormatter):
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
task = current_task
|
||||||
|
if task and task.request:
|
||||||
|
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||||
|
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||||
|
|
||||||
|
return super().format(record)
|
||||||
|
|
||||||
|
|
||||||
|
class CeleryTaskColoredFormatter(ColoredFormatter):
|
||||||
|
def format(self, record: logging.LogRecord) -> str:
|
||||||
|
task = current_task
|
||||||
|
if task and task.request:
|
||||||
|
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
||||||
|
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
||||||
|
|
||||||
|
return super().format(record)
|
@ -1,601 +0,0 @@
|
|||||||
import logging
|
|
||||||
import multiprocessing
|
|
||||||
import time
|
|
||||||
from datetime import timedelta
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import redis
|
|
||||||
import sentry_sdk
|
|
||||||
from celery import bootsteps # type: ignore
|
|
||||||
from celery import Celery
|
|
||||||
from celery import current_task
|
|
||||||
from celery import signals
|
|
||||||
from celery import Task
|
|
||||||
from celery.exceptions import WorkerShutdown
|
|
||||||
from celery.signals import beat_init
|
|
||||||
from celery.signals import celeryd_init
|
|
||||||
from celery.signals import worker_init
|
|
||||||
from celery.signals import worker_ready
|
|
||||||
from celery.signals import worker_shutdown
|
|
||||||
from celery.states import READY_STATES
|
|
||||||
from celery.utils.log import get_task_logger
|
|
||||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
|
||||||
|
|
||||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
|
||||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
|
||||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
|
||||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
|
||||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
|
||||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
|
||||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
|
||||||
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
|
||||||
from danswer.configs.constants import DanswerCeleryPriority
|
|
||||||
from danswer.configs.constants import DanswerRedisLocks
|
|
||||||
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
|
||||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
|
||||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
|
||||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
|
||||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
|
||||||
from danswer.db.engine import get_all_tenant_ids
|
|
||||||
from danswer.db.engine import get_session_with_tenant
|
|
||||||
from danswer.db.engine import SqlEngine
|
|
||||||
from danswer.db.search_settings import get_current_search_settings
|
|
||||||
from danswer.db.swap_index import check_index_swap
|
|
||||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
|
||||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
|
||||||
from danswer.redis.redis_pool import get_redis_client
|
|
||||||
from danswer.utils.logger import ColoredFormatter
|
|
||||||
from danswer.utils.logger import PlainFormatter
|
|
||||||
from danswer.utils.logger import setup_logger
|
|
||||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
|
||||||
from shared_configs.configs import MODEL_SERVER_PORT
|
|
||||||
from shared_configs.configs import SENTRY_DSN
|
|
||||||
|
|
||||||
logger = setup_logger()
|
|
||||||
|
|
||||||
# use this within celery tasks to get celery task specific logging
|
|
||||||
task_logger = get_task_logger(__name__)
|
|
||||||
|
|
||||||
if SENTRY_DSN:
|
|
||||||
sentry_sdk.init(
|
|
||||||
dsn=SENTRY_DSN,
|
|
||||||
integrations=[CeleryIntegration()],
|
|
||||||
traces_sample_rate=0.5,
|
|
||||||
)
|
|
||||||
logger.info("Sentry initialized")
|
|
||||||
else:
|
|
||||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
|
||||||
|
|
||||||
|
|
||||||
celery_app = Celery(__name__)
|
|
||||||
celery_app.config_from_object(
|
|
||||||
"danswer.background.celery.celeryconfig"
|
|
||||||
) # Load configuration from 'celeryconfig.py'
|
|
||||||
|
|
||||||
|
|
||||||
@signals.task_prerun.connect
|
|
||||||
def on_task_prerun(
|
|
||||||
sender: Any | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
task: Task | None = None,
|
|
||||||
args: tuple | None = None,
|
|
||||||
kwargs: dict | None = None,
|
|
||||||
**kwds: Any,
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@signals.task_postrun.connect
|
|
||||||
def on_task_postrun(
|
|
||||||
sender: Any | None = None,
|
|
||||||
task_id: str | None = None,
|
|
||||||
task: Task | None = None,
|
|
||||||
args: tuple | None = None,
|
|
||||||
kwargs: dict | None = None,
|
|
||||||
retval: Any | None = None,
|
|
||||||
state: str | None = None,
|
|
||||||
**kwds: Any,
|
|
||||||
) -> None:
|
|
||||||
"""We handle this signal in order to remove completed tasks
|
|
||||||
from their respective tasksets. This allows us to track the progress of document set
|
|
||||||
and user group syncs.
|
|
||||||
|
|
||||||
This function runs after any task completes (both success and failure)
|
|
||||||
Note that this signal does not fire on a task that failed to complete and is going
|
|
||||||
to be retried.
|
|
||||||
|
|
||||||
This also does not fire if a worker with acks_late=False crashes (which all of our
|
|
||||||
long running workers are)
|
|
||||||
"""
|
|
||||||
if not task:
|
|
||||||
return
|
|
||||||
|
|
||||||
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
|
|
||||||
|
|
||||||
if state not in READY_STATES:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not task_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
r = get_redis_client()
|
|
||||||
|
|
||||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
|
||||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
if task_id.startswith(RedisDocumentSet.PREFIX):
|
|
||||||
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
|
|
||||||
if document_set_id is not None:
|
|
||||||
rds = RedisDocumentSet(int(document_set_id))
|
|
||||||
r.srem(rds.taskset_key, task_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
if task_id.startswith(RedisUserGroup.PREFIX):
|
|
||||||
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
|
|
||||||
if usergroup_id is not None:
|
|
||||||
rug = RedisUserGroup(int(usergroup_id))
|
|
||||||
r.srem(rug.taskset_key, task_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
if task_id.startswith(RedisConnectorDeletion.PREFIX):
|
|
||||||
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
|
|
||||||
if cc_pair_id is not None:
|
|
||||||
rcd = RedisConnectorDeletion(int(cc_pair_id))
|
|
||||||
r.srem(rcd.taskset_key, task_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
|
|
||||||
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
|
|
||||||
if cc_pair_id is not None:
|
|
||||||
rcp = RedisConnectorPruning(int(cc_pair_id))
|
|
||||||
r.srem(rcp.taskset_key, task_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
@celeryd_init.connect
|
|
||||||
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
|
|
||||||
"""The first signal sent on celery worker startup"""
|
|
||||||
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
|
|
||||||
|
|
||||||
|
|
||||||
@beat_init.connect
|
|
||||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
|
||||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
|
||||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
|
||||||
|
|
||||||
|
|
||||||
@worker_init.connect
|
|
||||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
|
||||||
logger.info("worker_init signal received.")
|
|
||||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
|
||||||
|
|
||||||
# decide some initial startup settings based on the celery worker's hostname
|
|
||||||
# (set at the command line)
|
|
||||||
hostname = sender.hostname
|
|
||||||
if hostname.startswith("light"):
|
|
||||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
|
||||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
|
||||||
elif hostname.startswith("heavy"):
|
|
||||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
|
||||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
|
||||||
elif hostname.startswith("indexing"):
|
|
||||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
|
||||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
|
||||||
|
|
||||||
# TODO: why is this necessary for the indexer to do?
|
|
||||||
with get_session_with_tenant(tenant_id) as db_session:
|
|
||||||
check_index_swap(db_session=db_session)
|
|
||||||
search_settings = get_current_search_settings(db_session)
|
|
||||||
|
|
||||||
# So that the first time users aren't surprised by really slow speed of first
|
|
||||||
# batch of documents indexed
|
|
||||||
|
|
||||||
if search_settings.provider_type is None:
|
|
||||||
logger.notice("Running a first inference to warm up embedding model")
|
|
||||||
embedding_model = EmbeddingModel.from_db_model(
|
|
||||||
search_settings=search_settings,
|
|
||||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
|
||||||
server_port=MODEL_SERVER_PORT,
|
|
||||||
)
|
|
||||||
|
|
||||||
warm_up_bi_encoder(
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
)
|
|
||||||
logger.notice("First inference complete.")
|
|
||||||
else:
|
|
||||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
|
||||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
|
||||||
|
|
||||||
r = get_redis_client()
|
|
||||||
|
|
||||||
WAIT_INTERVAL = 5
|
|
||||||
WAIT_LIMIT = 60
|
|
||||||
|
|
||||||
time_start = time.monotonic()
|
|
||||||
logger.info("Redis: Readiness check starting.")
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
if r.ping():
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
time_elapsed = time.monotonic() - time_start
|
|
||||||
logger.info(
|
|
||||||
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
|
||||||
)
|
|
||||||
if time_elapsed > WAIT_LIMIT:
|
|
||||||
msg = (
|
|
||||||
f"Redis: Readiness check did not succeed within the timeout "
|
|
||||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
|
||||||
)
|
|
||||||
logger.error(msg)
|
|
||||||
raise WorkerShutdown(msg)
|
|
||||||
|
|
||||||
time.sleep(WAIT_INTERVAL)
|
|
||||||
|
|
||||||
logger.info("Redis: Readiness check succeeded. Continuing...")
|
|
||||||
|
|
||||||
if not celery_is_worker_primary(sender):
|
|
||||||
logger.info("Running as a secondary celery worker.")
|
|
||||||
logger.info("Waiting for primary worker to be ready...")
|
|
||||||
time_start = time.monotonic()
|
|
||||||
while True:
|
|
||||||
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
|
|
||||||
break
|
|
||||||
|
|
||||||
time.monotonic()
|
|
||||||
time_elapsed = time.monotonic() - time_start
|
|
||||||
logger.info(
|
|
||||||
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
|
||||||
)
|
|
||||||
if time_elapsed > WAIT_LIMIT:
|
|
||||||
msg = (
|
|
||||||
f"Primary worker was not ready within the timeout. "
|
|
||||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
|
||||||
)
|
|
||||||
logger.error(msg)
|
|
||||||
raise WorkerShutdown(msg)
|
|
||||||
|
|
||||||
time.sleep(WAIT_INTERVAL)
|
|
||||||
|
|
||||||
logger.info("Wait for primary worker completed successfully. Continuing...")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Running as the primary celery worker.")
|
|
||||||
|
|
||||||
# This is singleton work that should be done on startup exactly once
|
|
||||||
# by the primary worker
|
|
||||||
r = get_redis_client()
|
|
||||||
|
|
||||||
# For the moment, we're assuming that we are the only primary worker
|
|
||||||
# that should be running.
|
|
||||||
# TODO: maybe check for or clean up another zombie primary worker if we detect it
|
|
||||||
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
|
|
||||||
|
|
||||||
# this process wide lock is taken to help other workers start up in order.
|
|
||||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
|
||||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
|
||||||
# implemented yet.
|
|
||||||
lock = r.lock(
|
|
||||||
DanswerRedisLocks.PRIMARY_WORKER,
|
|
||||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Primary worker lock: Acquire starting.")
|
|
||||||
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
|
|
||||||
if acquired:
|
|
||||||
logger.info("Primary worker lock: Acquire succeeded.")
|
|
||||||
else:
|
|
||||||
logger.error("Primary worker lock: Acquire failed!")
|
|
||||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
|
||||||
|
|
||||||
sender.primary_worker_lock = lock
|
|
||||||
|
|
||||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
|
||||||
# to a clean state (for our purposes, anyway)
|
|
||||||
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
|
||||||
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
|
||||||
|
|
||||||
r.delete(RedisConnectorCredentialPair.get_taskset_key())
|
|
||||||
r.delete(RedisConnectorCredentialPair.get_fence_key())
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
|
||||||
r.delete(key)
|
|
||||||
|
|
||||||
|
|
||||||
# @worker_process_init.connect
|
|
||||||
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
|
|
||||||
# """This only runs inside child processes when the worker is in pool=prefork mode.
|
|
||||||
# This may be technically unnecessary since we're finding prefork pools to be
|
|
||||||
# unstable and currently aren't planning on using them."""
|
|
||||||
# logger.info("worker_process_init signal received.")
|
|
||||||
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
|
||||||
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
|
|
||||||
|
|
||||||
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
|
|
||||||
# SqlEngine.get_engine().dispose(close=False)
|
|
||||||
|
|
||||||
|
|
||||||
@worker_ready.connect
|
|
||||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
|
||||||
task_logger.info("worker_ready signal received.")
|
|
||||||
|
|
||||||
|
|
||||||
@worker_shutdown.connect
|
|
||||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
|
||||||
if not celery_is_worker_primary(sender):
|
|
||||||
return
|
|
||||||
|
|
||||||
if not sender.primary_worker_lock:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("Releasing primary worker lock.")
|
|
||||||
lock = sender.primary_worker_lock
|
|
||||||
if lock.owned():
|
|
||||||
lock.release()
|
|
||||||
sender.primary_worker_lock = None
|
|
||||||
|
|
||||||
|
|
||||||
class CeleryTaskPlainFormatter(PlainFormatter):
|
|
||||||
def format(self, record: logging.LogRecord) -> str:
|
|
||||||
task = current_task
|
|
||||||
if task and task.request:
|
|
||||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
|
||||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
|
||||||
|
|
||||||
return super().format(record)
|
|
||||||
|
|
||||||
|
|
||||||
class CeleryTaskColoredFormatter(ColoredFormatter):
|
|
||||||
def format(self, record: logging.LogRecord) -> str:
|
|
||||||
task = current_task
|
|
||||||
if task and task.request:
|
|
||||||
record.__dict__.update(task_id=task.request.id, task_name=task.name)
|
|
||||||
record.msg = f"[{task.name}({task.request.id})] {record.msg}"
|
|
||||||
|
|
||||||
return super().format(record)
|
|
||||||
|
|
||||||
|
|
||||||
@signals.setup_logging.connect
|
|
||||||
def on_setup_logging(
|
|
||||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
# TODO: could unhardcode format and colorize and accept these as options from
|
|
||||||
# celery's config
|
|
||||||
|
|
||||||
# reformats the root logger
|
|
||||||
root_logger = logging.getLogger()
|
|
||||||
|
|
||||||
root_handler = logging.StreamHandler() # Set up a handler for the root logger
|
|
||||||
root_formatter = ColoredFormatter(
|
|
||||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
|
||||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
|
||||||
)
|
|
||||||
root_handler.setFormatter(root_formatter)
|
|
||||||
root_logger.addHandler(root_handler) # Apply the handler to the root logger
|
|
||||||
|
|
||||||
if logfile:
|
|
||||||
root_file_handler = logging.FileHandler(logfile)
|
|
||||||
root_file_formatter = PlainFormatter(
|
|
||||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
|
||||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
|
||||||
)
|
|
||||||
root_file_handler.setFormatter(root_file_formatter)
|
|
||||||
root_logger.addHandler(root_file_handler)
|
|
||||||
|
|
||||||
root_logger.setLevel(loglevel)
|
|
||||||
|
|
||||||
# reformats celery's task logger
|
|
||||||
task_formatter = CeleryTaskColoredFormatter(
|
|
||||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
|
||||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
|
||||||
)
|
|
||||||
task_handler = logging.StreamHandler() # Set up a handler for the task logger
|
|
||||||
task_handler.setFormatter(task_formatter)
|
|
||||||
task_logger.addHandler(task_handler) # Apply the handler to the task logger
|
|
||||||
|
|
||||||
if logfile:
|
|
||||||
task_file_handler = logging.FileHandler(logfile)
|
|
||||||
task_file_formatter = CeleryTaskPlainFormatter(
|
|
||||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
|
||||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
|
||||||
)
|
|
||||||
task_file_handler.setFormatter(task_file_formatter)
|
|
||||||
task_logger.addHandler(task_file_handler)
|
|
||||||
|
|
||||||
task_logger.setLevel(loglevel)
|
|
||||||
task_logger.propagate = False
|
|
||||||
|
|
||||||
|
|
||||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
|
||||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
|
||||||
Use the task_logger in this class to avoid double logging.
|
|
||||||
|
|
||||||
This cannot be done inside a regular beat task because it must run on schedule and
|
|
||||||
a queue of existing work would starve the task from running.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# it's unclear to me whether using the hub's timer or the bootstep timer is better
|
|
||||||
requires = {"celery.worker.components:Hub"}
|
|
||||||
|
|
||||||
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
|
||||||
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
|
|
||||||
self.task_tref = None
|
|
||||||
|
|
||||||
def start(self, worker: Any) -> None:
|
|
||||||
if not celery_is_worker_primary(worker):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Access the worker's event loop (hub)
|
|
||||||
hub = worker.consumer.controller.hub
|
|
||||||
|
|
||||||
# Schedule the periodic task
|
|
||||||
self.task_tref = hub.call_repeatedly(
|
|
||||||
self.interval, self.run_periodic_task, worker
|
|
||||||
)
|
|
||||||
task_logger.info("Scheduled periodic task with hub.")
|
|
||||||
|
|
||||||
def run_periodic_task(self, worker: Any) -> None:
|
|
||||||
try:
|
|
||||||
if not worker.primary_worker_lock:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not hasattr(worker, "primary_worker_lock"):
|
|
||||||
return
|
|
||||||
|
|
||||||
r = get_redis_client()
|
|
||||||
|
|
||||||
lock: redis.lock.Lock = worker.primary_worker_lock
|
|
||||||
|
|
||||||
if lock.owned():
|
|
||||||
task_logger.debug("Reacquiring primary worker lock.")
|
|
||||||
lock.reacquire()
|
|
||||||
else:
|
|
||||||
task_logger.warning(
|
|
||||||
"Full acquisition of primary worker lock. "
|
|
||||||
"Reasons could be computer sleep or a clock change."
|
|
||||||
)
|
|
||||||
lock = r.lock(
|
|
||||||
DanswerRedisLocks.PRIMARY_WORKER,
|
|
||||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
|
||||||
)
|
|
||||||
|
|
||||||
task_logger.info("Primary worker lock: Acquire starting.")
|
|
||||||
acquired = lock.acquire(
|
|
||||||
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
|
|
||||||
)
|
|
||||||
if acquired:
|
|
||||||
task_logger.info("Primary worker lock: Acquire succeeded.")
|
|
||||||
else:
|
|
||||||
task_logger.error("Primary worker lock: Acquire failed!")
|
|
||||||
raise TimeoutError("Primary worker lock could not be acquired!")
|
|
||||||
|
|
||||||
worker.primary_worker_lock = lock
|
|
||||||
except Exception:
|
|
||||||
task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.")
|
|
||||||
|
|
||||||
def stop(self, worker: Any) -> None:
|
|
||||||
# Cancel the scheduled task when the worker stops
|
|
||||||
if self.task_tref:
|
|
||||||
self.task_tref.cancel()
|
|
||||||
task_logger.info("Canceled periodic task with hub.")
|
|
||||||
|
|
||||||
|
|
||||||
celery_app.steps["worker"].add(HubPeriodicTask)
|
|
||||||
|
|
||||||
celery_app.autodiscover_tasks(
|
|
||||||
[
|
|
||||||
"danswer.background.celery.tasks.connector_deletion",
|
|
||||||
"danswer.background.celery.tasks.indexing",
|
|
||||||
"danswer.background.celery.tasks.periodic",
|
|
||||||
"danswer.background.celery.tasks.pruning",
|
|
||||||
"danswer.background.celery.tasks.shared",
|
|
||||||
"danswer.background.celery.tasks.vespa",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
#####
|
|
||||||
# Celery Beat (Periodic Tasks) Settings
|
|
||||||
#####
|
|
||||||
|
|
||||||
tenant_ids = get_all_tenant_ids()
|
|
||||||
|
|
||||||
tasks_to_schedule = [
|
|
||||||
{
|
|
||||||
"name": "check-for-vespa-sync",
|
|
||||||
"task": "check_for_vespa_sync_task",
|
|
||||||
"schedule": timedelta(seconds=5),
|
|
||||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "check-for-connector-deletion",
|
|
||||||
"task": "check_for_connector_deletion_task",
|
|
||||||
"schedule": timedelta(seconds=60),
|
|
||||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "check-for-indexing",
|
|
||||||
"task": "check_for_indexing",
|
|
||||||
"schedule": timedelta(seconds=10),
|
|
||||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "check-for-prune",
|
|
||||||
"task": "check_for_pruning",
|
|
||||||
"schedule": timedelta(seconds=10),
|
|
||||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "kombu-message-cleanup",
|
|
||||||
"task": "kombu_message_cleanup_task",
|
|
||||||
"schedule": timedelta(seconds=3600),
|
|
||||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "monitor-vespa-sync",
|
|
||||||
"task": "monitor_vespa_sync",
|
|
||||||
"schedule": timedelta(seconds=5),
|
|
||||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Build the celery beat schedule dynamically
|
|
||||||
beat_schedule = {}
|
|
||||||
|
|
||||||
for tenant_id in tenant_ids:
|
|
||||||
for task in tasks_to_schedule:
|
|
||||||
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
|
||||||
beat_schedule[task_name] = {
|
|
||||||
"task": task["task"],
|
|
||||||
"schedule": task["schedule"],
|
|
||||||
"options": task["options"],
|
|
||||||
"args": (tenant_id,), # Must pass tenant_id as an argument
|
|
||||||
}
|
|
||||||
|
|
||||||
# Include any existing beat schedules
|
|
||||||
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
|
||||||
beat_schedule.update(existing_beat_schedule)
|
|
||||||
|
|
||||||
# Update the Celery app configuration once
|
|
||||||
celery_app.conf.beat_schedule = beat_schedule
|
|
@ -10,7 +10,7 @@ from celery import Celery
|
|||||||
from redis import Redis
|
from redis import Redis
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.background.celery.celeryconfig import CELERY_SEPARATOR
|
from danswer.background.celery.configs.base import CELERY_SEPARATOR
|
||||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||||
from danswer.configs.constants import DanswerCeleryPriority
|
from danswer.configs.constants import DanswerCeleryPriority
|
||||||
from danswer.configs.constants import DanswerCeleryQueues
|
from danswer.configs.constants import DanswerCeleryQueues
|
||||||
|
@ -3,13 +3,10 @@ from datetime import datetime
|
|||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import text
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||||
from danswer.configs.app_configs import MULTI_TENANT
|
|
||||||
from danswer.configs.constants import TENANT_ID_PREFIX
|
|
||||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||||
rate_limit_builder,
|
rate_limit_builder,
|
||||||
)
|
)
|
||||||
@ -19,7 +16,6 @@ from danswer.connectors.interfaces import PollConnector
|
|||||||
from danswer.connectors.interfaces import SlimConnector
|
from danswer.connectors.interfaces import SlimConnector
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||||
from danswer.db.engine import get_session_with_tenant
|
|
||||||
from danswer.db.enums import TaskStatus
|
from danswer.db.enums import TaskStatus
|
||||||
from danswer.db.models import TaskQueueState
|
from danswer.db.models import TaskQueueState
|
||||||
from danswer.redis.redis_pool import get_redis_client
|
from danswer.redis.redis_pool import get_redis_client
|
||||||
@ -129,33 +125,10 @@ def celery_is_listening_to_queue(worker: Any, name: str) -> bool:
|
|||||||
def celery_is_worker_primary(worker: Any) -> bool:
|
def celery_is_worker_primary(worker: Any) -> bool:
|
||||||
"""There are multiple approaches that could be taken to determine if a celery worker
|
"""There are multiple approaches that could be taken to determine if a celery worker
|
||||||
is 'primary', as defined by us. But the way we do it is to check the hostname set
|
is 'primary', as defined by us. But the way we do it is to check the hostname set
|
||||||
for the celery worker, which can be done either in celeryconfig.py or on the
|
for the celery worker, which can be done on the
|
||||||
command line with '--hostname'."""
|
command line with '--hostname'."""
|
||||||
hostname = worker.hostname
|
hostname = worker.hostname
|
||||||
if hostname.startswith("primary"):
|
if hostname.startswith("primary"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
|
||||||
if not MULTI_TENANT:
|
|
||||||
return [None]
|
|
||||||
with get_session_with_tenant(tenant_id="public") as session:
|
|
||||||
result = session.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
SELECT schema_name
|
|
||||||
FROM information_schema.schemata
|
|
||||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
|
|
||||||
)
|
|
||||||
)
|
|
||||||
tenant_ids = [row[0] for row in result]
|
|
||||||
|
|
||||||
valid_tenants = [
|
|
||||||
tenant
|
|
||||||
for tenant in tenant_ids
|
|
||||||
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
|
||||||
]
|
|
||||||
|
|
||||||
return valid_tenants
|
|
||||||
|
@ -31,21 +31,10 @@ if REDIS_SSL:
|
|||||||
if REDIS_SSL_CA_CERTS:
|
if REDIS_SSL_CA_CERTS:
|
||||||
SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}"
|
SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}"
|
||||||
|
|
||||||
|
# region Broker settings
|
||||||
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
||||||
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||||
|
|
||||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
|
|
||||||
|
|
||||||
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
|
|
||||||
# however, prefetching is bad when tasks are lengthy as those tasks
|
|
||||||
# can stall other tasks.
|
|
||||||
worker_prefetch_multiplier = 4
|
|
||||||
|
|
||||||
# Leaving this to the default of True may cause double logging since both our own app
|
|
||||||
# and celery think they are controlling the logger.
|
|
||||||
# TODO: Configure celery's logger entirely manually and set this to False
|
|
||||||
# worker_hijack_root_logger = False
|
|
||||||
|
|
||||||
broker_connection_retry_on_startup = True
|
broker_connection_retry_on_startup = True
|
||||||
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
|
broker_pool_limit = CELERY_BROKER_POOL_LIMIT
|
||||||
|
|
||||||
@ -60,6 +49,7 @@ broker_transport_options = {
|
|||||||
"socket_keepalive": True,
|
"socket_keepalive": True,
|
||||||
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
"socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS,
|
||||||
}
|
}
|
||||||
|
# endregion
|
||||||
|
|
||||||
# redis backend settings
|
# redis backend settings
|
||||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||||
@ -73,10 +63,19 @@ redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL
|
|||||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||||
task_acks_late = True
|
task_acks_late = True
|
||||||
|
|
||||||
|
# region Task result backend settings
|
||||||
# It's possible we don't even need celery's result backend, in which case all of the optimization below
|
# It's possible we don't even need celery's result backend, in which case all of the optimization below
|
||||||
# might be irrelevant
|
# might be irrelevant
|
||||||
|
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
|
||||||
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
|
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# Leaving this to the default of True may cause double logging since both our own app
|
||||||
|
# and celery think they are controlling the logger.
|
||||||
|
# TODO: Configure celery's logger entirely manually and set this to False
|
||||||
|
# worker_hijack_root_logger = False
|
||||||
|
|
||||||
|
# region Notes on serialization performance
|
||||||
# Option 0: Defaults (json serializer, no compression)
|
# Option 0: Defaults (json serializer, no compression)
|
||||||
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
|
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
|
||||||
|
|
||||||
@ -102,3 +101,4 @@ result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
|
|||||||
# task_serializer = "pickle-bzip2"
|
# task_serializer = "pickle-bzip2"
|
||||||
# result_serializer = "pickle-bzip2"
|
# result_serializer = "pickle-bzip2"
|
||||||
# accept_content=["pickle", "pickle-bzip2"]
|
# accept_content=["pickle", "pickle-bzip2"]
|
||||||
|
# endregion
|
14
backend/danswer/background/celery/configs/beat.py
Normal file
14
backend/danswer/background/celery/configs/beat.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||||
|
import danswer.background.celery.configs.base as shared_config
|
||||||
|
|
||||||
|
broker_url = shared_config.broker_url
|
||||||
|
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||||
|
broker_pool_limit = shared_config.broker_pool_limit
|
||||||
|
broker_transport_options = shared_config.broker_transport_options
|
||||||
|
|
||||||
|
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||||
|
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||||
|
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||||
|
|
||||||
|
result_backend = shared_config.result_backend
|
||||||
|
result_expires = shared_config.result_expires # 86400 seconds is the default
|
20
backend/danswer/background/celery/configs/heavy.py
Normal file
20
backend/danswer/background/celery/configs/heavy.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import danswer.background.celery.configs.base as shared_config
|
||||||
|
|
||||||
|
broker_url = shared_config.broker_url
|
||||||
|
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||||
|
broker_pool_limit = shared_config.broker_pool_limit
|
||||||
|
broker_transport_options = shared_config.broker_transport_options
|
||||||
|
|
||||||
|
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||||
|
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||||
|
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||||
|
|
||||||
|
result_backend = shared_config.result_backend
|
||||||
|
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||||
|
|
||||||
|
task_default_priority = shared_config.task_default_priority
|
||||||
|
task_acks_late = shared_config.task_acks_late
|
||||||
|
|
||||||
|
worker_concurrency = 4
|
||||||
|
worker_pool = "threads"
|
||||||
|
worker_prefetch_multiplier = 1
|
21
backend/danswer/background/celery/configs/indexing.py
Normal file
21
backend/danswer/background/celery/configs/indexing.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import danswer.background.celery.configs.base as shared_config
|
||||||
|
from danswer.configs.app_configs import CELERY_WORKER_INDEXING_CONCURRENCY
|
||||||
|
|
||||||
|
broker_url = shared_config.broker_url
|
||||||
|
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||||
|
broker_pool_limit = shared_config.broker_pool_limit
|
||||||
|
broker_transport_options = shared_config.broker_transport_options
|
||||||
|
|
||||||
|
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||||
|
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||||
|
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||||
|
|
||||||
|
result_backend = shared_config.result_backend
|
||||||
|
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||||
|
|
||||||
|
task_default_priority = shared_config.task_default_priority
|
||||||
|
task_acks_late = shared_config.task_acks_late
|
||||||
|
|
||||||
|
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
|
||||||
|
worker_pool = "threads"
|
||||||
|
worker_prefetch_multiplier = 1
|
22
backend/danswer/background/celery/configs/light.py
Normal file
22
backend/danswer/background/celery/configs/light.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import danswer.background.celery.configs.base as shared_config
|
||||||
|
from danswer.configs.app_configs import CELERY_WORKER_LIGHT_CONCURRENCY
|
||||||
|
from danswer.configs.app_configs import CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER
|
||||||
|
|
||||||
|
broker_url = shared_config.broker_url
|
||||||
|
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||||
|
broker_pool_limit = shared_config.broker_pool_limit
|
||||||
|
broker_transport_options = shared_config.broker_transport_options
|
||||||
|
|
||||||
|
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||||
|
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||||
|
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||||
|
|
||||||
|
result_backend = shared_config.result_backend
|
||||||
|
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||||
|
|
||||||
|
task_default_priority = shared_config.task_default_priority
|
||||||
|
task_acks_late = shared_config.task_acks_late
|
||||||
|
|
||||||
|
worker_concurrency = CELERY_WORKER_LIGHT_CONCURRENCY
|
||||||
|
worker_pool = "threads"
|
||||||
|
worker_prefetch_multiplier = CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER
|
20
backend/danswer/background/celery/configs/primary.py
Normal file
20
backend/danswer/background/celery/configs/primary.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import danswer.background.celery.configs.base as shared_config
|
||||||
|
|
||||||
|
broker_url = shared_config.broker_url
|
||||||
|
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||||
|
broker_pool_limit = shared_config.broker_pool_limit
|
||||||
|
broker_transport_options = shared_config.broker_transport_options
|
||||||
|
|
||||||
|
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||||
|
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||||
|
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||||
|
|
||||||
|
result_backend = shared_config.result_backend
|
||||||
|
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||||
|
|
||||||
|
task_default_priority = shared_config.task_default_priority
|
||||||
|
task_acks_late = shared_config.task_acks_late
|
||||||
|
|
||||||
|
worker_concurrency = 4
|
||||||
|
worker_pool = "threads"
|
||||||
|
worker_prefetch_multiplier = 1
|
@ -1,20 +1,20 @@
|
|||||||
import redis
|
import redis
|
||||||
|
from celery import Celery
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
from celery import Task
|
||||||
from celery.exceptions import SoftTimeLimitExceeded
|
from celery.exceptions import SoftTimeLimitExceeded
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
|
||||||
|
|
||||||
from danswer.background.celery.celery_app import celery_app
|
from danswer.background.celery.apps.app_base import task_logger
|
||||||
from danswer.background.celery.celery_app import task_logger
|
|
||||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||||
from danswer.configs.constants import DanswerRedisLocks
|
from danswer.configs.constants import DanswerRedisLocks
|
||||||
|
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||||
from danswer.db.engine import get_session_with_tenant
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||||
from danswer.db.models import ConnectorCredentialPair
|
|
||||||
from danswer.redis.redis_pool import get_redis_client
|
from danswer.redis.redis_pool import get_redis_client
|
||||||
|
|
||||||
|
|
||||||
@ -22,8 +22,9 @@ from danswer.redis.redis_pool import get_redis_client
|
|||||||
name="check_for_connector_deletion_task",
|
name="check_for_connector_deletion_task",
|
||||||
soft_time_limit=JOB_TIMEOUT,
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
trail=False,
|
trail=False,
|
||||||
|
bind=True,
|
||||||
)
|
)
|
||||||
def check_for_connector_deletion_task(tenant_id: str | None) -> None:
|
def check_for_connector_deletion_task(self: Task, tenant_id: str | None) -> None:
|
||||||
r = get_redis_client()
|
r = get_redis_client()
|
||||||
|
|
||||||
lock_beat = r.lock(
|
lock_beat = r.lock(
|
||||||
@ -36,11 +37,16 @@ def check_for_connector_deletion_task(tenant_id: str | None) -> None:
|
|||||||
if not lock_beat.acquire(blocking=False):
|
if not lock_beat.acquire(blocking=False):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
cc_pair_ids: list[int] = []
|
||||||
with get_session_with_tenant(tenant_id) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
cc_pairs = get_connector_credential_pairs(db_session)
|
cc_pairs = get_connector_credential_pairs(db_session)
|
||||||
for cc_pair in cc_pairs:
|
for cc_pair in cc_pairs:
|
||||||
|
cc_pair_ids.append(cc_pair.id)
|
||||||
|
|
||||||
|
for cc_pair_id in cc_pair_ids:
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
try_generate_document_cc_pair_cleanup_tasks(
|
try_generate_document_cc_pair_cleanup_tasks(
|
||||||
cc_pair, db_session, r, lock_beat, tenant_id
|
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
|
||||||
)
|
)
|
||||||
except SoftTimeLimitExceeded:
|
except SoftTimeLimitExceeded:
|
||||||
task_logger.info(
|
task_logger.info(
|
||||||
@ -54,7 +60,8 @@ def check_for_connector_deletion_task(tenant_id: str | None) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def try_generate_document_cc_pair_cleanup_tasks(
|
def try_generate_document_cc_pair_cleanup_tasks(
|
||||||
cc_pair: ConnectorCredentialPair,
|
app: Celery,
|
||||||
|
cc_pair_id: int,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
r: Redis,
|
r: Redis,
|
||||||
lock_beat: redis.lock.Lock,
|
lock_beat: redis.lock.Lock,
|
||||||
@ -67,18 +74,17 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
|||||||
|
|
||||||
lock_beat.reacquire()
|
lock_beat.reacquire()
|
||||||
|
|
||||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||||
|
|
||||||
# don't generate sync tasks if tasks are still pending
|
# don't generate sync tasks if tasks are still pending
|
||||||
if r.exists(rcd.fence_key):
|
if r.exists(rcd.fence_key):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# we need to refresh the state of the object inside the fence
|
# we need to load the state of the object inside the fence
|
||||||
# to avoid a race condition with db.commit/fence deletion
|
# to avoid a race condition with db.commit/fence deletion
|
||||||
# at the end of this taskset
|
# at the end of this taskset
|
||||||
try:
|
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||||
db_session.refresh(cc_pair)
|
if not cc_pair:
|
||||||
except ObjectDeletedError:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||||
@ -91,9 +97,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
|||||||
task_logger.info(
|
task_logger.info(
|
||||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||||
)
|
)
|
||||||
tasks_generated = rcd.generate_tasks(
|
tasks_generated = rcd.generate_tasks(app, db_session, r, lock_beat, tenant_id)
|
||||||
celery_app, db_session, r, lock_beat, tenant_id
|
|
||||||
)
|
|
||||||
if tasks_generated is None:
|
if tasks_generated is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -5,15 +5,18 @@ from time import sleep
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from celery import Celery
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
from celery import Task
|
||||||
from celery.exceptions import SoftTimeLimitExceeded
|
from celery.exceptions import SoftTimeLimitExceeded
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.background.celery.celery_app import celery_app
|
from danswer.background.celery.apps.app_base import task_logger
|
||||||
from danswer.background.celery.celery_app import task_logger
|
|
||||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||||
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
|
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||||
|
RedisConnectorIndexingFenceData,
|
||||||
|
)
|
||||||
from danswer.background.indexing.job_client import SimpleJobClient
|
from danswer.background.indexing.job_client import SimpleJobClient
|
||||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||||
@ -50,8 +53,9 @@ logger = setup_logger()
|
|||||||
@shared_task(
|
@shared_task(
|
||||||
name="check_for_indexing",
|
name="check_for_indexing",
|
||||||
soft_time_limit=300,
|
soft_time_limit=300,
|
||||||
|
bind=True,
|
||||||
)
|
)
|
||||||
def check_for_indexing(tenant_id: str | None) -> int | None:
|
def check_for_indexing(self: Task, tenant_id: str | None) -> int | None:
|
||||||
tasks_created = 0
|
tasks_created = 0
|
||||||
|
|
||||||
r = get_redis_client()
|
r = get_redis_client()
|
||||||
@ -66,26 +70,37 @@ def check_for_indexing(tenant_id: str | None) -> int | None:
|
|||||||
if not lock_beat.acquire(blocking=False):
|
if not lock_beat.acquire(blocking=False):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
cc_pair_ids: list[int] = []
|
||||||
with get_session_with_tenant(tenant_id) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
# Get the primary search settings
|
|
||||||
primary_search_settings = get_current_search_settings(db_session)
|
|
||||||
search_settings = [primary_search_settings]
|
|
||||||
|
|
||||||
# Check for secondary search settings
|
|
||||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
|
||||||
if secondary_search_settings is not None:
|
|
||||||
# If secondary settings exist, add them to the list
|
|
||||||
search_settings.append(secondary_search_settings)
|
|
||||||
|
|
||||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||||
for cc_pair in cc_pairs:
|
for cc_pair_entry in cc_pairs:
|
||||||
|
cc_pair_ids.append(cc_pair_entry.id)
|
||||||
|
|
||||||
|
for cc_pair_id in cc_pair_ids:
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
|
# Get the primary search settings
|
||||||
|
primary_search_settings = get_current_search_settings(db_session)
|
||||||
|
search_settings = [primary_search_settings]
|
||||||
|
|
||||||
|
# Check for secondary search settings
|
||||||
|
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||||
|
if secondary_search_settings is not None:
|
||||||
|
# If secondary settings exist, add them to the list
|
||||||
|
search_settings.append(secondary_search_settings)
|
||||||
|
|
||||||
for search_settings_instance in search_settings:
|
for search_settings_instance in search_settings:
|
||||||
rci = RedisConnectorIndexing(
|
rci = RedisConnectorIndexing(
|
||||||
cc_pair.id, search_settings_instance.id
|
cc_pair_id, search_settings_instance.id
|
||||||
)
|
)
|
||||||
if r.exists(rci.fence_key):
|
if r.exists(rci.fence_key):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
cc_pair = get_connector_credential_pair_from_id(
|
||||||
|
cc_pair_id, db_session
|
||||||
|
)
|
||||||
|
if not cc_pair:
|
||||||
|
continue
|
||||||
|
|
||||||
last_attempt = get_last_attempt_for_cc_pair(
|
last_attempt = get_last_attempt_for_cc_pair(
|
||||||
cc_pair.id, search_settings_instance.id, db_session
|
cc_pair.id, search_settings_instance.id, db_session
|
||||||
)
|
)
|
||||||
@ -101,6 +116,7 @@ def check_for_indexing(tenant_id: str | None) -> int | None:
|
|||||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||||
# prevents us from starving out certain attempts
|
# prevents us from starving out certain attempts
|
||||||
attempt_id = try_creating_indexing_task(
|
attempt_id = try_creating_indexing_task(
|
||||||
|
self.app,
|
||||||
cc_pair,
|
cc_pair,
|
||||||
search_settings_instance,
|
search_settings_instance,
|
||||||
False,
|
False,
|
||||||
@ -210,6 +226,7 @@ def _should_index(
|
|||||||
|
|
||||||
|
|
||||||
def try_creating_indexing_task(
|
def try_creating_indexing_task(
|
||||||
|
celery_app: Celery,
|
||||||
cc_pair: ConnectorCredentialPair,
|
cc_pair: ConnectorCredentialPair,
|
||||||
search_settings: SearchSettings,
|
search_settings: SearchSettings,
|
||||||
reindex: bool,
|
reindex: bool,
|
||||||
|
@ -11,7 +11,7 @@ from sqlalchemy import inspect
|
|||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.background.celery.celery_app import task_logger
|
from danswer.background.celery.apps.app_base import task_logger
|
||||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||||
from danswer.db.engine import get_session_with_tenant
|
from danswer.db.engine import get_session_with_tenant
|
||||||
|
@ -3,13 +3,14 @@ from datetime import timedelta
|
|||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from celery import Celery
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
from celery import Task
|
||||||
from celery.exceptions import SoftTimeLimitExceeded
|
from celery.exceptions import SoftTimeLimitExceeded
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.background.celery.celery_app import celery_app
|
from danswer.background.celery.apps.app_base import task_logger
|
||||||
from danswer.background.celery.celery_app import task_logger
|
|
||||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||||
@ -23,6 +24,7 @@ from danswer.configs.constants import DanswerRedisLocks
|
|||||||
from danswer.connectors.factory import instantiate_connector
|
from danswer.connectors.factory import instantiate_connector
|
||||||
from danswer.connectors.models import InputType
|
from danswer.connectors.models import InputType
|
||||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||||
|
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||||
from danswer.db.engine import get_session_with_tenant
|
from danswer.db.engine import get_session_with_tenant
|
||||||
@ -37,8 +39,9 @@ logger = setup_logger()
|
|||||||
@shared_task(
|
@shared_task(
|
||||||
name="check_for_pruning",
|
name="check_for_pruning",
|
||||||
soft_time_limit=JOB_TIMEOUT,
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
|
bind=True,
|
||||||
)
|
)
|
||||||
def check_for_pruning(tenant_id: str | None) -> None:
|
def check_for_pruning(self: Task, tenant_id: str | None) -> None:
|
||||||
r = get_redis_client()
|
r = get_redis_client()
|
||||||
|
|
||||||
lock_beat = r.lock(
|
lock_beat = r.lock(
|
||||||
@ -51,15 +54,24 @@ def check_for_pruning(tenant_id: str | None) -> None:
|
|||||||
if not lock_beat.acquire(blocking=False):
|
if not lock_beat.acquire(blocking=False):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
cc_pair_ids: list[int] = []
|
||||||
with get_session_with_tenant(tenant_id) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
cc_pairs = get_connector_credential_pairs(db_session)
|
cc_pairs = get_connector_credential_pairs(db_session)
|
||||||
for cc_pair in cc_pairs:
|
for cc_pair_entry in cc_pairs:
|
||||||
lock_beat.reacquire()
|
cc_pair_ids.append(cc_pair_entry.id)
|
||||||
|
|
||||||
|
for cc_pair_id in cc_pair_ids:
|
||||||
|
lock_beat.reacquire()
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
|
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||||
|
if not cc_pair:
|
||||||
|
continue
|
||||||
|
|
||||||
if not is_pruning_due(cc_pair, db_session, r):
|
if not is_pruning_due(cc_pair, db_session, r):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tasks_created = try_creating_prune_generator_task(
|
tasks_created = try_creating_prune_generator_task(
|
||||||
cc_pair, db_session, r, tenant_id
|
self.app, cc_pair, db_session, r, tenant_id
|
||||||
)
|
)
|
||||||
if not tasks_created:
|
if not tasks_created:
|
||||||
continue
|
continue
|
||||||
@ -118,6 +130,7 @@ def is_pruning_due(
|
|||||||
|
|
||||||
|
|
||||||
def try_creating_prune_generator_task(
|
def try_creating_prune_generator_task(
|
||||||
|
celery_app: Celery,
|
||||||
cc_pair: ConnectorCredentialPair,
|
cc_pair: ConnectorCredentialPair,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
r: Redis,
|
r: Redis,
|
||||||
@ -196,9 +209,14 @@ def try_creating_prune_generator_task(
|
|||||||
soft_time_limit=JOB_TIMEOUT,
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
track_started=True,
|
track_started=True,
|
||||||
trail=False,
|
trail=False,
|
||||||
|
bind=True,
|
||||||
)
|
)
|
||||||
def connector_pruning_generator_task(
|
def connector_pruning_generator_task(
|
||||||
cc_pair_id: int, connector_id: int, credential_id: int, tenant_id: str | None
|
self: Task,
|
||||||
|
cc_pair_id: int,
|
||||||
|
connector_id: int,
|
||||||
|
credential_id: int,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||||
@ -278,7 +296,7 @@ def connector_pruning_generator_task(
|
|||||||
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
f"RedisConnectorPruning.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||||
)
|
)
|
||||||
tasks_generated = rcp.generate_tasks(
|
tasks_generated = rcp.generate_tasks(
|
||||||
celery_app, db_session, r, None, tenant_id
|
self.app, db_session, r, None, tenant_id
|
||||||
)
|
)
|
||||||
if tasks_generated is None:
|
if tasks_generated is None:
|
||||||
return None
|
return None
|
||||||
|
@ -0,0 +1,10 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class RedisConnectorIndexingFenceData(BaseModel):
|
||||||
|
index_attempt_id: int | None
|
||||||
|
started: datetime | None
|
||||||
|
submitted: datetime
|
||||||
|
celery_task_id: str | None
|
@ -6,7 +6,7 @@ from celery.exceptions import SoftTimeLimitExceeded
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from danswer.access.access import get_access_for_document
|
from danswer.access.access import get_access_for_document
|
||||||
from danswer.background.celery.celery_app import task_logger
|
from danswer.background.celery.apps.app_base import task_logger
|
||||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||||
from danswer.db.document import delete_documents_complete__no_commit
|
from danswer.db.document import delete_documents_complete__no_commit
|
||||||
from danswer.db.document import get_document
|
from danswer.db.document import get_document
|
||||||
|
@ -5,6 +5,7 @@ from http import HTTPStatus
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
|
from celery import Celery
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from celery import Task
|
from celery import Task
|
||||||
from celery.exceptions import SoftTimeLimitExceeded
|
from celery.exceptions import SoftTimeLimitExceeded
|
||||||
@ -14,8 +15,7 @@ from redis import Redis
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.access.access import get_access_for_document
|
from danswer.access.access import get_access_for_document
|
||||||
from danswer.background.celery.celery_app import celery_app
|
from danswer.background.celery.apps.app_base import task_logger
|
||||||
from danswer.background.celery.celery_app import task_logger
|
|
||||||
from danswer.background.celery.celery_redis import celery_get_queue_length
|
from danswer.background.celery.celery_redis import celery_get_queue_length
|
||||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||||
@ -23,7 +23,9 @@ from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
|||||||
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
from danswer.background.celery.celery_redis import RedisConnectorPruning
|
||||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||||
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
|
from danswer.background.celery.tasks.shared.RedisConnectorIndexingFenceData import (
|
||||||
|
RedisConnectorIndexingFenceData,
|
||||||
|
)
|
||||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||||
from danswer.configs.constants import DanswerCeleryQueues
|
from danswer.configs.constants import DanswerCeleryQueues
|
||||||
@ -54,7 +56,6 @@ from danswer.db.index_attempt import get_index_attempt
|
|||||||
from danswer.db.index_attempt import mark_attempt_failed
|
from danswer.db.index_attempt import mark_attempt_failed
|
||||||
from danswer.db.models import DocumentSet
|
from danswer.db.models import DocumentSet
|
||||||
from danswer.db.models import IndexAttempt
|
from danswer.db.models import IndexAttempt
|
||||||
from danswer.db.models import UserGroup
|
|
||||||
from danswer.document_index.document_index_utils import get_both_index_names
|
from danswer.document_index.document_index_utils import get_both_index_names
|
||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
from danswer.document_index.interfaces import VespaDocumentFields
|
from danswer.document_index.interfaces import VespaDocumentFields
|
||||||
@ -73,8 +74,9 @@ from danswer.utils.variable_functionality import noop_fallback
|
|||||||
name="check_for_vespa_sync_task",
|
name="check_for_vespa_sync_task",
|
||||||
soft_time_limit=JOB_TIMEOUT,
|
soft_time_limit=JOB_TIMEOUT,
|
||||||
trail=False,
|
trail=False,
|
||||||
|
bind=True,
|
||||||
)
|
)
|
||||||
def check_for_vespa_sync_task(tenant_id: str | None) -> None:
|
def check_for_vespa_sync_task(self: Task, tenant_id: str | None) -> None:
|
||||||
"""Runs periodically to check if any document needs syncing.
|
"""Runs periodically to check if any document needs syncing.
|
||||||
Generates sets of tasks for Celery if syncing is needed."""
|
Generates sets of tasks for Celery if syncing is needed."""
|
||||||
|
|
||||||
@ -91,35 +93,53 @@ def check_for_vespa_sync_task(tenant_id: str | None) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
with get_session_with_tenant(tenant_id) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
try_generate_stale_document_sync_tasks(db_session, r, lock_beat, tenant_id)
|
try_generate_stale_document_sync_tasks(
|
||||||
|
self.app, db_session, r, lock_beat, tenant_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# region document set scan
|
||||||
|
document_set_ids: list[int] = []
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
# check if any document sets are not synced
|
# check if any document sets are not synced
|
||||||
document_set_info = fetch_document_sets(
|
document_set_info = fetch_document_sets(
|
||||||
user_id=None, db_session=db_session, include_outdated=True
|
user_id=None, db_session=db_session, include_outdated=True
|
||||||
)
|
)
|
||||||
|
|
||||||
for document_set, _ in document_set_info:
|
for document_set, _ in document_set_info:
|
||||||
|
document_set_ids.append(document_set.id)
|
||||||
|
|
||||||
|
for document_set_id in document_set_ids:
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
try_generate_document_set_sync_tasks(
|
try_generate_document_set_sync_tasks(
|
||||||
document_set, db_session, r, lock_beat, tenant_id
|
self.app, document_set_id, db_session, r, lock_beat, tenant_id
|
||||||
)
|
)
|
||||||
|
# endregion
|
||||||
|
|
||||||
# check if any user groups are not synced
|
# check if any user groups are not synced
|
||||||
if global_version.is_ee_version():
|
if global_version.is_ee_version():
|
||||||
try:
|
try:
|
||||||
fetch_user_groups = fetch_versioned_implementation(
|
fetch_user_groups = fetch_versioned_implementation(
|
||||||
"danswer.db.user_group", "fetch_user_groups"
|
"danswer.db.user_group", "fetch_user_groups"
|
||||||
)
|
)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# Always exceptions on the MIT version, which is expected
|
||||||
|
# We shouldn't actually get here if the ee version check works
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
usergroup_ids: list[int] = []
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
user_groups = fetch_user_groups(
|
user_groups = fetch_user_groups(
|
||||||
db_session=db_session, only_up_to_date=False
|
db_session=db_session, only_up_to_date=False
|
||||||
)
|
)
|
||||||
|
|
||||||
for usergroup in user_groups:
|
for usergroup in user_groups:
|
||||||
|
usergroup_ids.append(usergroup.id)
|
||||||
|
|
||||||
|
for usergroup_id in usergroup_ids:
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
try_generate_user_group_sync_tasks(
|
try_generate_user_group_sync_tasks(
|
||||||
usergroup, db_session, r, lock_beat, tenant_id
|
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||||
)
|
)
|
||||||
except ModuleNotFoundError:
|
|
||||||
# Always exceptions on the MIT version, which is expected
|
|
||||||
# We shouldn't actually get here if the ee version check works
|
|
||||||
pass
|
|
||||||
|
|
||||||
except SoftTimeLimitExceeded:
|
except SoftTimeLimitExceeded:
|
||||||
task_logger.info(
|
task_logger.info(
|
||||||
@ -133,7 +153,11 @@ def check_for_vespa_sync_task(tenant_id: str | None) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def try_generate_stale_document_sync_tasks(
|
def try_generate_stale_document_sync_tasks(
|
||||||
db_session: Session, r: Redis, lock_beat: redis.lock.Lock, tenant_id: str | None
|
celery_app: Celery,
|
||||||
|
db_session: Session,
|
||||||
|
r: Redis,
|
||||||
|
lock_beat: redis.lock.Lock,
|
||||||
|
tenant_id: str | None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
# the fence is up, do nothing
|
# the fence is up, do nothing
|
||||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||||
@ -184,7 +208,8 @@ def try_generate_stale_document_sync_tasks(
|
|||||||
|
|
||||||
|
|
||||||
def try_generate_document_set_sync_tasks(
|
def try_generate_document_set_sync_tasks(
|
||||||
document_set: DocumentSet,
|
celery_app: Celery,
|
||||||
|
document_set_id: int,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
r: Redis,
|
r: Redis,
|
||||||
lock_beat: redis.lock.Lock,
|
lock_beat: redis.lock.Lock,
|
||||||
@ -192,7 +217,7 @@ def try_generate_document_set_sync_tasks(
|
|||||||
) -> int | None:
|
) -> int | None:
|
||||||
lock_beat.reacquire()
|
lock_beat.reacquire()
|
||||||
|
|
||||||
rds = RedisDocumentSet(document_set.id)
|
rds = RedisDocumentSet(document_set_id)
|
||||||
|
|
||||||
# don't generate document set sync tasks if tasks are still pending
|
# don't generate document set sync tasks if tasks are still pending
|
||||||
if r.exists(rds.fence_key):
|
if r.exists(rds.fence_key):
|
||||||
@ -200,7 +225,10 @@ def try_generate_document_set_sync_tasks(
|
|||||||
|
|
||||||
# don't generate sync tasks if we're up to date
|
# don't generate sync tasks if we're up to date
|
||||||
# race condition with the monitor/cleanup function if we use a cached result!
|
# race condition with the monitor/cleanup function if we use a cached result!
|
||||||
db_session.refresh(document_set)
|
document_set = get_document_set_by_id(db_session, document_set_id)
|
||||||
|
if not document_set:
|
||||||
|
return None
|
||||||
|
|
||||||
if document_set.is_up_to_date:
|
if document_set.is_up_to_date:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -235,7 +263,8 @@ def try_generate_document_set_sync_tasks(
|
|||||||
|
|
||||||
|
|
||||||
def try_generate_user_group_sync_tasks(
|
def try_generate_user_group_sync_tasks(
|
||||||
usergroup: UserGroup,
|
celery_app: Celery,
|
||||||
|
usergroup_id: int,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
r: Redis,
|
r: Redis,
|
||||||
lock_beat: redis.lock.Lock,
|
lock_beat: redis.lock.Lock,
|
||||||
@ -243,14 +272,21 @@ def try_generate_user_group_sync_tasks(
|
|||||||
) -> int | None:
|
) -> int | None:
|
||||||
lock_beat.reacquire()
|
lock_beat.reacquire()
|
||||||
|
|
||||||
rug = RedisUserGroup(usergroup.id)
|
rug = RedisUserGroup(usergroup_id)
|
||||||
|
|
||||||
# don't generate sync tasks if tasks are still pending
|
# don't generate sync tasks if tasks are still pending
|
||||||
if r.exists(rug.fence_key):
|
if r.exists(rug.fence_key):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# race condition with the monitor/cleanup function if we use a cached result!
|
# race condition with the monitor/cleanup function if we use a cached result!
|
||||||
db_session.refresh(usergroup)
|
fetch_user_group = fetch_versioned_implementation(
|
||||||
|
"danswer.db.user_group", "fetch_user_group"
|
||||||
|
)
|
||||||
|
|
||||||
|
usergroup = fetch_user_group(db_session, usergroup_id)
|
||||||
|
if not usergroup:
|
||||||
|
return None
|
||||||
|
|
||||||
if usergroup.is_up_to_date:
|
if usergroup.is_up_to_date:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -680,36 +716,9 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
|||||||
f"pruning={n_pruning}"
|
f"pruning={n_pruning}"
|
||||||
)
|
)
|
||||||
|
|
||||||
lock_beat.reacquire()
|
# do some cleanup before clearing fences
|
||||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
# check the db for any outstanding index attempts
|
||||||
monitor_connector_taskset(r)
|
|
||||||
|
|
||||||
lock_beat.reacquire()
|
|
||||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
|
||||||
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
|
|
||||||
|
|
||||||
with get_session_with_tenant(tenant_id) as db_session:
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
lock_beat.reacquire()
|
|
||||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
|
||||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
|
||||||
|
|
||||||
lock_beat.reacquire()
|
|
||||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
|
||||||
monitor_usergroup_taskset = (
|
|
||||||
fetch_versioned_implementation_with_fallback(
|
|
||||||
"danswer.background.celery.tasks.vespa.tasks",
|
|
||||||
"monitor_usergroup_taskset",
|
|
||||||
noop_fallback,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
|
||||||
|
|
||||||
lock_beat.reacquire()
|
|
||||||
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
|
||||||
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
|
|
||||||
|
|
||||||
# do some cleanup before clearing fences
|
|
||||||
# check the db for any outstanding index attempts
|
|
||||||
attempts: list[IndexAttempt] = []
|
attempts: list[IndexAttempt] = []
|
||||||
attempts.extend(
|
attempts.extend(
|
||||||
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
|
||||||
@ -727,8 +736,42 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
|||||||
if not r.exists(rci.fence_key):
|
if not r.exists(rci.fence_key):
|
||||||
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
|
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
|
||||||
|
|
||||||
|
lock_beat.reacquire()
|
||||||
|
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||||
|
monitor_connector_taskset(r)
|
||||||
|
|
||||||
|
lock_beat.reacquire()
|
||||||
|
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||||
lock_beat.reacquire()
|
lock_beat.reacquire()
|
||||||
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
|
||||||
|
|
||||||
|
lock_beat.reacquire()
|
||||||
|
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||||
|
lock_beat.reacquire()
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
|
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||||
|
|
||||||
|
lock_beat.reacquire()
|
||||||
|
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||||
|
lock_beat.reacquire()
|
||||||
|
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||||
|
"danswer.background.celery.tasks.vespa.tasks",
|
||||||
|
"monitor_usergroup_taskset",
|
||||||
|
noop_fallback,
|
||||||
|
)
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
|
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||||
|
|
||||||
|
lock_beat.reacquire()
|
||||||
|
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
|
||||||
|
lock_beat.reacquire()
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
|
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
|
||||||
|
|
||||||
|
lock_beat.reacquire()
|
||||||
|
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
|
||||||
|
lock_beat.reacquire()
|
||||||
|
with get_session_with_tenant(tenant_id) as db_session:
|
||||||
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
|
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
|
||||||
|
|
||||||
# uncomment for debugging if needed
|
# uncomment for debugging if needed
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
"""Entry point for running celery worker / celery beat."""
|
"""Factory stub for running celery worker / celery beat."""
|
||||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||||
|
|
||||||
|
|
||||||
set_is_ee_based_on_env_variable()
|
set_is_ee_based_on_env_variable()
|
||||||
celery_app = fetch_versioned_implementation(
|
app = fetch_versioned_implementation(
|
||||||
"danswer.background.celery.celery_app", "celery_app"
|
"danswer.background.celery.apps.beat", "celery_app"
|
||||||
)
|
)
|
17
backend/danswer/background/celery/versioned_apps/heavy.py
Normal file
17
backend/danswer/background/celery/versioned_apps/heavy.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
"""Factory stub for running celery worker / celery beat.
|
||||||
|
This code is different from the primary/beat stubs because there is no EE version to
|
||||||
|
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||||
|
from celery import Celery
|
||||||
|
|
||||||
|
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||||
|
|
||||||
|
set_is_ee_based_on_env_variable()
|
||||||
|
|
||||||
|
|
||||||
|
def get_app() -> Celery:
|
||||||
|
from danswer.background.celery.apps.heavy import celery_app
|
||||||
|
|
||||||
|
return celery_app
|
||||||
|
|
||||||
|
|
||||||
|
app = get_app()
|
17
backend/danswer/background/celery/versioned_apps/indexing.py
Normal file
17
backend/danswer/background/celery/versioned_apps/indexing.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
"""Factory stub for running celery worker / celery beat.
|
||||||
|
This code is different from the primary/beat stubs because there is no EE version to
|
||||||
|
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||||
|
from celery import Celery
|
||||||
|
|
||||||
|
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||||
|
|
||||||
|
set_is_ee_based_on_env_variable()
|
||||||
|
|
||||||
|
|
||||||
|
def get_app() -> Celery:
|
||||||
|
from danswer.background.celery.apps.indexing import celery_app
|
||||||
|
|
||||||
|
return celery_app
|
||||||
|
|
||||||
|
|
||||||
|
app = get_app()
|
17
backend/danswer/background/celery/versioned_apps/light.py
Normal file
17
backend/danswer/background/celery/versioned_apps/light.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
"""Factory stub for running celery worker / celery beat.
|
||||||
|
This code is different from the primary/beat stubs because there is no EE version to
|
||||||
|
fetch. Port over the code in those files if we add an EE version of this worker."""
|
||||||
|
from celery import Celery
|
||||||
|
|
||||||
|
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||||
|
|
||||||
|
set_is_ee_based_on_env_variable()
|
||||||
|
|
||||||
|
|
||||||
|
def get_app() -> Celery:
|
||||||
|
from danswer.background.celery.apps.light import celery_app
|
||||||
|
|
||||||
|
return celery_app
|
||||||
|
|
||||||
|
|
||||||
|
app = get_app()
|
@ -0,0 +1,8 @@
|
|||||||
|
"""Factory stub for running celery worker / celery beat."""
|
||||||
|
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||||
|
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||||
|
|
||||||
|
set_is_ee_based_on_env_variable()
|
||||||
|
app = fetch_versioned_implementation(
|
||||||
|
"danswer.background.celery.apps.primary", "celery_app"
|
||||||
|
)
|
@ -198,6 +198,41 @@ try:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
|
CELERY_BROKER_POOL_LIMIT = CELERY_BROKER_POOL_LIMIT_DEFAULT
|
||||||
|
|
||||||
|
CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT = 24
|
||||||
|
try:
|
||||||
|
CELERY_WORKER_LIGHT_CONCURRENCY = int(
|
||||||
|
os.environ.get(
|
||||||
|
"CELERY_WORKER_LIGHT_CONCURRENCY", CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
CELERY_WORKER_LIGHT_CONCURRENCY = CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT
|
||||||
|
|
||||||
|
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT = 8
|
||||||
|
try:
|
||||||
|
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = int(
|
||||||
|
os.environ.get(
|
||||||
|
"CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER",
|
||||||
|
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = (
|
||||||
|
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
|
||||||
|
)
|
||||||
|
|
||||||
|
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 1
|
||||||
|
try:
|
||||||
|
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
|
||||||
|
if not env_value:
|
||||||
|
env_value = os.environ.get("NUM_INDEXING_WORKERS")
|
||||||
|
|
||||||
|
if not env_value:
|
||||||
|
env_value = str(CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT)
|
||||||
|
CELERY_WORKER_INDEXING_CONCURRENCY = int(env_value)
|
||||||
|
except ValueError:
|
||||||
|
CELERY_WORKER_INDEXING_CONCURRENCY = CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# Connector Configs
|
# Connector Configs
|
||||||
#####
|
#####
|
||||||
|
@ -16,6 +16,7 @@ from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
|||||||
from danswer.background.celery.tasks.pruning.tasks import (
|
from danswer.background.celery.tasks.pruning.tasks import (
|
||||||
try_creating_prune_generator_task,
|
try_creating_prune_generator_task,
|
||||||
)
|
)
|
||||||
|
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||||
from danswer.db.connector_credential_pair import add_credential_to_connector
|
from danswer.db.connector_credential_pair import add_credential_to_connector
|
||||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||||
from danswer.db.connector_credential_pair import remove_credential_from_connector
|
from danswer.db.connector_credential_pair import remove_credential_from_connector
|
||||||
@ -49,6 +50,7 @@ from ee.danswer.background.task_name_builders import (
|
|||||||
)
|
)
|
||||||
from ee.danswer.db.user_group import validate_user_creation_permissions
|
from ee.danswer.db.user_group import validate_user_creation_permissions
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
router = APIRouter(prefix="/manage")
|
router = APIRouter(prefix="/manage")
|
||||||
|
|
||||||
@ -261,7 +263,7 @@ def prune_cc_pair(
|
|||||||
f"{cc_pair.connector.name} connector."
|
f"{cc_pair.connector.name} connector."
|
||||||
)
|
)
|
||||||
tasks_created = try_creating_prune_generator_task(
|
tasks_created = try_creating_prune_generator_task(
|
||||||
cc_pair, db_session, r, current_tenant_id.get()
|
primary_app, cc_pair, db_session, r, current_tenant_id.get()
|
||||||
)
|
)
|
||||||
if not tasks_created:
|
if not tasks_created:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@ -318,7 +320,7 @@ def sync_cc_pair(
|
|||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> StatusResponse[list[int]]:
|
) -> StatusResponse[list[int]]:
|
||||||
# avoiding circular refs
|
# avoiding circular refs
|
||||||
from ee.danswer.background.celery.celery_app import (
|
from ee.danswer.background.celery.apps.primary import (
|
||||||
sync_external_doc_permissions_task,
|
sync_external_doc_permissions_task,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ from danswer.auth.users import current_user
|
|||||||
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
from danswer.background.celery.celery_redis import RedisConnectorIndexing
|
||||||
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||||
from danswer.background.celery.tasks.indexing.tasks import try_creating_indexing_task
|
from danswer.background.celery.tasks.indexing.tasks import try_creating_indexing_task
|
||||||
|
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||||
from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.configs.constants import FileOrigin
|
from danswer.configs.constants import FileOrigin
|
||||||
@ -834,6 +835,7 @@ def connector_run_once(
|
|||||||
for cc_pair in connector_credential_pairs:
|
for cc_pair in connector_credential_pairs:
|
||||||
if cc_pair is not None:
|
if cc_pair is not None:
|
||||||
attempt_id = try_creating_indexing_task(
|
attempt_id = try_creating_indexing_task(
|
||||||
|
primary_app,
|
||||||
cc_pair,
|
cc_pair,
|
||||||
search_settings,
|
search_settings,
|
||||||
run_info.from_beginning,
|
run_info.from_beginning,
|
||||||
|
@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from danswer.auth.users import current_admin_user
|
from danswer.auth.users import current_admin_user
|
||||||
from danswer.auth.users import current_curator_or_admin_user
|
from danswer.auth.users import current_curator_or_admin_user
|
||||||
from danswer.background.celery.celery_app import celery_app
|
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||||
from danswer.configs.constants import DanswerCeleryPriority
|
from danswer.configs.constants import DanswerCeleryPriority
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
@ -195,7 +195,7 @@ def create_deletion_attempt_for_connector_id(
|
|||||||
db_session.commit()
|
db_session.commit()
|
||||||
|
|
||||||
# run the beat task to pick up this deletion from the db immediately
|
# run the beat task to pick up this deletion from the db immediately
|
||||||
celery_app.send_task(
|
primary_app.send_task(
|
||||||
"check_for_connector_deletion_task",
|
"check_for_connector_deletion_task",
|
||||||
priority=DanswerCeleryPriority.HIGH,
|
priority=DanswerCeleryPriority.HIGH,
|
||||||
kwargs={"tenant_id": tenant_id},
|
kwargs={"tenant_id": tenant_id},
|
||||||
|
52
backend/ee/danswer/background/celery/apps/beat.py
Normal file
52
backend/ee/danswer/background/celery/apps/beat.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
#####
|
||||||
|
# Celery Beat (Periodic Tasks) Settings
|
||||||
|
#####
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from danswer.background.celery.apps.beat import celery_app
|
||||||
|
from danswer.db.engine import get_all_tenant_ids
|
||||||
|
|
||||||
|
|
||||||
|
tenant_ids = get_all_tenant_ids()
|
||||||
|
|
||||||
|
tasks_to_schedule = [
|
||||||
|
{
|
||||||
|
"name": "sync-external-doc-permissions",
|
||||||
|
"task": "check_sync_external_doc_permissions_task",
|
||||||
|
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "sync-external-group-permissions",
|
||||||
|
"task": "check_sync_external_group_permissions_task",
|
||||||
|
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "autogenerate_usage_report",
|
||||||
|
"task": "autogenerate_usage_report_task",
|
||||||
|
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "check-ttl-management",
|
||||||
|
"task": "check_ttl_management_task",
|
||||||
|
"schedule": timedelta(hours=1),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build the celery beat schedule dynamically
|
||||||
|
beat_schedule = {}
|
||||||
|
|
||||||
|
for tenant_id in tenant_ids:
|
||||||
|
for task in tasks_to_schedule:
|
||||||
|
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
||||||
|
beat_schedule[task_name] = {
|
||||||
|
"task": task["task"],
|
||||||
|
"schedule": task["schedule"],
|
||||||
|
"args": (tenant_id,), # Must pass tenant_id as an argument
|
||||||
|
}
|
||||||
|
|
||||||
|
# Include any existing beat schedules
|
||||||
|
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
||||||
|
beat_schedule.update(existing_beat_schedule)
|
||||||
|
|
||||||
|
# Update the Celery app configuration
|
||||||
|
celery_app.conf.beat_schedule = beat_schedule
|
@ -1,11 +1,8 @@
|
|||||||
from datetime import timedelta
|
from danswer.background.celery.apps.primary import celery_app
|
||||||
|
|
||||||
from danswer.background.celery.celery_app import celery_app
|
|
||||||
from danswer.background.task_utils import build_celery_task_wrapper
|
from danswer.background.task_utils import build_celery_task_wrapper
|
||||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||||
from danswer.configs.app_configs import MULTI_TENANT
|
from danswer.configs.app_configs import MULTI_TENANT
|
||||||
from danswer.db.chat import delete_chat_sessions_older_than
|
from danswer.db.chat import delete_chat_sessions_older_than
|
||||||
from danswer.db.engine import get_all_tenant_ids
|
|
||||||
from danswer.db.engine import get_session_with_tenant
|
from danswer.db.engine import get_session_with_tenant
|
||||||
from danswer.server.settings.store import load_settings
|
from danswer.server.settings.store import load_settings
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@ -138,53 +135,3 @@ def autogenerate_usage_report_task(tenant_id: str | None) -> None:
|
|||||||
user_id=None,
|
user_id=None,
|
||||||
period=None,
|
period=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
#####
|
|
||||||
# Celery Beat (Periodic Tasks) Settings
|
|
||||||
#####
|
|
||||||
|
|
||||||
|
|
||||||
tenant_ids = get_all_tenant_ids()
|
|
||||||
|
|
||||||
tasks_to_schedule = [
|
|
||||||
{
|
|
||||||
"name": "sync-external-doc-permissions",
|
|
||||||
"task": "check_sync_external_doc_permissions_task",
|
|
||||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "sync-external-group-permissions",
|
|
||||||
"task": "check_sync_external_group_permissions_task",
|
|
||||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "autogenerate_usage_report",
|
|
||||||
"task": "autogenerate_usage_report_task",
|
|
||||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "check-ttl-management",
|
|
||||||
"task": "check_ttl_management_task",
|
|
||||||
"schedule": timedelta(hours=1),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Build the celery beat schedule dynamically
|
|
||||||
beat_schedule = {}
|
|
||||||
|
|
||||||
for tenant_id in tenant_ids:
|
|
||||||
for task in tasks_to_schedule:
|
|
||||||
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
|
||||||
beat_schedule[task_name] = {
|
|
||||||
"task": task["task"],
|
|
||||||
"schedule": task["schedule"],
|
|
||||||
"args": (tenant_id,), # Must pass tenant_id as an argument
|
|
||||||
}
|
|
||||||
|
|
||||||
# Include any existing beat schedules
|
|
||||||
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
|
||||||
beat_schedule.update(existing_beat_schedule)
|
|
||||||
|
|
||||||
# Update the Celery app configuration
|
|
||||||
celery_app.conf.beat_schedule = beat_schedule
|
|
@ -3,7 +3,7 @@ from typing import cast
|
|||||||
from redis import Redis
|
from redis import Redis
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.background.celery.celery_app import task_logger
|
from danswer.background.celery.apps.app_base import task_logger
|
||||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from ee.danswer.db.user_group import delete_user_group
|
from ee.danswer.db.user_group import delete_user_group
|
||||||
|
@ -20,14 +20,13 @@ def run_jobs() -> None:
|
|||||||
cmd_worker_primary = [
|
cmd_worker_primary = [
|
||||||
"celery",
|
"celery",
|
||||||
"-A",
|
"-A",
|
||||||
"ee.danswer.background.celery.celery_app",
|
"danswer.background.celery.versioned_apps.primary",
|
||||||
"worker",
|
"worker",
|
||||||
"--pool=threads",
|
"--pool=threads",
|
||||||
"--concurrency=6",
|
"--concurrency=6",
|
||||||
"--prefetch-multiplier=1",
|
"--prefetch-multiplier=1",
|
||||||
"--loglevel=INFO",
|
"--loglevel=INFO",
|
||||||
"-n",
|
"--hostname=primary@%n",
|
||||||
"primary@%n",
|
|
||||||
"-Q",
|
"-Q",
|
||||||
"celery",
|
"celery",
|
||||||
]
|
]
|
||||||
@ -35,14 +34,13 @@ def run_jobs() -> None:
|
|||||||
cmd_worker_light = [
|
cmd_worker_light = [
|
||||||
"celery",
|
"celery",
|
||||||
"-A",
|
"-A",
|
||||||
"ee.danswer.background.celery.celery_app",
|
"danswer.background.celery.versioned_apps.light",
|
||||||
"worker",
|
"worker",
|
||||||
"--pool=threads",
|
"--pool=threads",
|
||||||
"--concurrency=16",
|
"--concurrency=16",
|
||||||
"--prefetch-multiplier=8",
|
"--prefetch-multiplier=8",
|
||||||
"--loglevel=INFO",
|
"--loglevel=INFO",
|
||||||
"-n",
|
"--hostname=light@%n",
|
||||||
"light@%n",
|
|
||||||
"-Q",
|
"-Q",
|
||||||
"vespa_metadata_sync,connector_deletion",
|
"vespa_metadata_sync,connector_deletion",
|
||||||
]
|
]
|
||||||
@ -50,14 +48,13 @@ def run_jobs() -> None:
|
|||||||
cmd_worker_heavy = [
|
cmd_worker_heavy = [
|
||||||
"celery",
|
"celery",
|
||||||
"-A",
|
"-A",
|
||||||
"ee.danswer.background.celery.celery_app",
|
"danswer.background.celery.versioned_apps.heavy",
|
||||||
"worker",
|
"worker",
|
||||||
"--pool=threads",
|
"--pool=threads",
|
||||||
"--concurrency=6",
|
"--concurrency=6",
|
||||||
"--prefetch-multiplier=1",
|
"--prefetch-multiplier=1",
|
||||||
"--loglevel=INFO",
|
"--loglevel=INFO",
|
||||||
"-n",
|
"--hostname=heavy@%n",
|
||||||
"heavy@%n",
|
|
||||||
"-Q",
|
"-Q",
|
||||||
"connector_pruning",
|
"connector_pruning",
|
||||||
]
|
]
|
||||||
@ -65,21 +62,20 @@ def run_jobs() -> None:
|
|||||||
cmd_worker_indexing = [
|
cmd_worker_indexing = [
|
||||||
"celery",
|
"celery",
|
||||||
"-A",
|
"-A",
|
||||||
"ee.danswer.background.celery.celery_app",
|
"danswer.background.celery.versioned_apps.indexing",
|
||||||
"worker",
|
"worker",
|
||||||
"--pool=threads",
|
"--pool=threads",
|
||||||
"--concurrency=1",
|
"--concurrency=1",
|
||||||
"--prefetch-multiplier=1",
|
"--prefetch-multiplier=1",
|
||||||
"--loglevel=INFO",
|
"--loglevel=INFO",
|
||||||
"-n",
|
"--hostname=indexing@%n",
|
||||||
"indexing@%n",
|
|
||||||
"--queues=connector_indexing",
|
"--queues=connector_indexing",
|
||||||
]
|
]
|
||||||
|
|
||||||
cmd_beat = [
|
cmd_beat = [
|
||||||
"celery",
|
"celery",
|
||||||
"-A",
|
"-A",
|
||||||
"ee.danswer.background.celery.celery_app",
|
"danswer.background.celery.versioned_apps.beat",
|
||||||
"beat",
|
"beat",
|
||||||
"--loglevel=INFO",
|
"--loglevel=INFO",
|
||||||
]
|
]
|
||||||
|
@ -15,10 +15,7 @@ logfile=/var/log/supervisord.log
|
|||||||
# relatively compute-light (e.g. they tend to just make a bunch of requests to
|
# relatively compute-light (e.g. they tend to just make a bunch of requests to
|
||||||
# Vespa / Postgres)
|
# Vespa / Postgres)
|
||||||
[program:celery_worker_primary]
|
[program:celery_worker_primary]
|
||||||
command=celery -A danswer.background.celery.celery_run:celery_app worker
|
command=celery -A danswer.background.celery.versioned_apps.primary worker
|
||||||
--pool=threads
|
|
||||||
--concurrency=4
|
|
||||||
--prefetch-multiplier=1
|
|
||||||
--loglevel=INFO
|
--loglevel=INFO
|
||||||
--hostname=primary@%%n
|
--hostname=primary@%%n
|
||||||
-Q celery
|
-Q celery
|
||||||
@ -33,13 +30,10 @@ stopasgroup=true
|
|||||||
# since this is often the bottleneck for "sync" jobs (e.g. document set syncing,
|
# since this is often the bottleneck for "sync" jobs (e.g. document set syncing,
|
||||||
# user group syncing, deletion, etc.)
|
# user group syncing, deletion, etc.)
|
||||||
[program:celery_worker_light]
|
[program:celery_worker_light]
|
||||||
command=bash -c "celery -A danswer.background.celery.celery_run:celery_app worker \
|
command=celery -A danswer.background.celery.versioned_apps.light worker
|
||||||
--pool=threads \
|
--loglevel=INFO
|
||||||
--concurrency=${CELERY_WORKER_LIGHT_CONCURRENCY:-24} \
|
--hostname=light@%%n
|
||||||
--prefetch-multiplier=${CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER:-8} \
|
-Q vespa_metadata_sync,connector_deletion
|
||||||
--loglevel=INFO \
|
|
||||||
--hostname=light@%%n \
|
|
||||||
-Q vespa_metadata_sync,connector_deletion"
|
|
||||||
stdout_logfile=/var/log/celery_worker_light.log
|
stdout_logfile=/var/log/celery_worker_light.log
|
||||||
stdout_logfile_maxbytes=16MB
|
stdout_logfile_maxbytes=16MB
|
||||||
redirect_stderr=true
|
redirect_stderr=true
|
||||||
@ -48,10 +42,7 @@ startsecs=10
|
|||||||
stopasgroup=true
|
stopasgroup=true
|
||||||
|
|
||||||
[program:celery_worker_heavy]
|
[program:celery_worker_heavy]
|
||||||
command=celery -A danswer.background.celery.celery_run:celery_app worker
|
command=celery -A danswer.background.celery.versioned_apps.heavy worker
|
||||||
--pool=threads
|
|
||||||
--concurrency=4
|
|
||||||
--prefetch-multiplier=1
|
|
||||||
--loglevel=INFO
|
--loglevel=INFO
|
||||||
--hostname=heavy@%%n
|
--hostname=heavy@%%n
|
||||||
-Q connector_pruning
|
-Q connector_pruning
|
||||||
@ -63,13 +54,10 @@ startsecs=10
|
|||||||
stopasgroup=true
|
stopasgroup=true
|
||||||
|
|
||||||
[program:celery_worker_indexing]
|
[program:celery_worker_indexing]
|
||||||
command=bash -c "celery -A danswer.background.celery.celery_run:celery_app worker \
|
command=celery -A danswer.background.celery.versioned_apps.indexing worker
|
||||||
--pool=threads \
|
--loglevel=INFO
|
||||||
--concurrency=${CELERY_WORKER_INDEXING_CONCURRENCY:-${NUM_INDEXING_WORKERS:-1}} \
|
--hostname=indexing@%%n
|
||||||
--prefetch-multiplier=1 \
|
-Q connector_indexing
|
||||||
--loglevel=INFO \
|
|
||||||
--hostname=indexing@%%n \
|
|
||||||
-Q connector_indexing"
|
|
||||||
stdout_logfile=/var/log/celery_worker_indexing.log
|
stdout_logfile=/var/log/celery_worker_indexing.log
|
||||||
stdout_logfile_maxbytes=16MB
|
stdout_logfile_maxbytes=16MB
|
||||||
redirect_stderr=true
|
redirect_stderr=true
|
||||||
@ -79,7 +67,7 @@ stopasgroup=true
|
|||||||
|
|
||||||
# Job scheduler for periodic tasks
|
# Job scheduler for periodic tasks
|
||||||
[program:celery_beat]
|
[program:celery_beat]
|
||||||
command=celery -A danswer.background.celery.celery_run:celery_app beat
|
command=celery -A danswer.background.celery.versioned_apps.beat beat
|
||||||
stdout_logfile=/var/log/celery_beat.log
|
stdout_logfile=/var/log/celery_beat.log
|
||||||
stdout_logfile_maxbytes=16MB
|
stdout_logfile_maxbytes=16MB
|
||||||
redirect_stderr=true
|
redirect_stderr=true
|
||||||
|
Loading…
x
Reference in New Issue
Block a user