mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-22 14:34:09 +02:00
pass through various id's and log them in the model server for better… (#4485)
* pass through various id's and log them in the model server for better tracking * fix test --------- Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
This commit is contained in:
parent
caa9b106e4
commit
3fc8027e73
@ -17,7 +17,9 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
basic_router as enterprise_settings_router,
|
||||
)
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
|
||||
from ee.onyx.server.middleware.tenant_tracking import (
|
||||
add_api_server_tenant_id_middleware,
|
||||
)
|
||||
from ee.onyx.server.oauth.api import router as ee_oauth_router
|
||||
from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
@ -79,7 +81,7 @@ def get_application() -> FastAPI:
|
||||
application = get_application_base(lifespan_override=lifespan)
|
||||
|
||||
if MULTI_TENANT:
|
||||
add_tenant_id_middleware(application, logger)
|
||||
add_api_server_tenant_id_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
|
@ -18,11 +18,18 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
|
||||
def add_api_server_tenant_id_middleware(
|
||||
app: FastAPI, logger: logging.LoggerAdapter
|
||||
) -> None:
|
||||
@app.middleware("http")
|
||||
async def set_tenant_id(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
"""Extracts the tenant id from multiple locations and sets the context var.
|
||||
|
||||
This is very specific to the api server and probably not something you'd want
|
||||
to use elsewhere.
|
||||
"""
|
||||
try:
|
||||
if MULTI_TENANT:
|
||||
tenant_id = await _get_tenant_id_from_request(request, logger)
|
||||
|
@ -24,6 +24,7 @@ from onyx import __version__
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import setup_uvicorn_logger
|
||||
from onyx.utils.middleware import add_onyx_request_id_middleware
|
||||
from onyx.utils.middleware import add_onyx_tenant_id_middleware
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import MIN_THREADS_ML_MODELS
|
||||
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
|
||||
@ -126,6 +127,7 @@ def get_model_app() -> FastAPI:
|
||||
if INDEXING_ONLY:
|
||||
request_id_prefix = "IDX"
|
||||
|
||||
add_onyx_tenant_id_middleware(application, logger)
|
||||
add_onyx_request_id_middleware(application, request_id_prefix, logger)
|
||||
|
||||
# Initialize and instrument the app
|
||||
|
@ -58,6 +58,7 @@ from onyx.natural_language_processing.search_nlp_models import (
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
@ -379,6 +380,7 @@ def _run_indexing(
|
||||
memory_tracer.start()
|
||||
|
||||
index_attempt_md = IndexAttemptMetadata(
|
||||
attempt_id=index_attempt_id,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
)
|
||||
@ -481,6 +483,8 @@ def _run_indexing(
|
||||
|
||||
batch_description = []
|
||||
|
||||
# Generate an ID that can be used to correlate activity between here
|
||||
# and the embedding model server
|
||||
doc_batch_cleaned = strip_null_characters(document_batch)
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
@ -502,6 +506,10 @@ def _run_indexing(
|
||||
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
|
||||
index_attempt_md.request_id = make_randomized_onyx_request_id("CIX")
|
||||
index_attempt_md.structured_id = (
|
||||
f"{tenant_id}:{ctx.cc_pair_id}:{index_attempt_id}:{batch_num}"
|
||||
)
|
||||
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
|
||||
|
||||
# real work happens here!
|
||||
|
@ -272,9 +272,14 @@ class SlimDocument(BaseModel):
|
||||
|
||||
|
||||
class IndexAttemptMetadata(BaseModel):
|
||||
batch_num: int | None = None
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
batch_num: int | None = None
|
||||
attempt_id: int | None = None
|
||||
request_id: str | None = None
|
||||
|
||||
# Work in progress: will likely contain metadata about cc pair / index attempt
|
||||
structured_id: str | None = None
|
||||
|
||||
|
||||
class ConnectorCheckpoint(BaseModel):
|
||||
|
@ -135,6 +135,7 @@ class Chunker:
|
||||
mini_chunk_size: int = MINI_CHUNK_SIZE,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
# from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
self.include_metadata = include_metadata
|
||||
|
@ -73,6 +73,8 @@ class IndexingEmbedder(ABC):
|
||||
def embed_chunks(
|
||||
self,
|
||||
chunks: list[DocAwareChunk],
|
||||
tenant_id: str | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> list[IndexChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -110,6 +112,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
def embed_chunks(
|
||||
self,
|
||||
chunks: list[DocAwareChunk],
|
||||
tenant_id: str | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> list[IndexChunk]:
|
||||
"""Adds embeddings to the chunks, the title and metadata suffixes are added to the chunk as well
|
||||
if they exist. If there is no space for it, it would have been thrown out at the chunking step.
|
||||
@ -143,6 +147,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
texts=flat_chunk_texts,
|
||||
text_type=EmbedTextType.PASSAGE,
|
||||
large_chunks_present=large_chunks_present,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
chunk_titles = {
|
||||
@ -158,7 +164,10 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
title_embed_dict: dict[str, Embedding] = {}
|
||||
if chunk_titles_list:
|
||||
title_embeddings = self.embedding_model.encode(
|
||||
chunk_titles_list, text_type=EmbedTextType.PASSAGE
|
||||
chunk_titles_list,
|
||||
text_type=EmbedTextType.PASSAGE,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
title_embed_dict.update(
|
||||
{
|
||||
@ -190,7 +199,10 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
"Title had to be embedded separately, this should not happen!"
|
||||
)
|
||||
title_embedding = self.embedding_model.encode(
|
||||
[title], text_type=EmbedTextType.PASSAGE
|
||||
[title],
|
||||
text_type=EmbedTextType.PASSAGE,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)[0]
|
||||
title_embed_dict[title] = title_embedding
|
||||
|
||||
@ -231,14 +243,24 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
def embed_chunks_with_failure_handling(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedder: IndexingEmbedder,
|
||||
tenant_id: str | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
"""Tries to embed all chunks in one large batch. If that batch fails for any reason,
|
||||
goes document by document to isolate the failure(s).
|
||||
"""
|
||||
|
||||
# TODO(rkuo): this doesn't disambiguate calls to the model server on retries.
|
||||
# Improve this if needed.
|
||||
|
||||
# First try to embed all chunks in one batch
|
||||
try:
|
||||
return embedder.embed_chunks(chunks=chunks), []
|
||||
return (
|
||||
embedder.embed_chunks(
|
||||
chunks=chunks, tenant_id=tenant_id, request_id=request_id
|
||||
),
|
||||
[],
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to embed chunk batch. Trying individual docs.")
|
||||
# wait a couple seconds to let any rate limits or temporary issues resolve
|
||||
@ -254,7 +276,9 @@ def embed_chunks_with_failure_handling(
|
||||
|
||||
for doc_id, chunks_for_doc in chunks_by_doc.items():
|
||||
try:
|
||||
doc_embedded_chunks = embedder.embed_chunks(chunks=chunks_for_doc)
|
||||
doc_embedded_chunks = embedder.embed_chunks(
|
||||
chunks=chunks_for_doc, tenant_id=tenant_id, request_id=request_id
|
||||
)
|
||||
embedded_chunks.extend(doc_embedded_chunks)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to embed chunks for document '{doc_id}'")
|
||||
|
@ -791,6 +791,8 @@ def index_doc_batch(
|
||||
embed_chunks_with_failure_handling(
|
||||
chunks=chunks,
|
||||
embedder=embedder,
|
||||
tenant_id=tenant_id,
|
||||
request_id=index_attempt_metadata.request_id,
|
||||
)
|
||||
if chunks
|
||||
else ([], [])
|
||||
|
@ -3,6 +3,7 @@ import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
@ -114,10 +115,24 @@ class EmbeddingModel:
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
|
||||
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
|
||||
def _make_model_server_request(
|
||||
self,
|
||||
embed_request: EmbedRequest,
|
||||
tenant_id: str | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> EmbedResponse:
|
||||
def _make_request() -> Response:
|
||||
headers = {}
|
||||
if tenant_id:
|
||||
headers["X-Onyx-Tenant-ID"] = tenant_id
|
||||
|
||||
if request_id:
|
||||
headers["X-Onyx-Request-ID"] = request_id
|
||||
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.model_dump()
|
||||
self.embed_server_endpoint,
|
||||
headers=headers,
|
||||
json=embed_request.model_dump(),
|
||||
)
|
||||
# signify that this is a rate limit error
|
||||
if response.status_code == 429:
|
||||
@ -165,6 +180,8 @@ class EmbeddingModel:
|
||||
batch_size: int,
|
||||
max_seq_length: int,
|
||||
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
|
||||
tenant_id: str | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> list[Embedding]:
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
|
||||
@ -175,7 +192,11 @@ class EmbeddingModel:
|
||||
embeddings: list[Embedding] = []
|
||||
|
||||
def process_batch(
|
||||
batch_idx: int, batch_len: int, text_batch: list[str]
|
||||
batch_idx: int,
|
||||
batch_len: int,
|
||||
text_batch: list[str],
|
||||
tenant_id: str | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> tuple[int, list[Embedding]]:
|
||||
if self.callback:
|
||||
if self.callback.should_stop():
|
||||
@ -198,7 +219,9 @@ class EmbeddingModel:
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
response = self._make_model_server_request(embed_request)
|
||||
response = self._make_model_server_request(
|
||||
embed_request, tenant_id=tenant_id, request_id=request_id
|
||||
)
|
||||
end_time = time.time()
|
||||
|
||||
processing_time = end_time - start_time
|
||||
@ -215,7 +238,16 @@ class EmbeddingModel:
|
||||
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
future_to_batch = {
|
||||
executor.submit(process_batch, idx, len(text_batches), batch): idx
|
||||
executor.submit(
|
||||
partial(
|
||||
process_batch,
|
||||
idx,
|
||||
len(text_batches),
|
||||
batch,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
): idx
|
||||
for idx, batch in enumerate(text_batches, start=1)
|
||||
}
|
||||
|
||||
@ -238,7 +270,13 @@ class EmbeddingModel:
|
||||
else:
|
||||
# Original sequential processing
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
_, batch_embeddings = process_batch(idx, len(text_batches), text_batch)
|
||||
_, batch_embeddings = process_batch(
|
||||
idx,
|
||||
len(text_batches),
|
||||
text_batch,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
embeddings.extend(batch_embeddings)
|
||||
if self.callback:
|
||||
self.callback.progress("_batch_encode_texts", 1)
|
||||
@ -253,6 +291,8 @@ class EmbeddingModel:
|
||||
local_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
api_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
|
||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
tenant_id: str | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> list[Embedding]:
|
||||
if not texts or not all(texts):
|
||||
raise ValueError(f"Empty or missing text for embedding: {texts}")
|
||||
@ -284,6 +324,8 @@ class EmbeddingModel:
|
||||
text_type=text_type,
|
||||
batch_size=batch_size,
|
||||
max_seq_length=max_seq_length,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -11,9 +11,23 @@ from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def add_onyx_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
|
||||
@app.middleware("http")
|
||||
async def set_tenant_id(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
"""Captures and sets the context var for the tenant."""
|
||||
|
||||
onyx_tenant_id = request.headers.get("X-Onyx-Tenant-ID")
|
||||
if onyx_tenant_id:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(onyx_tenant_id)
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def add_onyx_request_id_middleware(
|
||||
app: FastAPI, prefix: str, logger: logging.LoggerAdapter
|
||||
) -> None:
|
||||
|
@ -88,14 +88,18 @@ def test_default_indexing_embedder_embed_chunks(
|
||||
)
|
||||
assert result[0].title_embedding == [7.0, 8.0, 9.0]
|
||||
|
||||
# Verify the embedding model was called correctly
|
||||
# Verify the embedding model was called exactly as follows
|
||||
mock_embedding_model.return_value.encode.assert_any_call(
|
||||
texts=[f"Title: {doc_summary}Test chunk{chunk_context}"],
|
||||
text_type=EmbedTextType.PASSAGE,
|
||||
large_chunks_present=False,
|
||||
tenant_id=None,
|
||||
request_id=None,
|
||||
)
|
||||
# title only embedding call
|
||||
# Same for title only embedding call
|
||||
mock_embedding_model.return_value.encode.assert_any_call(
|
||||
["Test Document"],
|
||||
text_type=EmbedTextType.PASSAGE,
|
||||
tenant_id=None,
|
||||
request_id=None,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user