Use torch.multiprocessing + enable SimpleJobClient by default (#765)

This commit is contained in:
Chris Weaver 2023-11-24 18:29:28 -08:00 committed by GitHub
parent 63b051b342
commit 47f947b045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 8 deletions

View File

@ -4,12 +4,13 @@ not follow the expected behavior, etc.
NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
import multiprocessing
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from typing import Literal
from torch import multiprocessing
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -94,7 +95,7 @@ class SimpleJobClient:
job_id = self.job_id_counter
self.job_id_counter += 1
process = multiprocessing.Process(target=func, args=args)
process = multiprocessing.Process(target=func, args=args, daemon=True)
job = SimpleJob(id=job_id, process=process)
process.start()

View File

@ -13,7 +13,7 @@ from danswer.background.indexing.dask_utils import ResourceLogger
from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import LOG_LEVEL
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
@ -273,9 +273,7 @@ def kickoff_indexing_jobs(
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
client: Client | SimpleJobClient
if EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED:
client = SimpleJobClient(n_workers=num_workers)
else:
if DASK_JOB_CLIENT_ENABLED:
cluster = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
@ -288,6 +286,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
client = Client(cluster)
if LOG_LEVEL.lower() == "debug":
client.register_worker_plugin(ResourceLogger())
else:
client = SimpleJobClient(n_workers=num_workers)
existing_jobs: dict[int, Future | SimpleJob] = {}
engine = get_sqlalchemy_engine()
@ -322,6 +322,12 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
if __name__ == "__main__":
# needed for CUDA to work with multiprocessing
# NOTE: needs to be done on application startup
# before any other torch code has been run
if not DASK_JOB_CLIENT_ENABLED:
torch.multiprocessing.set_start_method("spawn")
if not MODEL_SERVER_HOST:
logger.info("Warming up Embedding Model(s)")
warm_up_models(indexer_only=True)

View File

@ -141,8 +141,8 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED = (
os.environ.get("EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED", "").lower() == "true"
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)
EXPERIMENTAL_CHECKPOINTING_ENABLED = (
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"