add configurable support for memory tracing during indexing (#2140)

This commit is contained in:
rkuo-danswer 2024-08-15 13:40:17 -07:00 committed by GitHub
parent 4d194bc86a
commit 9fa4280f96
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 119 additions and 0 deletions

View File

@ -7,7 +7,9 @@ from datetime import timezone
from sqlalchemy.orm import Session
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
from danswer.background.indexing.tracer import DanswerTracer
from danswer.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from danswer.configs.app_configs import INDEXING_TRACER_INTERVAL
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
@ -36,6 +38,8 @@ from danswer.utils.variable_functionality import global_version
logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
def _get_document_generator(
db_session: Session,
@ -109,6 +113,7 @@ def _run_indexing(
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
start_time = time.time()
db_embedding_model = index_attempt.embedding_model
index_name = db_embedding_model.index_name
@ -152,6 +157,12 @@ def _run_indexing(
)
)
if INDEXING_TRACER_INTERVAL > 0:
logger.info(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = DanswerTracer()
tracer.start()
tracer.snap()
net_doc_change = 0
document_count = 0
chunk_count = 0
@ -178,6 +189,10 @@ def _run_indexing(
)
all_connector_doc_ids: set[str] = set()
tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in doc_batch_generator:
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
@ -241,6 +256,17 @@ def _run_indexing(
docs_removed_from_index=0,
)
tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
):
logger.info(
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
)
tracer.snap()
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
run_end_dt = window_end
if is_primary:
update_connector_credential_pair(
@ -279,12 +305,24 @@ def _run_indexing(
credential_id=db_credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
break
if INDEXING_TRACER_INTERVAL > 0:
logger.info(
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
)
tracer.snap()
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
tracer.stop()
logger.info("Memory tracer stopped.")
mark_attempt_succeeded(index_attempt, db_session)
if is_primary:
update_connector_credential_pair(

View File

@ -0,0 +1,77 @@
import tracemalloc
from danswer.utils.logger import setup_logger
logger = setup_logger()
DANSWER_TRACEMALLOC_FRAMES = 10
class DanswerTracer:
def __init__(self) -> None:
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None
def start(self) -> None:
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
def stop(self) -> None:
tracemalloc.stop()
def snap(self) -> None:
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
tracemalloc.Filter(
False, "<frozen importlib._bootstrap>"
), # Exclude importlib
tracemalloc.Filter(
False, "<frozen importlib._bootstrap_external>"
), # Exclude external importlib
)
)
if not self.snapshot_first:
self.snapshot_first = snapshot
if self.snapshot:
self.snapshot_prev = self.snapshot
self.snapshot = snapshot
def log_snapshot(self, numEntries: int) -> None:
if not self.snapshot:
return
stats = self.snapshot.statistics("traceback")
for s in stats[:numEntries]:
logger.info(f"Tracer snap: {s}")
for line in s.traceback:
logger.info(f"* {line}")
@staticmethod
def log_diff(
snap_current: tracemalloc.Snapshot,
snap_previous: tracemalloc.Snapshot,
numEntries: int,
) -> None:
stats = snap_current.compare_to(snap_previous, "traceback")
for s in stats[:numEntries]:
logger.info(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.info(f"* {line}")
def log_previous_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_prev:
return
DanswerTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)
def log_first_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_first:
return
DanswerTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)

View File

@ -291,6 +291,10 @@ INDEXING_SIZE_WARNING_THRESHOLD = int(
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024)
)
# during indexing, will log verbose memory diff stats every x batches and at the end.
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
#####
# Miscellaneous
#####