Alternative solution to up the number of threads for torch (#632)

This commit is contained in:
Yuhong Sun
2023-10-25 22:30:57 -07:00
committed by GitHub
parent 379e71160a
commit 604e511c09
3 changed files with 8 additions and 19 deletions

View File

@@ -3,15 +3,14 @@ import time
from datetime import datetime from datetime import datetime
from datetime import timezone from datetime import timezone
import torch
from dask.distributed import Client from dask.distributed import Client
from dask.distributed import Future from dask.distributed import Future
from distributed import LocalCluster from distributed import LocalCluster
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.model_configs import ( from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
BACKGROUND_JOB_EMBEDDING_MODEL_CPU_CORES_LEFT_UNUSED,
)
from danswer.connectors.factory import instantiate_connector from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import LoadConnector
@@ -351,15 +350,9 @@ def _run_indexing_entrypoint(index_attempt_id: int) -> None:
"""Entrypoint for indexing run when using dask distributed. """Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed.""" and mark the attempt as failed."""
import torch
import os
# force torch to use more cores if available. On VMs pytorch only takes cpu_cores_to_use = max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
# advantage of a single core by default
cpu_cores_to_use = max(
(os.cpu_count() or 1) - BACKGROUND_JOB_EMBEDDING_MODEL_CPU_CORES_LEFT_UNUSED,
torch.get_num_threads(),
)
logger.info(f"Setting task to use {cpu_cores_to_use} threads") logger.info(f"Setting task to use {cpu_cores_to_use} threads")
torch.set_num_threads(cpu_cores_to_use) torch.set_num_threads(cpu_cores_to_use)

View File

@@ -30,13 +30,9 @@ ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "") ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "")
# Purely an optimization, memory limitation consideration # Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8 BATCH_SIZE_ENCODE_CHUNKS = 8
# This controls the number of pytorch "threads" to allocate to the embedding # This controls the minimum number of pytorch "threads" to allocate to the embedding
# model. Specifically, this is computed as `num_cpu_cores - BACKGROUND_JOB_EMBEDDING_MODEL_CPU_CORES_LEFT_UNUSED`. # model. If torch finds more threads on its own, this value is not used.
# This is useful for limiting the number of CPU cores that the background job consumes to leave some MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
# compute for other processes (most importantly the api_server and web_server).
BACKGROUND_JOB_EMBEDDING_MODEL_CPU_CORES_LEFT_UNUSED = int(
os.environ.get("BACKGROUND_JOB_EMBEDDING_MODEL_CPU_CORES_LEFT_UNUSED") or 1
)
# Cross Encoder Settings # Cross Encoder Settings

View File

@@ -96,7 +96,7 @@ services:
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-} - ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
- SKIP_RERANKING=${SKIP_RERANKING:-} - SKIP_RERANKING=${SKIP_RERANKING:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- BACKGROUND_JOB_EMBEDDING_MODEL_CPU_CORES_LEFT_UNUSED=${BACKGROUND_JOB_EMBEDDING_MODEL_CPU_CORES_LEFT_UNUSED:-} - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs # Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info} - LOG_LEVEL=${LOG_LEVEL:-info}
volumes: volumes: