mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Use torch.multiprocessing + enable SimpleJobClient by default (#765)
This commit is contained in:
parent
63b051b342
commit
47f947b045
@ -4,12 +4,13 @@ not follow the expected behavior, etc.
|
||||
|
||||
NOTE: cannot use Celery directly due to
|
||||
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||
import multiprocessing
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from torch import multiprocessing
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -94,7 +95,7 @@ class SimpleJobClient:
|
||||
job_id = self.job_id_counter
|
||||
self.job_id_counter += 1
|
||||
|
||||
process = multiprocessing.Process(target=func, args=args)
|
||||
process = multiprocessing.Process(target=func, args=args, daemon=True)
|
||||
job = SimpleJob(id=job_id, process=process)
|
||||
process.start()
|
||||
|
||||
|
@ -13,7 +13,7 @@ from danswer.background.indexing.dask_utils import ResourceLogger
|
||||
from danswer.background.indexing.job_client import SimpleJob
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import LOG_LEVEL
|
||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
@ -273,9 +273,7 @@ def kickoff_indexing_jobs(
|
||||
|
||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||
client: Client | SimpleJobClient
|
||||
if EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED:
|
||||
client = SimpleJobClient(n_workers=num_workers)
|
||||
else:
|
||||
if DASK_JOB_CLIENT_ENABLED:
|
||||
cluster = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
threads_per_worker=1,
|
||||
@ -288,6 +286,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
client = Client(cluster)
|
||||
if LOG_LEVEL.lower() == "debug":
|
||||
client.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client = SimpleJobClient(n_workers=num_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
engine = get_sqlalchemy_engine()
|
||||
@ -322,6 +322,12 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# needed for CUDA to work with multiprocessing
|
||||
# NOTE: needs to be done on application startup
|
||||
# before any other torch code has been run
|
||||
if not DASK_JOB_CLIENT_ENABLED:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
if not MODEL_SERVER_HOST:
|
||||
logger.info("Warming up Embedding Model(s)")
|
||||
warm_up_models(indexer_only=True)
|
||||
|
@ -141,8 +141,8 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
|
||||
|
||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||
|
||||
EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
|
||||
|
Loading…
x
Reference in New Issue
Block a user