mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-01 18:20:49 +02:00
Add simple job client to try and get rid of some of the flakiness / weirdness that we are seeing with Dask
This commit is contained in:
103
backend/danswer/background/indexing/job_client.py
Normal file
103
backend/danswer/background/indexing/job_client.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
"""Custom client that works similarly to Dask, but simpler and more lightweight.
|
||||||
|
Dask jobs behaved very strangely - they would die all the time, retries would
|
||||||
|
not follow the expected behavior, etc.
|
||||||
|
|
||||||
|
NOTE: cannot use Celery directly due to
|
||||||
|
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||||
|
import multiprocessing
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
JobStatusType = (
|
||||||
|
Literal["error"]
|
||||||
|
| Literal["finished"]
|
||||||
|
| Literal["pending"]
|
||||||
|
| Literal["running"]
|
||||||
|
| Literal["cancelled"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SimpleJob:
|
||||||
|
"""Drop in replacement for `dask.distributed.Future`"""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
process: multiprocessing.Process | None = None
|
||||||
|
|
||||||
|
def cancel(self) -> bool:
|
||||||
|
return self.release()
|
||||||
|
|
||||||
|
def release(self) -> bool:
|
||||||
|
if self.process is not None and self.process.is_alive():
|
||||||
|
self.process.terminate()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def status(self) -> JobStatusType:
|
||||||
|
if not self.process:
|
||||||
|
return "pending"
|
||||||
|
elif self.process.is_alive():
|
||||||
|
return "running"
|
||||||
|
elif self.process.exitcode is None:
|
||||||
|
return "cancelled"
|
||||||
|
elif self.process.exitcode > 0:
|
||||||
|
return "error"
|
||||||
|
else:
|
||||||
|
return "finished"
|
||||||
|
|
||||||
|
def done(self) -> bool:
|
||||||
|
return (
|
||||||
|
self.status == "finished"
|
||||||
|
or self.status == "cancelled"
|
||||||
|
or self.status == "error"
|
||||||
|
)
|
||||||
|
|
||||||
|
def exception(self) -> str:
|
||||||
|
"""Needed to match the Dask API, but not implemented since we don't currently
|
||||||
|
have a way to get back the exception information from the child process."""
|
||||||
|
return (
|
||||||
|
f"Job with ID '{self.id}' was killed or encountered an unhandled exception."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleJobClient:
|
||||||
|
"""Drop in replacement for `dask.distributed.Client`"""
|
||||||
|
|
||||||
|
def __init__(self, n_workers: int = 1) -> None:
|
||||||
|
self.n_workers = n_workers
|
||||||
|
self.job_id_counter = 0
|
||||||
|
self.jobs: dict[int, SimpleJob] = {}
|
||||||
|
|
||||||
|
def _cleanup_completed_jobs(self) -> None:
|
||||||
|
current_job_ids = list(self.jobs.keys())
|
||||||
|
for job_id in current_job_ids:
|
||||||
|
job = self.jobs.get(job_id)
|
||||||
|
if job and job.done():
|
||||||
|
logger.debug(f"Cleaning up job with id: '{job.id}'")
|
||||||
|
del self.jobs[job.id]
|
||||||
|
|
||||||
|
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
|
||||||
|
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
|
||||||
|
self._cleanup_completed_jobs()
|
||||||
|
if len(self.jobs) >= self.n_workers:
|
||||||
|
logger.debug("No available workers to run job")
|
||||||
|
return None
|
||||||
|
|
||||||
|
job_id = self.job_id_counter
|
||||||
|
self.job_id_counter += 1
|
||||||
|
|
||||||
|
process = multiprocessing.Process(target=func, args=args)
|
||||||
|
job = SimpleJob(id=job_id, process=process)
|
||||||
|
process.start()
|
||||||
|
|
||||||
|
self.jobs[job_id] = job
|
||||||
|
|
||||||
|
return job
|
@ -10,6 +10,9 @@ from dask.distributed import Future
|
|||||||
from distributed import LocalCluster
|
from distributed import LocalCluster
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.background.indexing.job_client import SimpleJob
|
||||||
|
from danswer.background.indexing.job_client import SimpleJobClient
|
||||||
|
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED
|
||||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||||
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
||||||
from danswer.connectors.factory import instantiate_connector
|
from danswer.connectors.factory import instantiate_connector
|
||||||
@ -99,7 +102,9 @@ def mark_run_failed(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_indexing_jobs(db_session: Session, existing_jobs: dict[int, Future]) -> None:
|
def create_indexing_jobs(
|
||||||
|
db_session: Session, existing_jobs: dict[int, Future | SimpleJob]
|
||||||
|
) -> None:
|
||||||
"""Creates new indexing jobs for each connector / credential pair which is:
|
"""Creates new indexing jobs for each connector / credential pair which is:
|
||||||
1. Enabled
|
1. Enabled
|
||||||
2. `refresh_frequency` time has passed since the last indexing run for this pair
|
2. `refresh_frequency` time has passed since the last indexing run for this pair
|
||||||
@ -139,8 +144,8 @@ def create_indexing_jobs(db_session: Session, existing_jobs: dict[int, Future])
|
|||||||
|
|
||||||
|
|
||||||
def cleanup_indexing_jobs(
|
def cleanup_indexing_jobs(
|
||||||
db_session: Session, existing_jobs: dict[int, Future]
|
db_session: Session, existing_jobs: dict[int, Future | SimpleJob]
|
||||||
) -> dict[int, Future]:
|
) -> dict[int, Future | SimpleJob]:
|
||||||
existing_jobs_copy = existing_jobs.copy()
|
existing_jobs_copy = existing_jobs.copy()
|
||||||
|
|
||||||
# clean up completed jobs
|
# clean up completed jobs
|
||||||
@ -421,9 +426,9 @@ def _run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
|
|||||||
|
|
||||||
def kickoff_indexing_jobs(
|
def kickoff_indexing_jobs(
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
existing_jobs: dict[int, Future],
|
existing_jobs: dict[int, Future | SimpleJob],
|
||||||
client: Client,
|
client: Client | SimpleJobClient,
|
||||||
) -> dict[int, Future]:
|
) -> dict[int, Future | SimpleJob]:
|
||||||
existing_jobs_copy = existing_jobs.copy()
|
existing_jobs_copy = existing_jobs.copy()
|
||||||
|
|
||||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||||
@ -455,31 +460,37 @@ def kickoff_indexing_jobs(
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Kicking off indexing attempt for connector: '{attempt.connector.name}', "
|
|
||||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
|
||||||
f"with credentials: '{attempt.credential_id}'"
|
|
||||||
)
|
|
||||||
run = client.submit(
|
run = client.submit(
|
||||||
_run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
|
_run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
|
||||||
)
|
)
|
||||||
existing_jobs_copy[attempt.id] = run
|
if run:
|
||||||
|
logger.info(
|
||||||
|
f"Kicked off indexing attempt for connector: '{attempt.connector.name}', "
|
||||||
|
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||||
|
f"with credentials: '{attempt.credential_id}'"
|
||||||
|
)
|
||||||
|
existing_jobs_copy[attempt.id] = run
|
||||||
|
|
||||||
return existing_jobs_copy
|
return existing_jobs_copy
|
||||||
|
|
||||||
|
|
||||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||||
cluster = LocalCluster(
|
client: Client | SimpleJobClient
|
||||||
n_workers=num_workers,
|
if EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED:
|
||||||
threads_per_worker=1,
|
client = SimpleJobClient(n_workers=num_workers)
|
||||||
# there are warning about high memory usage + "Event loop unresponsive"
|
else:
|
||||||
# which are not relevant to us since our workers are expected to use a
|
cluster = LocalCluster(
|
||||||
# lot of memory + involve CPU intensive tasks that will not relinquish
|
n_workers=num_workers,
|
||||||
# the event loop
|
threads_per_worker=1,
|
||||||
silence_logs=logging.ERROR,
|
# there are warning about high memory usage + "Event loop unresponsive"
|
||||||
)
|
# which are not relevant to us since our workers are expected to use a
|
||||||
client = Client(cluster)
|
# lot of memory + involve CPU intensive tasks that will not relinquish
|
||||||
existing_jobs: dict[int, Future] = {}
|
# the event loop
|
||||||
|
silence_logs=logging.ERROR,
|
||||||
|
)
|
||||||
|
client = Client(cluster)
|
||||||
|
|
||||||
|
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||||
engine = get_sqlalchemy_engine()
|
engine = get_sqlalchemy_engine()
|
||||||
|
|
||||||
with Session(engine) as db_session:
|
with Session(engine) as db_session:
|
||||||
|
@ -140,6 +140,10 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
|
|||||||
|
|
||||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||||
|
|
||||||
|
EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED = (
|
||||||
|
os.environ.get("EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# Query Configs
|
# Query Configs
|
||||||
#####
|
#####
|
||||||
|
Reference in New Issue
Block a user