Add simple job client to try and get rid of some of the flakiness / weirdness that we are seeing with Dask

This commit is contained in:
Weves 2023-11-01 17:15:15 -07:00 committed by Chris Weaver
parent 73b653d324
commit d9adee168b
3 changed files with 141 additions and 23 deletions

View File

@ -0,0 +1,103 @@
"""Custom client that works similarly to Dask, but simpler and more lightweight.
Dask jobs behaved very strangely - they would die all the time, retries would
not follow the expected behavior, etc.
NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
import multiprocessing
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from typing import Literal
from danswer.utils.logger import setup_logger
logger = setup_logger()
JobStatusType = (
Literal["error"]
| Literal["finished"]
| Literal["pending"]
| Literal["running"]
| Literal["cancelled"]
)
@dataclass
class SimpleJob:
"""Drop in replacement for `dask.distributed.Future`"""
id: int
process: multiprocessing.Process | None = None
def cancel(self) -> bool:
return self.release()
def release(self) -> bool:
if self.process is not None and self.process.is_alive():
self.process.terminate()
return True
return False
@property
def status(self) -> JobStatusType:
if not self.process:
return "pending"
elif self.process.is_alive():
return "running"
elif self.process.exitcode is None:
return "cancelled"
elif self.process.exitcode > 0:
return "error"
else:
return "finished"
def done(self) -> bool:
return (
self.status == "finished"
or self.status == "cancelled"
or self.status == "error"
)
def exception(self) -> str:
"""Needed to match the Dask API, but not implemented since we don't currently
have a way to get back the exception information from the child process."""
return (
f"Job with ID '{self.id}' was killed or encountered an unhandled exception."
)
class SimpleJobClient:
"""Drop in replacement for `dask.distributed.Client`"""
def __init__(self, n_workers: int = 1) -> None:
self.n_workers = n_workers
self.job_id_counter = 0
self.jobs: dict[int, SimpleJob] = {}
def _cleanup_completed_jobs(self) -> None:
current_job_ids = list(self.jobs.keys())
for job_id in current_job_ids:
job = self.jobs.get(job_id)
if job and job.done():
logger.debug(f"Cleaning up job with id: '{job.id}'")
del self.jobs[job.id]
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
self._cleanup_completed_jobs()
if len(self.jobs) >= self.n_workers:
logger.debug("No available workers to run job")
return None
job_id = self.job_id_counter
self.job_id_counter += 1
process = multiprocessing.Process(target=func, args=args)
job = SimpleJob(id=job_id, process=process)
process.start()
self.jobs[job_id] = job
return job

View File

@ -10,6 +10,9 @@ from dask.distributed import Future
from distributed import LocalCluster
from sqlalchemy.orm import Session
from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
from danswer.connectors.factory import instantiate_connector
@ -99,7 +102,9 @@ def mark_run_failed(
)
def create_indexing_jobs(db_session: Session, existing_jobs: dict[int, Future]) -> None:
def create_indexing_jobs(
db_session: Session, existing_jobs: dict[int, Future | SimpleJob]
) -> None:
"""Creates new indexing jobs for each connector / credential pair which is:
1. Enabled
2. `refresh_frequency` time has passed since the last indexing run for this pair
@ -139,8 +144,8 @@ def create_indexing_jobs(db_session: Session, existing_jobs: dict[int, Future])
def cleanup_indexing_jobs(
db_session: Session, existing_jobs: dict[int, Future]
) -> dict[int, Future]:
db_session: Session, existing_jobs: dict[int, Future | SimpleJob]
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
# clean up completed jobs
@ -421,9 +426,9 @@ def _run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
def kickoff_indexing_jobs(
db_session: Session,
existing_jobs: dict[int, Future],
client: Client,
) -> dict[int, Future]:
existing_jobs: dict[int, Future | SimpleJob],
client: Client | SimpleJobClient,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
# Don't include jobs waiting in the Dask queue that just haven't started running
@ -455,31 +460,37 @@ def kickoff_indexing_jobs(
)
continue
logger.info(
f"Kicking off indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
run = client.submit(
_run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
)
existing_jobs_copy[attempt.id] = run
if run:
logger.info(
f"Kicked off indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
existing_jobs_copy[attempt.id] = run
return existing_jobs_copy
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
cluster = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
# there are warning about high memory usage + "Event loop unresponsive"
# which are not relevant to us since our workers are expected to use a
# lot of memory + involve CPU intensive tasks that will not relinquish
# the event loop
silence_logs=logging.ERROR,
)
client = Client(cluster)
existing_jobs: dict[int, Future] = {}
client: Client | SimpleJobClient
if EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED:
client = SimpleJobClient(n_workers=num_workers)
else:
cluster = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
# there are warning about high memory usage + "Event loop unresponsive"
# which are not relevant to us since our workers are expected to use a
# lot of memory + involve CPU intensive tasks that will not relinquish
# the event loop
silence_logs=logging.ERROR,
)
client = Client(cluster)
existing_jobs: dict[int, Future | SimpleJob] = {}
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:

View File

@ -140,6 +140,10 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED = (
os.environ.get("EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED", "").lower() == "true"
)
#####
# Query Configs
#####