From 47f947b0456328be3f2e60292338f6a3104a4a6c Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Fri, 24 Nov 2023 18:29:28 -0800 Subject: [PATCH] Use torch.multiprocessing + enable SimpleJobClient by default (#765) --- backend/danswer/background/indexing/job_client.py | 5 +++-- backend/danswer/background/update.py | 14 ++++++++++---- backend/danswer/configs/app_configs.py | 4 ++-- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/backend/danswer/background/indexing/job_client.py b/backend/danswer/background/indexing/job_client.py index c3734d7cf..8e22f8e45 100644 --- a/backend/danswer/background/indexing/job_client.py +++ b/backend/danswer/background/indexing/job_client.py @@ -4,12 +4,13 @@ not follow the expected behavior, etc. NOTE: cannot use Celery directly due to https://github.com/celery/celery/issues/7007#issuecomment-1740139367""" -import multiprocessing from collections.abc import Callable from dataclasses import dataclass from typing import Any from typing import Literal +from torch import multiprocessing + from danswer.utils.logger import setup_logger logger = setup_logger() @@ -94,7 +95,7 @@ class SimpleJobClient: job_id = self.job_id_counter self.job_id_counter += 1 - process = multiprocessing.Process(target=func, args=args) + process = multiprocessing.Process(target=func, args=args, daemon=True) job = SimpleJob(id=job_id, process=process) process.start() diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 941fe4f83..e76901cd2 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -13,7 +13,7 @@ from danswer.background.indexing.dask_utils import ResourceLogger from danswer.background.indexing.job_client import SimpleJob from danswer.background.indexing.job_client import SimpleJobClient from danswer.background.indexing.run_indexing import run_indexing_entrypoint -from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED +from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED from danswer.configs.app_configs import LOG_LEVEL from danswer.configs.app_configs import MODEL_SERVER_HOST from danswer.configs.app_configs import NUM_INDEXING_WORKERS @@ -273,9 +273,7 @@ def kickoff_indexing_jobs( def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None: client: Client | SimpleJobClient - if EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED: - client = SimpleJobClient(n_workers=num_workers) - else: + if DASK_JOB_CLIENT_ENABLED: cluster = LocalCluster( n_workers=num_workers, threads_per_worker=1, @@ -288,6 +286,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non client = Client(cluster) if LOG_LEVEL.lower() == "debug": client.register_worker_plugin(ResourceLogger()) + else: + client = SimpleJobClient(n_workers=num_workers) existing_jobs: dict[int, Future | SimpleJob] = {} engine = get_sqlalchemy_engine() @@ -322,6 +322,12 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non if __name__ == "__main__": + # needed for CUDA to work with multiprocessing + # NOTE: needs to be done on application startup + # before any other torch code has been run + if not DASK_JOB_CLIENT_ENABLED: + torch.multiprocessing.set_start_method("spawn") + if not MODEL_SERVER_HOST: logger.info("Warming up Embedding Model(s)") warm_up_models(indexer_only=True) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index d4234ab17..0d952eafc 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -141,8 +141,8 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [ GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME") -EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED = ( - os.environ.get("EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED", "").lower() == "true" +DASK_JOB_CLIENT_ENABLED = ( + os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true" ) EXPERIMENTAL_CHECKPOINTING_ENABLED = ( os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"