pass through various id's and log them in the model server for better… (#4485)

* pass through various id's and log them in the model server for better tracking

* fix test

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer 2025-04-09 17:40:57 -07:00 committed by GitHub
parent caa9b106e4
commit 3fc8027e73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 127 additions and 16 deletions

View File

@ -17,7 +17,9 @@ from ee.onyx.server.enterprise_settings.api import (
basic_router as enterprise_settings_router,
)
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.onyx.server.middleware.tenant_tracking import (
add_api_server_tenant_id_middleware,
)
from ee.onyx.server.oauth.api import router as ee_oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
@ -79,7 +81,7 @@ def get_application() -> FastAPI:
application = get_application_base(lifespan_override=lifespan)
if MULTI_TENANT:
add_tenant_id_middleware(application, logger)
add_api_server_tenant_id_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
# For Google OAuth, refresh tokens are requested by:

View File

@ -18,11 +18,18 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
def add_api_server_tenant_id_middleware(
app: FastAPI, logger: logging.LoggerAdapter
) -> None:
@app.middleware("http")
async def set_tenant_id(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Extracts the tenant id from multiple locations and sets the context var.
This is very specific to the api server and probably not something you'd want
to use elsewhere.
"""
try:
if MULTI_TENANT:
tenant_id = await _get_tenant_id_from_request(request, logger)

View File

@ -24,6 +24,7 @@ from onyx import __version__
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_onyx_request_id_middleware
from onyx.utils.middleware import add_onyx_tenant_id_middleware
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
@ -126,6 +127,7 @@ def get_model_app() -> FastAPI:
if INDEXING_ONLY:
request_id_prefix = "IDX"
add_onyx_tenant_id_middleware(application, logger)
add_onyx_request_id_middleware(application, request_id_prefix, logger)
# Initialize and instrument the app

View File

@ -58,6 +58,7 @@ from onyx.natural_language_processing.search_nlp_models import (
)
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.middleware import make_randomized_onyx_request_id
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
@ -379,6 +380,7 @@ def _run_indexing(
memory_tracer.start()
index_attempt_md = IndexAttemptMetadata(
attempt_id=index_attempt_id,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
@ -481,6 +483,8 @@ def _run_indexing(
batch_description = []
# Generate an ID that can be used to correlate activity between here
# and the embedding model server
doc_batch_cleaned = strip_null_characters(document_batch)
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
@ -502,6 +506,10 @@ def _run_indexing(
logger.debug(f"Indexing batch of documents: {batch_description}")
index_attempt_md.request_id = make_randomized_onyx_request_id("CIX")
index_attempt_md.structured_id = (
f"{tenant_id}:{ctx.cc_pair_id}:{index_attempt_id}:{batch_num}"
)
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!

View File

@ -272,9 +272,14 @@ class SlimDocument(BaseModel):
class IndexAttemptMetadata(BaseModel):
batch_num: int | None = None
connector_id: int
credential_id: int
batch_num: int | None = None
attempt_id: int | None = None
request_id: str | None = None
# Work in progress: will likely contain metadata about cc pair / index attempt
structured_id: str | None = None
class ConnectorCheckpoint(BaseModel):

View File

@ -135,6 +135,7 @@ class Chunker:
mini_chunk_size: int = MINI_CHUNK_SIZE,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
# from llama_index.core.node_parser import SentenceSplitter
from llama_index.text_splitter import SentenceSplitter
self.include_metadata = include_metadata

View File

@ -73,6 +73,8 @@ class IndexingEmbedder(ABC):
def embed_chunks(
self,
chunks: list[DocAwareChunk],
tenant_id: str | None = None,
request_id: str | None = None,
) -> list[IndexChunk]:
raise NotImplementedError
@ -110,6 +112,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
def embed_chunks(
self,
chunks: list[DocAwareChunk],
tenant_id: str | None = None,
request_id: str | None = None,
) -> list[IndexChunk]:
"""Adds embeddings to the chunks, the title and metadata suffixes are added to the chunk as well
if they exist. If there is no space for it, it would have been thrown out at the chunking step.
@ -143,6 +147,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
texts=flat_chunk_texts,
text_type=EmbedTextType.PASSAGE,
large_chunks_present=large_chunks_present,
tenant_id=tenant_id,
request_id=request_id,
)
chunk_titles = {
@ -158,7 +164,10 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embed_dict: dict[str, Embedding] = {}
if chunk_titles_list:
title_embeddings = self.embedding_model.encode(
chunk_titles_list, text_type=EmbedTextType.PASSAGE
chunk_titles_list,
text_type=EmbedTextType.PASSAGE,
tenant_id=tenant_id,
request_id=request_id,
)
title_embed_dict.update(
{
@ -190,7 +199,10 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
"Title had to be embedded separately, this should not happen!"
)
title_embedding = self.embedding_model.encode(
[title], text_type=EmbedTextType.PASSAGE
[title],
text_type=EmbedTextType.PASSAGE,
tenant_id=tenant_id,
request_id=request_id,
)[0]
title_embed_dict[title] = title_embedding
@ -231,14 +243,24 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
def embed_chunks_with_failure_handling(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
tenant_id: str | None = None,
request_id: str | None = None,
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
"""Tries to embed all chunks in one large batch. If that batch fails for any reason,
goes document by document to isolate the failure(s).
"""
# TODO(rkuo): this doesn't disambiguate calls to the model server on retries.
# Improve this if needed.
# First try to embed all chunks in one batch
try:
return embedder.embed_chunks(chunks=chunks), []
return (
embedder.embed_chunks(
chunks=chunks, tenant_id=tenant_id, request_id=request_id
),
[],
)
except Exception:
logger.exception("Failed to embed chunk batch. Trying individual docs.")
# wait a couple seconds to let any rate limits or temporary issues resolve
@ -254,7 +276,9 @@ def embed_chunks_with_failure_handling(
for doc_id, chunks_for_doc in chunks_by_doc.items():
try:
doc_embedded_chunks = embedder.embed_chunks(chunks=chunks_for_doc)
doc_embedded_chunks = embedder.embed_chunks(
chunks=chunks_for_doc, tenant_id=tenant_id, request_id=request_id
)
embedded_chunks.extend(doc_embedded_chunks)
except Exception as e:
logger.exception(f"Failed to embed chunks for document '{doc_id}'")

View File

@ -791,6 +791,8 @@ def index_doc_batch(
embed_chunks_with_failure_handling(
chunks=chunks,
embedder=embedder,
tenant_id=tenant_id,
request_id=index_attempt_metadata.request_id,
)
if chunks
else ([], [])

View File

@ -3,6 +3,7 @@ import time
from collections.abc import Callable
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from functools import wraps
from typing import Any
@ -114,10 +115,24 @@ class EmbeddingModel:
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_model_server_request(
self,
embed_request: EmbedRequest,
tenant_id: str | None = None,
request_id: str | None = None,
) -> EmbedResponse:
def _make_request() -> Response:
headers = {}
if tenant_id:
headers["X-Onyx-Tenant-ID"] = tenant_id
if request_id:
headers["X-Onyx-Request-ID"] = request_id
response = requests.post(
self.embed_server_endpoint, json=embed_request.model_dump()
self.embed_server_endpoint,
headers=headers,
json=embed_request.model_dump(),
)
# signify that this is a rate limit error
if response.status_code == 429:
@ -165,6 +180,8 @@ class EmbeddingModel:
batch_size: int,
max_seq_length: int,
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
tenant_id: str | None = None,
request_id: str | None = None,
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)
@ -175,7 +192,11 @@ class EmbeddingModel:
embeddings: list[Embedding] = []
def process_batch(
batch_idx: int, batch_len: int, text_batch: list[str]
batch_idx: int,
batch_len: int,
text_batch: list[str],
tenant_id: str | None = None,
request_id: str | None = None,
) -> tuple[int, list[Embedding]]:
if self.callback:
if self.callback.should_stop():
@ -198,7 +219,9 @@ class EmbeddingModel:
)
start_time = time.time()
response = self._make_model_server_request(embed_request)
response = self._make_model_server_request(
embed_request, tenant_id=tenant_id, request_id=request_id
)
end_time = time.time()
processing_time = end_time - start_time
@ -215,7 +238,16 @@ class EmbeddingModel:
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future_to_batch = {
executor.submit(process_batch, idx, len(text_batches), batch): idx
executor.submit(
partial(
process_batch,
idx,
len(text_batches),
batch,
tenant_id=tenant_id,
request_id=request_id,
)
): idx
for idx, batch in enumerate(text_batches, start=1)
}
@ -238,7 +270,13 @@ class EmbeddingModel:
else:
# Original sequential processing
for idx, text_batch in enumerate(text_batches, start=1):
_, batch_embeddings = process_batch(idx, len(text_batches), text_batch)
_, batch_embeddings = process_batch(
idx,
len(text_batches),
text_batch,
tenant_id=tenant_id,
request_id=request_id,
)
embeddings.extend(batch_embeddings)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
@ -253,6 +291,8 @@ class EmbeddingModel:
local_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
api_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
tenant_id: str | None = None,
request_id: str | None = None,
) -> list[Embedding]:
if not texts or not all(texts):
raise ValueError(f"Empty or missing text for embedding: {texts}")
@ -284,6 +324,8 @@ class EmbeddingModel:
text_type=text_type,
batch_size=batch_size,
max_seq_length=max_seq_length,
tenant_id=tenant_id,
request_id=request_id,
)
@classmethod

View File

@ -11,9 +11,23 @@ from fastapi import FastAPI
from fastapi import Request
from fastapi import Response
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
def add_onyx_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
@app.middleware("http")
async def set_tenant_id(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Captures and sets the context var for the tenant."""
onyx_tenant_id = request.headers.get("X-Onyx-Tenant-ID")
if onyx_tenant_id:
CURRENT_TENANT_ID_CONTEXTVAR.set(onyx_tenant_id)
return await call_next(request)
def add_onyx_request_id_middleware(
app: FastAPI, prefix: str, logger: logging.LoggerAdapter
) -> None:

View File

@ -88,14 +88,18 @@ def test_default_indexing_embedder_embed_chunks(
)
assert result[0].title_embedding == [7.0, 8.0, 9.0]
# Verify the embedding model was called correctly
# Verify the embedding model was called exactly as follows
mock_embedding_model.return_value.encode.assert_any_call(
texts=[f"Title: {doc_summary}Test chunk{chunk_context}"],
text_type=EmbedTextType.PASSAGE,
large_chunks_present=False,
tenant_id=None,
request_id=None,
)
# title only embedding call
# Same for title only embedding call
mock_embedding_model.return_value.encode.assert_any_call(
["Test Document"],
text_type=EmbedTextType.PASSAGE,
tenant_id=None,
request_id=None,
)