From 9fa4280f962d78ee1ab9b54ca5be8bcb238f56b6 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Thu, 15 Aug 2024 13:40:17 -0700 Subject: [PATCH] add configurable support for memory tracing during indexing (#2140) --- .../background/indexing/run_indexing.py | 38 +++++++++ backend/danswer/background/indexing/tracer.py | 77 +++++++++++++++++++ backend/danswer/configs/app_configs.py | 4 + 3 files changed, 119 insertions(+) create mode 100644 backend/danswer/background/indexing/tracer.py diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index f2acebd93..4b4405956 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -7,7 +7,9 @@ from datetime import timezone from sqlalchemy.orm import Session from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt +from danswer.background.indexing.tracer import DanswerTracer from danswer.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD +from danswer.configs.app_configs import INDEXING_TRACER_INTERVAL from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET from danswer.connectors.factory import instantiate_connector from danswer.connectors.interfaces import GenerateDocumentsOutput @@ -36,6 +38,8 @@ from danswer.utils.variable_functionality import global_version logger = setup_logger() +INDEXING_TRACER_NUM_PRINT_ENTRIES = 5 + def _get_document_generator( db_session: Session, @@ -109,6 +113,7 @@ def _run_indexing( 3. Updates Postgres to record the indexed documents + the outcome of this run """ start_time = time.time() + db_embedding_model = index_attempt.embedding_model index_name = db_embedding_model.index_name @@ -152,6 +157,12 @@ def _run_indexing( ) ) + if INDEXING_TRACER_INTERVAL > 0: + logger.info(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}") + tracer = DanswerTracer() + tracer.start() + tracer.snap() + net_doc_change = 0 document_count = 0 chunk_count = 0 @@ -178,6 +189,10 @@ def _run_indexing( ) all_connector_doc_ids: set[str] = set() + + tracer_counter = 0 + if INDEXING_TRACER_INTERVAL > 0: + tracer.snap() for doc_batch in doc_batch_generator: # Check if connector is disabled mid run and stop if so unless it's the secondary # index being built. We want to populate it even for paused connectors @@ -241,6 +256,17 @@ def _run_indexing( docs_removed_from_index=0, ) + tracer_counter += 1 + if ( + INDEXING_TRACER_INTERVAL > 0 + and tracer_counter % INDEXING_TRACER_INTERVAL == 0 + ): + logger.info( + f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}" + ) + tracer.snap() + tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES) + run_end_dt = window_end if is_primary: update_connector_credential_pair( @@ -279,12 +305,24 @@ def _run_indexing( credential_id=db_credential.id, net_docs=net_doc_change, ) + + if INDEXING_TRACER_INTERVAL > 0: + tracer.stop() raise e # break => similar to success case. As mentioned above, if the next run fails for the same # reason it will then be marked as a failure break + if INDEXING_TRACER_INTERVAL > 0: + logger.info( + f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed." + ) + tracer.snap() + tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES) + tracer.stop() + logger.info("Memory tracer stopped.") + mark_attempt_succeeded(index_attempt, db_session) if is_primary: update_connector_credential_pair( diff --git a/backend/danswer/background/indexing/tracer.py b/backend/danswer/background/indexing/tracer.py new file mode 100644 index 000000000..d91775567 --- /dev/null +++ b/backend/danswer/background/indexing/tracer.py @@ -0,0 +1,77 @@ +import tracemalloc + +from danswer.utils.logger import setup_logger + +logger = setup_logger() + +DANSWER_TRACEMALLOC_FRAMES = 10 + + +class DanswerTracer: + def __init__(self) -> None: + self.snapshot_first: tracemalloc.Snapshot | None = None + self.snapshot_prev: tracemalloc.Snapshot | None = None + self.snapshot: tracemalloc.Snapshot | None = None + + def start(self) -> None: + tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES) + + def stop(self) -> None: + tracemalloc.stop() + + def snap(self) -> None: + snapshot = tracemalloc.take_snapshot() + # Filter out irrelevant frames (e.g., from tracemalloc itself or importlib) + snapshot = snapshot.filter_traces( + ( + tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc + tracemalloc.Filter( + False, "" + ), # Exclude importlib + tracemalloc.Filter( + False, "" + ), # Exclude external importlib + ) + ) + + if not self.snapshot_first: + self.snapshot_first = snapshot + + if self.snapshot: + self.snapshot_prev = self.snapshot + + self.snapshot = snapshot + + def log_snapshot(self, numEntries: int) -> None: + if not self.snapshot: + return + + stats = self.snapshot.statistics("traceback") + for s in stats[:numEntries]: + logger.info(f"Tracer snap: {s}") + for line in s.traceback: + logger.info(f"* {line}") + + @staticmethod + def log_diff( + snap_current: tracemalloc.Snapshot, + snap_previous: tracemalloc.Snapshot, + numEntries: int, + ) -> None: + stats = snap_current.compare_to(snap_previous, "traceback") + for s in stats[:numEntries]: + logger.info(f"Tracer diff: {s}") + for line in s.traceback.format(): + logger.info(f"* {line}") + + def log_previous_diff(self, numEntries: int) -> None: + if not self.snapshot or not self.snapshot_prev: + return + + DanswerTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries) + + def log_first_diff(self, numEntries: int) -> None: + if not self.snapshot or not self.snapshot_first: + return + + DanswerTracer.log_diff(self.snapshot, self.snapshot_first, numEntries) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index eeb8a1a7a..2db71f5e5 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -291,6 +291,10 @@ INDEXING_SIZE_WARNING_THRESHOLD = int( os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024) ) +# during indexing, will log verbose memory diff stats every x batches and at the end. +# 0 disables this behavior and is the default. +INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0)) + ##### # Miscellaneous #####