Add experimental checkpointing

This commit is contained in:
Weves 2023-11-01 18:05:48 -07:00 committed by Chris Weaver
parent 2db029672b
commit fe938b6fc6
5 changed files with 336 additions and 231 deletions

View File

@ -0,0 +1,75 @@
"""Experimental functionality related to splitting up indexing
into a series of checkpoints to better handle intermmittent failures
/ jobs being killed by cloud providers."""
import datetime
from danswer.configs.app_configs import EXPERIMENTAL_CHECKPOINTING_ENABLED
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.time_utils import datetime_to_utc
def _2010_dt() -> datetime.datetime:
return datetime.datetime(year=2010, month=1, day=1, tzinfo=datetime.timezone.utc)
def _2020_dt() -> datetime.datetime:
return datetime.datetime(year=2010, month=1, day=1, tzinfo=datetime.timezone.utc)
def _default_end_time(
last_successful_run: datetime.datetime | None,
) -> datetime.datetime:
"""If year is before 2010, go to the beginning of 2010.
If year is 2010-2020, go in 5 year increments.
If year > 2020, then go in 180 day increments.
For connectors that don't support a `filter_by` and instead rely on `sort_by`
for polling, then this will cause a massive duplication of fetches. For these
connectors, you may want to override this function to return a more reasonable
plan (e.g. extending the 2020+ windows to 6 months, 1 year, or higher)."""
last_successful_run = (
datetime_to_utc(last_successful_run) if last_successful_run else None
)
if last_successful_run is None or last_successful_run < _2010_dt():
return _2010_dt()
if last_successful_run < _2020_dt():
return last_successful_run + datetime.timedelta(days=365 * 5)
return last_successful_run + datetime.timedelta(days=180)
def find_end_time_for_indexing_attempt(
last_successful_run: datetime.datetime | None, source_type: DocumentSource
) -> datetime.datetime | None:
# NOTE: source_type can be used to override the default for certain connectors
end_of_window = _default_end_time(last_successful_run)
now = datetime.datetime.now(tz=datetime.timezone.utc)
if end_of_window < now:
return end_of_window
# None signals that we should index up to current time
return None
def get_time_windows_for_index_attempt(
last_successful_run: datetime.datetime, source_type: DocumentSource
) -> list[tuple[datetime.datetime, datetime.datetime]]:
if not EXPERIMENTAL_CHECKPOINTING_ENABLED:
return [(last_successful_run, datetime.datetime.now(tz=datetime.timezone.utc))]
time_windows: list[tuple[datetime.datetime, datetime.datetime]] = []
start_of_window: datetime.datetime | None = last_successful_run
while start_of_window:
end_of_window = find_end_time_for_indexing_attempt(
last_successful_run=start_of_window, source_type=source_type
)
time_windows.append(
(
start_of_window,
end_of_window or datetime.datetime.now(tz=datetime.timezone.utc),
)
)
start_of_window = end_of_window
return time_windows

View File

@ -0,0 +1,253 @@
import time
from datetime import datetime
from datetime import timezone
import torch
from sqlalchemy.orm import Session
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import IndexAttemptMetadata
from danswer.connectors.models import InputType
from danswer.db.connector import disable_connector
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.credentials import backend_update_credential_json
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_document_generator(
db_session: Session,
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
) -> GenerateDocumentsOutput:
"""NOTE: `start_time` and `end_time` are only used for poll connectors"""
task = attempt.connector.input_type
try:
runnable_connector, new_credential_json = instantiate_connector(
attempt.connector.source,
task,
attempt.connector.connector_specific_config,
attempt.credential.credential_json,
)
if new_credential_json is not None:
backend_update_credential_json(
attempt.credential, new_credential_json, db_session
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
disable_connector(attempt.connector.id, db_session)
raise e
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
raise ValueError(
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
)
logger.info(f"Polling for updates between {start_time} and {end_time}")
doc_batch_generator = runnable_connector.poll_source(
start=start_time.timestamp(), end=end_time.timestamp()
)
else:
# Event types cannot be handled by a background type
raise RuntimeError(f"Invalid task type: {task}")
return doc_batch_generator
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
) -> None:
"""
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
start_time = time.time()
# mark as started
mark_attempt_in_progress(index_attempt, db_session)
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector.id,
credential_id=index_attempt.credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
)
indexing_pipeline = build_indexing_pipeline()
db_connector = index_attempt.connector
db_credential = index_attempt.credential
last_successful_index_time = get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
db_session=db_session,
)
net_doc_change = 0
document_count = 0
chunk_count = 0
for ind, (window_start, window_end) in enumerate(
get_time_windows_for_index_attempt(
last_successful_run=datetime.fromtimestamp(
last_successful_index_time, tz=timezone.utc
),
source_type=db_connector.source,
)
):
doc_batch_generator = _get_document_generator(
db_session=db_session,
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
)
try:
for doc_batch in doc_batch_generator:
logger.debug(
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
)
new_docs, total_batch_chunks = indexing_pipeline(
documents=doc_batch,
index_attempt_metadata=IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
),
)
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch)
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
# of the transactions when computing `NOW()`, so if we have
# a long running transaction, the `time_updated` field will
# be inaccurate
db_session.commit()
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
db_session=db_session,
index_attempt=index_attempt,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
)
# check if connector is disabled mid run and stop if so
db_session.refresh(db_connector)
if db_connector.disabled:
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
net_docs=net_doc_change,
run_dt=window_end,
)
except Exception as e:
logger.info(
f"Failed connector elapsed time: {time.time() - start_time} seconds"
)
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure
if ind == 0:
mark_attempt_failed(index_attempt, db_session, failure_reason=str(e))
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector.id,
credential_id=index_attempt.credential.id,
attempt_status=IndexingStatus.FAILED,
net_docs=net_doc_change,
run_dt=window_end,
)
raise e
mark_attempt_succeeded(index_attempt, db_session)
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.SUCCESS,
net_docs=net_doc_change,
run_dt=window_end,
)
logger.info(
f"Indexed or updated {document_count} total documents for a total of {chunk_count} chunks"
)
logger.info(
f"Connector successfully finished, elapsed time: {time.time() - start_time} seconds"
)
def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
logger.info(f"Setting task to use {num_threads} threads")
torch.set_num_threads(num_threads)
with Session(get_sqlalchemy_engine()) as db_session:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if attempt is None:
raise RuntimeError(
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
)
logger.info(
f"Running indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
_run_indexing(
db_session=db_session,
index_attempt=attempt,
)
logger.info(
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
except Exception as e:
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")

View File

@ -1,7 +1,6 @@
import logging
import time
from datetime import datetime
from datetime import timezone
import dask
import torch
@ -12,21 +11,13 @@ from sqlalchemy.orm import Session
from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import IndexAttemptMetadata
from danswer.connectors.models import InputType
from danswer.db.connector import disable_connector
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.credentials import backend_update_credential_json
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import create_index_attempt
@ -35,15 +26,10 @@ from danswer.db.index_attempt import get_inprogress_index_attempts
from danswer.db.index_attempt import get_last_attempt
from danswer.db.index_attempt import get_not_started_index_attempts
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import Connector
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.search.search_nlp_models import warm_up_models
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -209,221 +195,6 @@ def cleanup_indexing_jobs(
return existing_jobs_copy
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
) -> None:
"""
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
def _get_document_generator(
db_session: Session, attempt: IndexAttempt
) -> tuple[GenerateDocumentsOutput, float]:
# "official" timestamp for this run
# used for setting time bounds when fetching updates from apps and
# is stored in the DB as the last successful run time if this run succeeds
run_time = time.time()
run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc)
run_time_str = run_dt.strftime("%Y-%m-%d %H:%M:%S")
task = attempt.connector.input_type
try:
runnable_connector, new_credential_json = instantiate_connector(
attempt.connector.source,
task,
attempt.connector.connector_specific_config,
attempt.credential.credential_json,
)
if new_credential_json is not None:
backend_update_credential_json(
attempt.credential, new_credential_json, db_session
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
disable_connector(attempt.connector.id, db_session)
raise e
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
raise ValueError(
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
)
last_run_time = get_last_successful_attempt_time(
attempt.connector_id, attempt.credential_id, db_session
)
last_run_time_str = datetime.fromtimestamp(
last_run_time, tz=timezone.utc
).strftime("%Y-%m-%d %H:%M:%S")
logger.info(
f"Polling for updates between {last_run_time_str} and {run_time_str}"
)
doc_batch_generator = runnable_connector.poll_source(
start=last_run_time, end=run_time
)
else:
# Event types cannot be handled by a background type
raise RuntimeError(f"Invalid task type: {task}")
return doc_batch_generator, run_time
doc_batch_generator, run_time = _get_document_generator(db_session, index_attempt)
def _index(
db_session: Session,
attempt: IndexAttempt,
doc_batch_generator: GenerateDocumentsOutput,
run_time: float,
) -> None:
indexing_pipeline = build_indexing_pipeline()
run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc)
db_connector = attempt.connector
db_credential = attempt.credential
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
run_dt=run_dt,
)
net_doc_change = 0
document_count = 0
chunk_count = 0
try:
for doc_batch in doc_batch_generator:
logger.debug(
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
)
new_docs, total_batch_chunks = indexing_pipeline(
documents=doc_batch,
index_attempt_metadata=IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
),
)
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch)
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
# of the transactions when computing `NOW()`, so if we have
# a long running transaction, the `time_updated` field will
# be inaccurate
db_session.commit()
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
db_session=db_session,
index_attempt=attempt,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
)
# check if connector is disabled mid run and stop if so
db_session.refresh(db_connector)
if db_connector.disabled:
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
mark_attempt_succeeded(attempt, db_session)
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.SUCCESS,
net_docs=net_doc_change,
run_dt=run_dt,
)
logger.info(
f"Indexed or updated {document_count} total documents for a total of {chunk_count} chunks"
)
logger.info(
f"Connector successfully finished, elapsed time: {time.time() - run_time} seconds"
)
except Exception as e:
logger.info(
f"Failed connector elapsed time: {time.time() - run_time} seconds"
)
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
# The last attempt won't be marked failed until the next cycle's check for still in-progress attempts
# The connector_credential_pair is marked failed here though to reflect correctly in UI asap
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector.id,
credential_id=attempt.credential.id,
attempt_status=IndexingStatus.FAILED,
net_docs=net_doc_change,
run_dt=run_dt,
)
raise e
_index(db_session, index_attempt, doc_batch_generator, run_time)
def _run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
logger.info(f"Setting task to use {num_threads} threads")
torch.set_num_threads(num_threads)
with Session(get_sqlalchemy_engine()) as db_session:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if attempt is None:
raise RuntimeError(
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
)
logger.info(
f"Running indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
mark_attempt_in_progress(attempt, db_session)
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector.id,
credential_id=attempt.credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
)
_run_indexing(
db_session=db_session,
index_attempt=attempt,
)
logger.info(
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
except Exception as e:
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
def kickoff_indexing_jobs(
db_session: Session,
existing_jobs: dict[int, Future | SimpleJob],
@ -461,7 +232,7 @@ def kickoff_indexing_jobs(
continue
run = client.submit(
_run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
)
if run:
logger.info(

View File

@ -143,6 +143,10 @@ GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED = (
os.environ.get("EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED", "").lower() == "true"
)
EXPERIMENTAL_CHECKPOINTING_ENABLED = (
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
)
#####
# Query Configs

View File

@ -54,6 +54,8 @@ def get_last_successful_attempt_time(
credential_id: int,
db_session: Session,
) -> float:
"""Gets the timestamp of the last successful index run stored in
the CC Pair row in the database"""
connector_credential_pair = get_connector_credential_pair(
connector_id, credential_id, db_session
)