diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 9937d43638..ec5378dd31 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -49,6 +49,13 @@ _UNEXPECTED_STATE_FAILURE_REASON = ( ) +def _get_num_threads() -> int: + """Get # of "threads" to use for ML models in an indexing job. By default uses + the torch implementation, which returns the # of physical cores on the machine. + """ + return max(MIN_THREADS_ML_MODELS, torch.get_num_threads()) + + def should_create_new_indexing( connector: Connector, last_index: IndexAttempt | None, db_session: Session ) -> bool: @@ -356,21 +363,18 @@ def _run_indexing( _index(db_session, index_attempt, doc_batch_generator, run_time) -def _run_indexing_entrypoint(index_attempt_id: int) -> None: +def _run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None: """Entrypoint for indexing run when using dask distributed. Wraps the actual logic in a `try` block so that we can catch any exceptions and mark the attempt as failed.""" - - cpu_cores_to_use = max(MIN_THREADS_ML_MODELS, torch.get_num_threads()) - - logger.info(f"Setting task to use {cpu_cores_to_use} threads") - torch.set_num_threads(cpu_cores_to_use) - try: # set the indexing attempt ID so that all log messages from this process # will have it added as a prefix IndexAttemptSingleton.set_index_attempt_id(index_attempt_id) + logger.info(f"Setting task to use {num_threads} threads") + torch.set_num_threads(num_threads) + with Session(get_sqlalchemy_engine()) as db_session: attempt = get_index_attempt( db_session=db_session, index_attempt_id=index_attempt_id @@ -444,7 +448,9 @@ def kickoff_indexing_jobs( f"with config: '{attempt.connector.connector_specific_config}', and " f"with credentials: '{attempt.credential_id}'" ) - run = client.submit(_run_indexing_entrypoint, attempt.id, pure=False) + run = client.submit( + _run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False + ) existing_jobs_copy[attempt.id] = run return existing_jobs_copy