Make indexing jobs use more cores again

This commit is contained in:
Weves 2023-10-30 18:45:51 -07:00 committed by Chris Weaver
parent a1da4dfac6
commit 517a539d7e

View File

@ -49,6 +49,13 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
)
def _get_num_threads() -> int:
"""Get # of "threads" to use for ML models in an indexing job. By default uses
the torch implementation, which returns the # of physical cores on the machine.
"""
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
def should_create_new_indexing(
connector: Connector, last_index: IndexAttempt | None, db_session: Session
) -> bool:
@ -356,21 +363,18 @@ def _run_indexing(
_index(db_session, index_attempt, doc_batch_generator, run_time)
def _run_indexing_entrypoint(index_attempt_id: int) -> None:
def _run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
cpu_cores_to_use = max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
logger.info(f"Setting task to use {cpu_cores_to_use} threads")
torch.set_num_threads(cpu_cores_to_use)
try:
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
logger.info(f"Setting task to use {num_threads} threads")
torch.set_num_threads(num_threads)
with Session(get_sqlalchemy_engine()) as db_session:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
@ -444,7 +448,9 @@ def kickoff_indexing_jobs(
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
run = client.submit(_run_indexing_entrypoint, attempt.id, pure=False)
run = client.submit(
_run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
)
existing_jobs_copy[attempt.id] = run
return existing_jobs_copy