diff --git a/backend/danswer/background/indexing/job_client.py b/backend/danswer/background/indexing/job_client.py new file mode 100644 index 000000000..c3734d7cf --- /dev/null +++ b/backend/danswer/background/indexing/job_client.py @@ -0,0 +1,103 @@ +"""Custom client that works similarly to Dask, but simpler and more lightweight. +Dask jobs behaved very strangely - they would die all the time, retries would +not follow the expected behavior, etc. + +NOTE: cannot use Celery directly due to +https://github.com/celery/celery/issues/7007#issuecomment-1740139367""" +import multiprocessing +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any +from typing import Literal + +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +JobStatusType = ( + Literal["error"] + | Literal["finished"] + | Literal["pending"] + | Literal["running"] + | Literal["cancelled"] +) + + +@dataclass +class SimpleJob: + """Drop in replacement for `dask.distributed.Future`""" + + id: int + process: multiprocessing.Process | None = None + + def cancel(self) -> bool: + return self.release() + + def release(self) -> bool: + if self.process is not None and self.process.is_alive(): + self.process.terminate() + return True + return False + + @property + def status(self) -> JobStatusType: + if not self.process: + return "pending" + elif self.process.is_alive(): + return "running" + elif self.process.exitcode is None: + return "cancelled" + elif self.process.exitcode > 0: + return "error" + else: + return "finished" + + def done(self) -> bool: + return ( + self.status == "finished" + or self.status == "cancelled" + or self.status == "error" + ) + + def exception(self) -> str: + """Needed to match the Dask API, but not implemented since we don't currently + have a way to get back the exception information from the child process.""" + return ( + f"Job with ID '{self.id}' was killed or encountered an unhandled exception." + ) + + +class SimpleJobClient: + """Drop in replacement for `dask.distributed.Client`""" + + def __init__(self, n_workers: int = 1) -> None: + self.n_workers = n_workers + self.job_id_counter = 0 + self.jobs: dict[int, SimpleJob] = {} + + def _cleanup_completed_jobs(self) -> None: + current_job_ids = list(self.jobs.keys()) + for job_id in current_job_ids: + job = self.jobs.get(job_id) + if job and job.done(): + logger.debug(f"Cleaning up job with id: '{job.id}'") + del self.jobs[job.id] + + def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None: + """NOTE: `pure` arg is needed so this can be a drop in replacement for Dask""" + self._cleanup_completed_jobs() + if len(self.jobs) >= self.n_workers: + logger.debug("No available workers to run job") + return None + + job_id = self.job_id_counter + self.job_id_counter += 1 + + process = multiprocessing.Process(target=func, args=args) + job = SimpleJob(id=job_id, process=process) + process.start() + + self.jobs[job_id] = job + + return job diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 7e22dfab9..96523a9de 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -10,6 +10,9 @@ from dask.distributed import Future from distributed import LocalCluster from sqlalchemy.orm import Session +from danswer.background.indexing.job_client import SimpleJob +from danswer.background.indexing.job_client import SimpleJobClient +from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.configs.model_configs import MIN_THREADS_ML_MODELS from danswer.connectors.factory import instantiate_connector @@ -99,7 +102,9 @@ def mark_run_failed( ) -def create_indexing_jobs(db_session: Session, existing_jobs: dict[int, Future]) -> None: +def create_indexing_jobs( + db_session: Session, existing_jobs: dict[int, Future | SimpleJob] +) -> None: """Creates new indexing jobs for each connector / credential pair which is: 1. Enabled 2. `refresh_frequency` time has passed since the last indexing run for this pair @@ -139,8 +144,8 @@ def create_indexing_jobs(db_session: Session, existing_jobs: dict[int, Future]) def cleanup_indexing_jobs( - db_session: Session, existing_jobs: dict[int, Future] -) -> dict[int, Future]: + db_session: Session, existing_jobs: dict[int, Future | SimpleJob] +) -> dict[int, Future | SimpleJob]: existing_jobs_copy = existing_jobs.copy() # clean up completed jobs @@ -421,9 +426,9 @@ def _run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None: def kickoff_indexing_jobs( db_session: Session, - existing_jobs: dict[int, Future], - client: Client, -) -> dict[int, Future]: + existing_jobs: dict[int, Future | SimpleJob], + client: Client | SimpleJobClient, +) -> dict[int, Future | SimpleJob]: existing_jobs_copy = existing_jobs.copy() # Don't include jobs waiting in the Dask queue that just haven't started running @@ -455,31 +460,37 @@ def kickoff_indexing_jobs( ) continue - logger.info( - f"Kicking off indexing attempt for connector: '{attempt.connector.name}', " - f"with config: '{attempt.connector.connector_specific_config}', and " - f"with credentials: '{attempt.credential_id}'" - ) run = client.submit( _run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False ) - existing_jobs_copy[attempt.id] = run + if run: + logger.info( + f"Kicked off indexing attempt for connector: '{attempt.connector.name}', " + f"with config: '{attempt.connector.connector_specific_config}', and " + f"with credentials: '{attempt.credential_id}'" + ) + existing_jobs_copy[attempt.id] = run return existing_jobs_copy def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None: - cluster = LocalCluster( - n_workers=num_workers, - threads_per_worker=1, - # there are warning about high memory usage + "Event loop unresponsive" - # which are not relevant to us since our workers are expected to use a - # lot of memory + involve CPU intensive tasks that will not relinquish - # the event loop - silence_logs=logging.ERROR, - ) - client = Client(cluster) - existing_jobs: dict[int, Future] = {} + client: Client | SimpleJobClient + if EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED: + client = SimpleJobClient(n_workers=num_workers) + else: + cluster = LocalCluster( + n_workers=num_workers, + threads_per_worker=1, + # there are warning about high memory usage + "Event loop unresponsive" + # which are not relevant to us since our workers are expected to use a + # lot of memory + involve CPU intensive tasks that will not relinquish + # the event loop + silence_logs=logging.ERROR, + ) + client = Client(cluster) + + existing_jobs: dict[int, Future | SimpleJob] = {} engine = get_sqlalchemy_engine() with Session(engine) as db_session: diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index c8cee5b7b..538a4766b 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -140,6 +140,10 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [ GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME") +EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED = ( + os.environ.get("EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED", "").lower() == "true" +) + ##### # Query Configs #####