Add more airtable logging (#3862)

* Add more airtable logging

* Add multithreading

* Remove empty comment
This commit is contained in:
Chris Weaver
2025-01-30 17:33:42 -08:00
committed by GitHub
parent 5e21dc6cb3
commit 288daa4e90
6 changed files with 124 additions and 15 deletions

View File

@@ -1,6 +1,8 @@
import threading
import time
from collections.abc import Callable
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from typing import Any
@@ -11,6 +13,7 @@ from requests import RequestException
from requests import Response
from retry import retry
from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
from onyx.configs.app_configs import SKIP_WARM_UP
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
@@ -155,6 +158,7 @@ class EmbeddingModel:
text_type: EmbedTextType,
batch_size: int,
max_seq_length: int,
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)
@@ -163,12 +167,15 @@ class EmbeddingModel:
)
embeddings: list[Embedding] = []
for idx, text_batch in enumerate(text_batches, start=1):
def process_batch(
batch_idx: int, text_batch: list[str]
) -> tuple[int, list[Embedding]]:
if self.callback:
if self.callback.should_stop():
raise RuntimeError("_batch_encode_texts detected stop signal")
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
logger.debug(f"Encoding batch {batch_idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
@@ -185,10 +192,43 @@ class EmbeddingModel:
)
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
return batch_idx, response.embeddings
# only multi thread if:
# 1. num_threads is greater than 1
# 2. we are using an API-based embedding model (provider_type is not None)
# 3. there are more than 1 batch (no point in threading if only 1)
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future_to_batch = {
executor.submit(process_batch, idx, batch): idx
for idx, batch in enumerate(text_batches, start=1)
}
# Collect results in order
batch_results: list[tuple[int, list[Embedding]]] = []
for future in as_completed(future_to_batch):
try:
result = future.result()
batch_results.append(result)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
except Exception as e:
logger.exception("Embedding model failed to process batch")
raise e
# Sort by batch index and extend embeddings
batch_results.sort(key=lambda x: x[0])
for _, batch_embeddings in batch_results:
embeddings.extend(batch_embeddings)
else:
# Original sequential processing
for idx, text_batch in enumerate(text_batches, start=1):
_, batch_embeddings = process_batch(idx, text_batch)
embeddings.extend(batch_embeddings)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
return embeddings
def encode(