Add support for multiple indexing workers (#322)

This commit is contained in:
Chris Weaver 2023-08-22 18:11:31 -07:00 committed by GitHub
parent 3ea205279f
commit e307275774
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 295 additions and 96 deletions

View File

@ -1,10 +1,16 @@
import logging
import time
from datetime import datetime
from datetime import timezone
from dask.distributed import Client
from dask.distributed import Future
from distributed import LocalCluster
from sqlalchemy.orm import Session
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import IndexAttemptMetadata
@ -18,6 +24,7 @@ from danswer.db.credentials import backend_update_credential_json
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
from danswer.db.index_attempt import get_last_attempt
from danswer.db.index_attempt import get_not_started_index_attempts
@ -28,6 +35,7 @@ from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import Connector
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -45,7 +53,7 @@ def should_create_new_indexing(
return time_since_index.total_seconds() >= connector.refresh_freq
def create_indexing_jobs(db_session: Session) -> None:
def create_indexing_jobs(db_session: Session, existing_jobs: dict[int, Future]) -> None:
connectors = fetch_connectors(db_session)
# clean up in-progress jobs that were never completed
@ -53,11 +61,11 @@ def create_indexing_jobs(db_session: Session) -> None:
in_progress_indexing_attempts = get_inprogress_index_attempts(
connector.id, db_session
)
if in_progress_indexing_attempts:
logger.error("Found incomplete indexing attempts")
# Currently single threaded so any still in-progress must have errored
for attempt in in_progress_indexing_attempts:
# if a job is still going, don't touch it
if attempt.id in existing_jobs:
continue
logger.warning(
f"Marking in-progress attempt 'connector: {attempt.connector_id}, "
f"credential: {attempt.credential_id}' as failed"
@ -69,12 +77,10 @@ def create_indexing_jobs(db_session: Session) -> None:
)
if attempt.connector_id and attempt.credential_id:
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector_id,
credential_id=attempt.credential_id,
attempt_status=IndexingStatus.FAILED,
net_docs=None,
run_dt=None,
db_session=db_session,
)
# potentially kick off new runs
@ -91,41 +97,69 @@ def create_indexing_jobs(db_session: Session) -> None:
create_index_attempt(connector.id, credential.id, db_session)
update_connector_credential_pair(
db_session=db_session,
connector_id=connector.id,
credential_id=credential.id,
attempt_status=IndexingStatus.NOT_STARTED,
net_docs=None,
run_dt=None,
db_session=db_session,
)
def run_indexing_jobs(db_session: Session) -> None:
indexing_pipeline = build_indexing_pipeline()
def cleanup_indexing_jobs(
db_session: Session, existing_jobs: dict[int, Future]
) -> dict[int, Future]:
existing_jobs_copy = existing_jobs.copy()
new_indexing_attempts = get_not_started_index_attempts(db_session)
logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.")
for attempt in new_indexing_attempts:
if attempt.connector is None:
logger.warning(
f"Skipping index attempt as Connector has been deleted: {attempt}"
)
mark_attempt_failed(attempt, db_session, failure_reason="Connector is null")
for attempt_id, job in existing_jobs.items():
if not job.done():
continue
if attempt.credential is None:
# cleanup completed job
job.release()
del existing_jobs_copy[attempt_id]
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
)
if not index_attempt:
logger.error(
f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
"up indexing jobs"
)
continue
if index_attempt.status == IndexingStatus.IN_PROGRESS:
logger.warning(
f"Skipping index attempt as Credential has been deleted: {attempt}"
f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, "
f"credential: {index_attempt.credential_id}' as failed"
)
mark_attempt_failed(
attempt, db_session, failure_reason="Credential is null"
index_attempt=index_attempt,
db_session=db_session,
failure_reason="Stopped mid run, likely due to the background process being killed",
)
continue
logger.info(
f"Starting new indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
if index_attempt.connector_id and index_attempt.credential_id:
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector_id,
credential_id=index_attempt.credential_id,
attempt_status=IndexingStatus.FAILED,
)
return existing_jobs_copy
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
) -> None:
"""
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastores (e.g. Qdrant / Typesense or Vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
def _get_document_generator(
db_session: Session, attempt: IndexAttempt
) -> tuple[GenerateDocumentsOutput, float]:
# "official" timestamp for this run
# used for setting time bounds when fetching updates from apps and
# is stored in the DB as the last successful run time if this run succeeds
@ -133,67 +167,70 @@ def run_indexing_jobs(db_session: Session) -> None:
run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc)
run_time_str = run_dt.strftime("%Y-%m-%d %H:%M:%S")
mark_attempt_in_progress(attempt, db_session)
db_connector = attempt.connector
db_credential = attempt.credential
task = db_connector.input_type
update_connector_credential_pair(
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
net_docs=None,
run_dt=None,
db_session=db_session,
)
task = attempt.connector.input_type
try:
runnable_connector, new_credential_json = instantiate_connector(
db_connector.source,
attempt.connector.source,
task,
db_connector.connector_specific_config,
db_credential.credential_json,
attempt.connector.connector_specific_config,
attempt.credential.credential_json,
)
if new_credential_json is not None:
backend_update_credential_json(
db_credential, new_credential_json, db_session
attempt.credential, new_credential_json, db_session
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
disable_connector(db_connector.id, db_session)
continue
disable_connector(attempt.connector.id, db_session)
raise e
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
raise ValueError(
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
)
last_run_time = get_last_successful_attempt_time(
attempt.connector_id, attempt.credential_id, db_session
)
last_run_time_str = datetime.fromtimestamp(
last_run_time, tz=timezone.utc
).strftime("%Y-%m-%d %H:%M:%S")
logger.info(
f"Polling for updates between {last_run_time_str} and {run_time_str}"
)
doc_batch_generator = runnable_connector.poll_source(
start=last_run_time, end=run_time
)
else:
# Event types cannot be handled by a background type
raise RuntimeError(f"Invalid task type: {task}")
return doc_batch_generator, run_time
doc_batch_generator, run_time = _get_document_generator(db_session, index_attempt)
def _index(
db_session: Session,
attempt: IndexAttempt,
doc_batch_generator: GenerateDocumentsOutput,
run_time: float,
) -> None:
indexing_pipeline = build_indexing_pipeline()
run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc)
db_connector = attempt.connector
db_credential = attempt.credential
net_doc_change = 0
try:
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
raise ValueError(
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
)
last_run_time = get_last_successful_attempt_time(
attempt.connector_id, attempt.credential_id, db_session
)
last_run_time_str = datetime.fromtimestamp(
last_run_time, tz=timezone.utc
).strftime("%Y-%m-%d %H:%M:%S")
logger.info(
f"Polling for updates between {last_run_time_str} and {run_time_str}"
)
doc_batch_generator = runnable_connector.poll_source(
start=last_run_time, end=run_time
)
else:
# Event types cannot be handled by a background type, leave these untouched
continue
net_doc_change = 0
document_count = 0
chunk_count = 0
for doc_batch in doc_batch_generator:
@ -229,12 +266,12 @@ def run_indexing_jobs(db_session: Session) -> None:
mark_attempt_succeeded(attempt, db_session)
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.SUCCESS,
net_docs=net_doc_change,
run_dt=run_dt,
db_session=db_session,
)
logger.info(
@ -243,24 +280,121 @@ def run_indexing_jobs(db_session: Session) -> None:
logger.info(
f"Connector successfully finished, elapsed time: {time.time() - run_time} seconds"
)
except Exception as e:
logger.exception(f"Indexing job with id {attempt.id} failed due to {e}")
logger.info(
f"Failed connector elapsed time: {time.time() - run_time} seconds"
)
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
update_connector_credential_pair(
connector_id=db_connector.id,
credential_id=db_credential.id,
db_session=db_session,
connector_id=attempt.connector.id,
credential_id=attempt.credential.id,
attempt_status=IndexingStatus.FAILED,
net_docs=net_doc_change,
run_dt=run_dt,
)
raise e
_index(db_session, index_attempt, doc_batch_generator, run_time)
def _run_indexing_entrypoint(index_attempt_id: int) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
with Session(get_sqlalchemy_engine()) as db_session:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if attempt is None:
raise RuntimeError(
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
)
logger.info(
f"Running indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector.id,
credential_id=attempt.credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
)
_run_indexing(
db_session=db_session,
index_attempt=attempt,
)
def update_loop(delay: int = 10) -> None:
logger.info(
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
except Exception as e:
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
def kickoff_indexing_jobs(
db_session: Session,
existing_jobs: dict[int, Future],
client: Client,
) -> dict[int, Future]:
existing_jobs_copy = existing_jobs.copy()
new_indexing_attempts = get_not_started_index_attempts(db_session)
logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.")
if not new_indexing_attempts:
return existing_jobs
for attempt in new_indexing_attempts:
if attempt.connector is None:
logger.warning(
f"Skipping index attempt as Connector has been deleted: {attempt}"
)
mark_attempt_failed(attempt, db_session, failure_reason="Connector is null")
continue
if attempt.credential is None:
logger.warning(
f"Skipping index attempt as Credential has been deleted: {attempt}"
)
mark_attempt_failed(
attempt, db_session, failure_reason="Credential is null"
)
continue
logger.info(
f"Kicking off indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
mark_attempt_in_progress(attempt, db_session)
run = client.submit(_run_indexing_entrypoint, attempt.id, pure=False)
existing_jobs_copy[attempt.id] = run
return existing_jobs_copy
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
cluster = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
# there are warning about high memory usage + "Event loop unresponsive"
# which are not relevant to us since our workers are expected to use a
# lot of memory + involve CPU intensive tasks that will not relinquish
# the event loop
silence_logs=logging.ERROR,
)
client = Client(cluster)
existing_jobs: dict[int, Future] = {}
engine = get_sqlalchemy_engine()
while True:
start = time.time()
@ -268,8 +402,13 @@ def update_loop(delay: int = 10) -> None:
logger.info(f"Running update, current UTC time: {start_time_utc}")
try:
with Session(engine, expire_on_commit=False) as db_session:
create_indexing_jobs(db_session)
run_indexing_jobs(db_session)
existing_jobs = cleanup_indexing_jobs(
db_session=db_session, existing_jobs=existing_jobs
)
create_indexing_jobs(db_session=db_session, existing_jobs=existing_jobs)
existing_jobs = kickoff_indexing_jobs(
db_session=db_session, existing_jobs=existing_jobs, client=client
)
except Exception as e:
logger.exception(f"Failed to run update due to {e}")
sleep_time = delay - (time.time() - start)

View File

@ -163,6 +163,11 @@ LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")
CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
"CONTINUE_ON_CONNECTOR_FAILURE", ""
).lower() not in ["false", ""]
# Controls how many worker processes we spin up to index documents in the
# background. This is useful for speeding up indexing, but does require a
# fairly large amount of memory in order to increase substantially, since
# each worker loads the embedding models into memory.
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
#####

View File

@ -57,12 +57,12 @@ def get_last_successful_attempt_time(
def update_connector_credential_pair(
db_session: Session,
connector_id: int,
credential_id: int,
attempt_status: IndexingStatus,
net_docs: int | None,
run_dt: datetime | None,
db_session: Session,
net_docs: int | None = None,
run_dt: datetime | None = None,
) -> None:
cc_pair = get_connector_credential_pair(connector_id, credential_id, db_session)
if not cc_pair:

View File

@ -18,6 +18,13 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_index_attempt(
db_session: Session, index_attempt_id: int
) -> IndexAttempt | None:
stmt = select(IndexAttempt).where(IndexAttempt.id == index_attempt_id)
return db_session.scalars(stmt).first()
def create_index_attempt(
connector_id: int,
credential_id: int,

View File

@ -1,3 +1,4 @@
import logging
import os
from collections.abc import Callable
from functools import wraps
@ -183,7 +184,12 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
# TODO: message should be enqueued and processed elsewhere,
# but doing it here for now for simplicity
@retry(tries=DANSWER_BOT_NUM_RETRIES, delay=0.25, backoff=2, logger=logger)
@retry(
tries=DANSWER_BOT_NUM_RETRIES,
delay=0.25,
backoff=2,
logger=cast(logging.Logger, logger),
)
def _get_answer(question: QuestionRequest) -> QAResponse:
answer = answer_question(
question=question,
@ -227,7 +233,12 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
else:
text = f"{answer.answer}\n\n*Warning*: no sources were quoted for this answer, so it may be unreliable 😔\n\n{top_documents_str_with_header}"
@retry(tries=DANSWER_BOT_NUM_RETRIES, delay=0.25, backoff=2, logger=logger)
@retry(
tries=DANSWER_BOT_NUM_RETRIES,
delay=0.25,
backoff=2,
logger=cast(logging.Logger, logger),
)
def _respond_in_thread(
channel: str,
text: str,

View File

@ -1,9 +1,26 @@
import logging
from logging import Logger
from collections.abc import MutableMapping
from typing import Any
from danswer.configs.app_configs import LOG_LEVEL
class IndexAttemptSingleton:
"""Used to tell if this process is an indexing job, and if so what is the
unique identifier for this indexing attempt. For things like the API server,
main background job (scheduler), etc. this will not be used."""
_INDEX_ATTEMPT_ID: None | int = None
@classmethod
def get_index_attempt_id(cls) -> None | int:
return cls._INDEX_ATTEMPT_ID
@classmethod
def set_index_attempt_id(cls, index_attempt_id: int) -> None:
cls._INDEX_ATTEMPT_ID = index_attempt_id
def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int:
log_level_dict = {
"CRITICAL": logging.CRITICAL,
@ -17,14 +34,31 @@ def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int:
return log_level_dict.get(log_level_str.upper(), logging.INFO)
class _IndexAttemptLoggingAdapter(logging.LoggerAdapter):
"""This is used to globally add the index attempt id to all log messages
during indexing by workers. This is done so that the logs can be filtered
by index attempt ID to get a better idea of what happened during a specific
indexing attempt. If the index attempt ID is not set, then this adapter
is a no-op."""
def process(
self, msg: str, kwargs: MutableMapping[str, Any]
) -> tuple[str, MutableMapping[str, Any]]:
attempt_id = IndexAttemptSingleton.get_index_attempt_id()
if attempt_id is None:
return msg, kwargs
return f"[Attempt ID: {attempt_id}] {msg}", kwargs
def setup_logger(
name: str = __name__, log_level: int = get_log_level_from_str()
) -> Logger:
) -> logging.LoggerAdapter:
logger = logging.getLogger(name)
# If the logger already has handlers, assume it was already configured and return it.
if logger.handlers:
return logger
return _IndexAttemptLoggingAdapter(logger)
logger.setLevel(log_level)
@ -39,4 +73,4 @@ def setup_logger(
logger.addHandler(handler)
return logger
return _IndexAttemptLoggingAdapter(logger)

View File

@ -2,6 +2,8 @@ alembic==1.10.4
asyncpg==0.27.0
atlassian-python-api==3.37.0
beautifulsoup4==4.12.0
dask==2023.8.1
distributed==2023.8.1
python-dateutil==2.8.2
fastapi==0.95.0
fastapi-users==11.0.0

View File

@ -69,6 +69,7 @@ services:
- API_VERSION_OPENAI=${API_VERSION_OPENAI:-}
- AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-}
- CONTINUE_ON_CONNECTOR_FAILURE=${CONTINUE_ON_CONNECTOR_FAILURE:-}
- NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-}
- DANSWER_BOT_SLACK_APP_TOKEN=${DANSWER_BOT_SLACK_APP_TOKEN:-}
- DANSWER_BOT_SLACK_BOT_TOKEN=${DANSWER_BOT_SLACK_BOT_TOKEN:-}
- LOG_LEVEL=${LOG_LEVEL:-info}