From 288daa4e901944226bbe76098c934387bab33cf8 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Thu, 30 Jan 2025 17:33:42 -0800 Subject: [PATCH] Add more airtable logging (#3862) * Add more airtable logging * Add multithreading * Remove empty comment --- backend/onyx/configs/app_configs.py | 6 +++ .../connectors/airtable/airtable_connector.py | 51 +++++++++++++++---- backend/onyx/connectors/connector_runner.py | 13 ++++- backend/onyx/connectors/models.py | 10 ++++ backend/onyx/indexing/indexing_pipeline.py | 9 ++++ .../search_nlp_models.py | 50 ++++++++++++++++-- 6 files changed, 124 insertions(+), 15 deletions(-) diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 3235f6127..d121d0517 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -478,6 +478,12 @@ INDEXING_SIZE_WARNING_THRESHOLD = int( # 0 disables this behavior and is the default. INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0) +# Enable multi-threaded embedding model calls for parallel processing +# Note: only applies for API-based embedding models +INDEXING_EMBEDDING_MODEL_NUM_THREADS = int( + os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 1 +) + # During an indexing attempt, specifies the number of batches which are allowed to # exception without aborting the attempt. INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0) diff --git a/backend/onyx/connectors/airtable/airtable_connector.py b/backend/onyx/connectors/airtable/airtable_connector.py index 777f2137f..211fe3b44 100644 --- a/backend/onyx/connectors/airtable/airtable_connector.py +++ b/backend/onyx/connectors/airtable/airtable_connector.py @@ -1,3 +1,5 @@ +from concurrent.futures import as_completed +from concurrent.futures import ThreadPoolExecutor from io import BytesIO from typing import Any @@ -274,6 +276,11 @@ class AirtableConnector(LoadConnector): field_val = fields.get(field_name) field_type = field_schema.type + logger.debug( + f"Processing field '{field_name}' of type '{field_type}' " + f"for record '{record_id}'." + ) + field_sections, field_metadata = self._process_field( field_id=field_schema.id, field_name=field_name, @@ -327,19 +334,45 @@ class AirtableConnector(LoadConnector): primary_field_name = field.name break - record_documents: list[Document] = [] - for record in records: - document = self._process_record( - record=record, - table_schema=table_schema, - primary_field_name=primary_field_name, - ) - if document: - record_documents.append(document) + logger.info(f"Starting to process Airtable records for {table.name}.") + # Process records in parallel batches using ThreadPoolExecutor + PARALLEL_BATCH_SIZE = 16 + max_workers = min(PARALLEL_BATCH_SIZE, len(records)) + + # Process records in batches + for i in range(0, len(records), PARALLEL_BATCH_SIZE): + batch_records = records[i : i + PARALLEL_BATCH_SIZE] + record_documents: list[Document] = [] + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit batch tasks + future_to_record = { + executor.submit( + self._process_record, + record=record, + table_schema=table_schema, + primary_field_name=primary_field_name, + ): record + for record in batch_records + } + + # Wait for all tasks in this batch to complete + for future in as_completed(future_to_record): + record = future_to_record[future] + try: + document = future.result() + if document: + record_documents.append(document) + except Exception as e: + logger.exception(f"Failed to process record {record['id']}") + raise e + + # After batch is complete, yield if we've hit the batch size if len(record_documents) >= self.batch_size: yield record_documents record_documents = [] + # Yield any remaining records if record_documents: yield record_documents diff --git a/backend/onyx/connectors/connector_runner.py b/backend/onyx/connectors/connector_runner.py index 650aa76b1..ffb35f4e6 100644 --- a/backend/onyx/connectors/connector_runner.py +++ b/backend/onyx/connectors/connector_runner.py @@ -1,4 +1,5 @@ import sys +import time from datetime import datetime from onyx.connectors.interfaces import BaseConnector @@ -45,7 +46,17 @@ class ConnectorRunner: def run(self) -> GenerateDocumentsOutput: """Adds additional exception logging to the connector.""" try: - yield from self.doc_batch_generator + start = time.monotonic() + for batch in self.doc_batch_generator: + # to know how long connector is taking + logger.debug( + f"Connector took {time.monotonic() - start} seconds to build a batch." + ) + + yield batch + + start = time.monotonic() + except Exception: exc_type, _, exc_traceback = sys.exc_info() diff --git a/backend/onyx/connectors/models.py b/backend/onyx/connectors/models.py index ee66d4b50..41123318a 100644 --- a/backend/onyx/connectors/models.py +++ b/backend/onyx/connectors/models.py @@ -150,6 +150,16 @@ class Document(DocumentBase): id: str # This must be unique or during indexing/reindexing, chunks will be overwritten source: DocumentSource + def get_total_char_length(self) -> int: + """Calculate the total character length of the document including sections, metadata, and identifiers.""" + section_length = sum(len(section.text) for section in self.sections) + identifier_length = len(self.semantic_identifier) + len(self.title or "") + metadata_length = sum( + len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v) + for k, v in self.metadata.items() + ) + return section_length + identifier_length + metadata_length + def to_short_descriptor(self) -> str: """Used when logging the identity of a document""" return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'" diff --git a/backend/onyx/indexing/indexing_pipeline.py b/backend/onyx/indexing/indexing_pipeline.py index ea7228a97..e965a5ca4 100644 --- a/backend/onyx/indexing/indexing_pipeline.py +++ b/backend/onyx/indexing/indexing_pipeline.py @@ -380,6 +380,15 @@ def index_doc_batch( new_docs=0, total_docs=len(filtered_documents), total_chunks=0 ) + doc_descriptors = [ + { + "doc_id": doc.id, + "doc_length": doc.get_total_char_length(), + } + for doc in ctx.updatable_docs + ] + logger.debug(f"Starting indexing process for documents: {doc_descriptors}") + logger.debug("Starting chunking") chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs) diff --git a/backend/onyx/natural_language_processing/search_nlp_models.py b/backend/onyx/natural_language_processing/search_nlp_models.py index b7e54e81a..2f6c6d306 100644 --- a/backend/onyx/natural_language_processing/search_nlp_models.py +++ b/backend/onyx/natural_language_processing/search_nlp_models.py @@ -1,6 +1,8 @@ import threading import time from collections.abc import Callable +from concurrent.futures import as_completed +from concurrent.futures import ThreadPoolExecutor from functools import wraps from typing import Any @@ -11,6 +13,7 @@ from requests import RequestException from requests import Response from retry import retry +from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS from onyx.configs.app_configs import LARGE_CHUNK_RATIO from onyx.configs.app_configs import SKIP_WARM_UP from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS @@ -155,6 +158,7 @@ class EmbeddingModel: text_type: EmbedTextType, batch_size: int, max_seq_length: int, + num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS, ) -> list[Embedding]: text_batches = batch_list(texts, batch_size) @@ -163,12 +167,15 @@ class EmbeddingModel: ) embeddings: list[Embedding] = [] - for idx, text_batch in enumerate(text_batches, start=1): + + def process_batch( + batch_idx: int, text_batch: list[str] + ) -> tuple[int, list[Embedding]]: if self.callback: if self.callback.should_stop(): raise RuntimeError("_batch_encode_texts detected stop signal") - logger.debug(f"Encoding batch {idx} of {len(text_batches)}") + logger.debug(f"Encoding batch {batch_idx} of {len(text_batches)}") embed_request = EmbedRequest( model_name=self.model_name, texts=text_batch, @@ -185,10 +192,43 @@ class EmbeddingModel: ) response = self._make_model_server_request(embed_request) - embeddings.extend(response.embeddings) + return batch_idx, response.embeddings + + # only multi thread if: + # 1. num_threads is greater than 1 + # 2. we are using an API-based embedding model (provider_type is not None) + # 3. there are more than 1 batch (no point in threading if only 1) + if num_threads >= 1 and self.provider_type and len(text_batches) > 1: + with ThreadPoolExecutor(max_workers=num_threads) as executor: + future_to_batch = { + executor.submit(process_batch, idx, batch): idx + for idx, batch in enumerate(text_batches, start=1) + } + + # Collect results in order + batch_results: list[tuple[int, list[Embedding]]] = [] + for future in as_completed(future_to_batch): + try: + result = future.result() + batch_results.append(result) + if self.callback: + self.callback.progress("_batch_encode_texts", 1) + except Exception as e: + logger.exception("Embedding model failed to process batch") + raise e + + # Sort by batch index and extend embeddings + batch_results.sort(key=lambda x: x[0]) + for _, batch_embeddings in batch_results: + embeddings.extend(batch_embeddings) + else: + # Original sequential processing + for idx, text_batch in enumerate(text_batches, start=1): + _, batch_embeddings = process_batch(idx, text_batch) + embeddings.extend(batch_embeddings) + if self.callback: + self.callback.progress("_batch_encode_texts", 1) - if self.callback: - self.callback.progress("_batch_encode_texts", 1) return embeddings def encode(