Use torch.multiprocessing + enable SimpleJobClient by default (#765)

This commit is contained in:
Chris Weaver
2023-11-24 18:29:28 -08:00
committed by GitHub
parent 63b051b342
commit 47f947b045
3 changed files with 15 additions and 8 deletions

View File

@@ -4,12 +4,13 @@ not follow the expected behavior, etc.
NOTE: cannot use Celery directly due to NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367""" https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
import multiprocessing
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from typing import Literal from typing import Literal
from torch import multiprocessing
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@@ -94,7 +95,7 @@ class SimpleJobClient:
job_id = self.job_id_counter job_id = self.job_id_counter
self.job_id_counter += 1 self.job_id_counter += 1
process = multiprocessing.Process(target=func, args=args) process = multiprocessing.Process(target=func, args=args, daemon=True)
job = SimpleJob(id=job_id, process=process) job = SimpleJob(id=job_id, process=process)
process.start() process.start()

View File

@@ -13,7 +13,7 @@ from danswer.background.indexing.dask_utils import ResourceLogger
from danswer.background.indexing.job_client import SimpleJob from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import LOG_LEVEL from danswer.configs.app_configs import LOG_LEVEL
from danswer.configs.app_configs import MODEL_SERVER_HOST from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.configs.app_configs import NUM_INDEXING_WORKERS
@@ -273,9 +273,7 @@ def kickoff_indexing_jobs(
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None: def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
client: Client | SimpleJobClient client: Client | SimpleJobClient
if EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED: if DASK_JOB_CLIENT_ENABLED:
client = SimpleJobClient(n_workers=num_workers)
else:
cluster = LocalCluster( cluster = LocalCluster(
n_workers=num_workers, n_workers=num_workers,
threads_per_worker=1, threads_per_worker=1,
@@ -288,6 +286,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
client = Client(cluster) client = Client(cluster)
if LOG_LEVEL.lower() == "debug": if LOG_LEVEL.lower() == "debug":
client.register_worker_plugin(ResourceLogger()) client.register_worker_plugin(ResourceLogger())
else:
client = SimpleJobClient(n_workers=num_workers)
existing_jobs: dict[int, Future | SimpleJob] = {} existing_jobs: dict[int, Future | SimpleJob] = {}
engine = get_sqlalchemy_engine() engine = get_sqlalchemy_engine()
@@ -322,6 +322,12 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
if __name__ == "__main__": if __name__ == "__main__":
# needed for CUDA to work with multiprocessing
# NOTE: needs to be done on application startup
# before any other torch code has been run
if not DASK_JOB_CLIENT_ENABLED:
torch.multiprocessing.set_start_method("spawn")
if not MODEL_SERVER_HOST: if not MODEL_SERVER_HOST:
logger.info("Warming up Embedding Model(s)") logger.info("Warming up Embedding Model(s)")
warm_up_models(indexer_only=True) warm_up_models(indexer_only=True)

View File

@@ -141,8 +141,8 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME") GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED = ( DASK_JOB_CLIENT_ENABLED = (
os.environ.get("EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED", "").lower() == "true" os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
) )
EXPERIMENTAL_CHECKPOINTING_ENABLED = ( EXPERIMENTAL_CHECKPOINTING_ENABLED = (
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true" os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"