mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-08 14:02:09 +02:00
Use torch.multiprocessing + enable SimpleJobClient by default (#765)
This commit is contained in:
@@ -4,12 +4,13 @@ not follow the expected behavior, etc.
|
|||||||
|
|
||||||
NOTE: cannot use Celery directly due to
|
NOTE: cannot use Celery directly due to
|
||||||
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||||
import multiprocessing
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
from torch import multiprocessing
|
||||||
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@@ -94,7 +95,7 @@ class SimpleJobClient:
|
|||||||
job_id = self.job_id_counter
|
job_id = self.job_id_counter
|
||||||
self.job_id_counter += 1
|
self.job_id_counter += 1
|
||||||
|
|
||||||
process = multiprocessing.Process(target=func, args=args)
|
process = multiprocessing.Process(target=func, args=args, daemon=True)
|
||||||
job = SimpleJob(id=job_id, process=process)
|
job = SimpleJob(id=job_id, process=process)
|
||||||
process.start()
|
process.start()
|
||||||
|
|
||||||
|
@@ -13,7 +13,7 @@ from danswer.background.indexing.dask_utils import ResourceLogger
|
|||||||
from danswer.background.indexing.job_client import SimpleJob
|
from danswer.background.indexing.job_client import SimpleJob
|
||||||
from danswer.background.indexing.job_client import SimpleJobClient
|
from danswer.background.indexing.job_client import SimpleJobClient
|
||||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||||
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED
|
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||||
from danswer.configs.app_configs import LOG_LEVEL
|
from danswer.configs.app_configs import LOG_LEVEL
|
||||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||||
@@ -273,9 +273,7 @@ def kickoff_indexing_jobs(
|
|||||||
|
|
||||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||||
client: Client | SimpleJobClient
|
client: Client | SimpleJobClient
|
||||||
if EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED:
|
if DASK_JOB_CLIENT_ENABLED:
|
||||||
client = SimpleJobClient(n_workers=num_workers)
|
|
||||||
else:
|
|
||||||
cluster = LocalCluster(
|
cluster = LocalCluster(
|
||||||
n_workers=num_workers,
|
n_workers=num_workers,
|
||||||
threads_per_worker=1,
|
threads_per_worker=1,
|
||||||
@@ -288,6 +286,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
|||||||
client = Client(cluster)
|
client = Client(cluster)
|
||||||
if LOG_LEVEL.lower() == "debug":
|
if LOG_LEVEL.lower() == "debug":
|
||||||
client.register_worker_plugin(ResourceLogger())
|
client.register_worker_plugin(ResourceLogger())
|
||||||
|
else:
|
||||||
|
client = SimpleJobClient(n_workers=num_workers)
|
||||||
|
|
||||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||||
engine = get_sqlalchemy_engine()
|
engine = get_sqlalchemy_engine()
|
||||||
@@ -322,6 +322,12 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# needed for CUDA to work with multiprocessing
|
||||||
|
# NOTE: needs to be done on application startup
|
||||||
|
# before any other torch code has been run
|
||||||
|
if not DASK_JOB_CLIENT_ENABLED:
|
||||||
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
|
|
||||||
if not MODEL_SERVER_HOST:
|
if not MODEL_SERVER_HOST:
|
||||||
logger.info("Warming up Embedding Model(s)")
|
logger.info("Warming up Embedding Model(s)")
|
||||||
warm_up_models(indexer_only=True)
|
warm_up_models(indexer_only=True)
|
||||||
|
@@ -141,8 +141,8 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
|
|||||||
|
|
||||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||||
|
|
||||||
EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED = (
|
DASK_JOB_CLIENT_ENABLED = (
|
||||||
os.environ.get("EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED", "").lower() == "true"
|
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||||
)
|
)
|
||||||
EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||||
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
|
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
|
||||||
|
Reference in New Issue
Block a user