mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 20:39:29 +02:00
Added retries and multithreading for cloud embedding (#1879)
* added retries and multithreading for cloud embedding * refactored a bit * cleaned up code * got the errors to bubble up to the ui correctly * added exceptin printing * added requirements * touchups --------- Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
parent
7fbbb174bb
commit
eb3e7610fc
@ -4,7 +4,6 @@ from abc import abstractmethod
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
@ -15,7 +14,6 @@ from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
@ -71,7 +69,6 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
def embed_chunks(
|
||||
self,
|
||||
chunks: list[DocAwareChunk],
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
) -> list[IndexChunk]:
|
||||
# Cache the Title embeddings to only have to do it once
|
||||
@ -80,7 +77,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
|
||||
# Create Mini Chunks for more precise matching of details
|
||||
# Off by default with unedited settings
|
||||
chunk_texts = []
|
||||
chunk_texts: list[str] = []
|
||||
chunk_mini_chunks_count = {}
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
chunk_texts.append(chunk.content)
|
||||
@ -92,22 +89,9 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
chunk_texts.extend(mini_chunk_texts)
|
||||
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
||||
|
||||
# Batching for embedding
|
||||
text_batches = batch_list(chunk_texts, batch_size)
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
len_text_batches = len(text_batches)
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(
|
||||
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
|
||||
)
|
||||
|
||||
# Replace line above with the line below for easy debugging of indexing flow
|
||||
# skipping the actual model
|
||||
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
|
||||
embeddings = self.embedding_model.encode(
|
||||
chunk_texts, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
|
||||
chunk_titles = {
|
||||
chunk.source_document.get_title_for_document_index() for chunk in chunks
|
||||
@ -116,17 +100,15 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
# Drop any None or empty strings
|
||||
chunk_titles_list = [title for title in chunk_titles if title]
|
||||
|
||||
# Embed Titles in batches
|
||||
title_batches = batch_list(chunk_titles_list, batch_size)
|
||||
len_title_batches = len(title_batches)
|
||||
for ind_batch, title_batch in enumerate(title_batches, start=1):
|
||||
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
|
||||
title_embeddings = self.embedding_model.encode(
|
||||
title_batch, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
title_embed_dict.update(
|
||||
{title: vector for title, vector in zip(title_batch, title_embeddings)}
|
||||
)
|
||||
title_embeddings = self.embedding_model.encode(
|
||||
chunk_titles_list, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
title_embed_dict.update(
|
||||
{
|
||||
title: vector
|
||||
for title, vector in zip(chunk_titles_list, title_embeddings)
|
||||
}
|
||||
)
|
||||
|
||||
# Mapping embeddings to chunks
|
||||
embedding_ind_start = 0
|
||||
|
@ -5,10 +5,13 @@ from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
from httpx import HTTPError
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
@ -103,28 +106,69 @@ class EmbeddingModel:
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
|
||||
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
|
||||
if text_type == EmbedTextType.QUERY and self.query_prefix:
|
||||
prefixed_texts = [self.query_prefix + text for text in texts]
|
||||
elif text_type == EmbedTextType.PASSAGE and self.passage_prefix:
|
||||
prefixed_texts = [self.passage_prefix + text for text in texts]
|
||||
else:
|
||||
prefixed_texts = texts
|
||||
def encode(
|
||||
self,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
) -> list[list[float]]:
|
||||
if self.provider_type:
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=prefixed_texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
)
|
||||
# Batching for local embedding
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
embeddings: list[list[float]] = []
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Embedding Content Texts batch {idx} of {len(text_batches)}")
|
||||
|
||||
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
||||
response.raise_for_status()
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(EmbedResponse(**response.json()).embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class CrossEncoderEnsembleModel:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import concurrent.futures
|
||||
import gc
|
||||
import json
|
||||
from typing import Any
|
||||
@ -10,6 +11,7 @@ from cohere import Client as CohereClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||
@ -40,65 +42,44 @@ router = APIRouter(prefix="/encoder")
|
||||
|
||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
# If we are not only indexing, dont want retry very long
|
||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||
|
||||
|
||||
def _initialize_client(
|
||||
api_key: str, provider: EmbeddingProvider, model: str | None = None
|
||||
) -> Any:
|
||||
if provider == EmbeddingProvider.OPENAI:
|
||||
return openai.OpenAI(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.COHERE:
|
||||
return CohereClient(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.VOYAGE:
|
||||
return voyageai.Client(api_key=api_key)
|
||||
elif provider == EmbeddingProvider.GOOGLE:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(api_key)
|
||||
)
|
||||
project_id = json.loads(api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(self, api_key: str, provider: str, model: str | None = None):
|
||||
self.api_key = api_key
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
provider: str,
|
||||
# Only for Google as is needed on client setup
|
||||
self.model = model
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
self.provider = EmbeddingProvider(provider.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
self.client = self._initialize_client()
|
||||
|
||||
def _initialize_client(self) -> Any:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return openai.OpenAI(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.COHERE:
|
||||
return CohereClient(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return voyageai.Client(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.api_key)
|
||||
)
|
||||
project_id = json.loads(self.api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
return TextEmbeddingModel.from_pretrained(
|
||||
self.model or DEFAULT_VERTEX_MODEL
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def encode(
|
||||
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
|
||||
) -> list[list[float]]:
|
||||
return [
|
||||
self.embed(text=text, text_type=text_type, model=model_name)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
def embed(
|
||||
self, *, text: str, text_type: EmbedTextType, model: str | None = None
|
||||
) -> list[float]:
|
||||
logger.debug(f"Embedding text with provider: {self.provider}")
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(text, model)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(text, model, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
self.client = _initialize_client(api_key, self.provider, model)
|
||||
|
||||
def _embed_openai(self, text: str, model: str | None) -> list[float]:
|
||||
if model is None:
|
||||
@ -145,6 +126,62 @@ class CloudEmbedding:
|
||||
)
|
||||
return embedding[0].values
|
||||
|
||||
def _embed(
|
||||
self,
|
||||
*,
|
||||
text: str,
|
||||
text_type: EmbedTextType,
|
||||
model: str | None = None,
|
||||
) -> list[float]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(text, model)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(text, model, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error embedding text with {self.provider}: {str(e)}",
|
||||
)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
texts: list[str],
|
||||
model_name: str | None,
|
||||
text_type: EmbedTextType,
|
||||
) -> list[list[float]]:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
self._embed,
|
||||
text=text,
|
||||
text_type=text_type,
|
||||
model=model_name,
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
results = []
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as e:
|
||||
# Cancel all pending futures
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
# Raise the exception immediately
|
||||
raise e
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
api_key: str, provider: str, model: str | None = None
|
||||
@ -204,6 +241,7 @@ def warm_up_cross_encoders() -> None:
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||
def embed_text(
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
@ -212,29 +250,50 @@ def embed_text(
|
||||
normalize_embeddings: bool,
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
prefix: str | None,
|
||||
) -> list[list[float]]:
|
||||
if provider_type is not None:
|
||||
logger.debug(f"Embedding text with provider: {provider_type}")
|
||||
if api_key is None:
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
|
||||
if prefix:
|
||||
# This may change in the future if some providers require the user
|
||||
# to manually append a prefix but this is not the case currently
|
||||
raise ValueError(
|
||||
"Prefix string is not valid for cloud models. "
|
||||
"Cloud models take an explicit text type instead."
|
||||
)
|
||||
|
||||
cloud_model = CloudEmbedding(
|
||||
api_key=api_key, provider=provider_type, model=model_name
|
||||
)
|
||||
embeddings = cloud_model.encode(texts, model_name, text_type)
|
||||
embeddings = cloud_model.encode(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
elif model_name is not None:
|
||||
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||
hosted_model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings = hosted_model.encode(
|
||||
texts, normalize_embeddings=normalize_embeddings
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
)
|
||||
|
||||
if embeddings is None:
|
||||
raise RuntimeError("Embeddings were not created")
|
||||
raise RuntimeError("Failed to create Embeddings")
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
embeddings = embeddings.tolist()
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
@ -253,6 +312,13 @@ async def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
try:
|
||||
if embed_request.text_type == EmbedTextType.QUERY:
|
||||
prefix = embed_request.manual_query_prefix
|
||||
elif embed_request.text_type == EmbedTextType.PASSAGE:
|
||||
prefix = embed_request.manual_passage_prefix
|
||||
else:
|
||||
prefix = None
|
||||
|
||||
embeddings = embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
@ -261,13 +327,13 @@ async def process_embed_request(
|
||||
api_key=embed_request.api_key,
|
||||
provider_type=embed_request.provider_type,
|
||||
text_type=embed_request.text_type,
|
||||
prefix=prefix,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during embedding process:\n{str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to run Bi-Encoder embedding"
|
||||
)
|
||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||
logger.exception(exception_detail)
|
||||
raise HTTPException(status_code=500, detail=exception_detail)
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
|
@ -1,6 +1,7 @@
|
||||
fastapi==0.109.2
|
||||
h5py==3.9.0
|
||||
pydantic==1.10.13
|
||||
retry==0.9.2
|
||||
safetensors==0.4.2
|
||||
sentence-transformers==2.6.1
|
||||
tensorflow==2.15.0
|
||||
|
@ -4,9 +4,7 @@ from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
# This already includes any prefixes, the text is just passed directly to the model
|
||||
texts: list[str]
|
||||
|
||||
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||
model_name: str | None
|
||||
max_context_length: int
|
||||
@ -14,6 +12,8 @@ class EmbedRequest(BaseModel):
|
||||
api_key: str | None
|
||||
provider_type: str | None
|
||||
text_type: EmbedTextType
|
||||
manual_query_prefix: str | None
|
||||
manual_passage_prefix: str | None
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
|
Loading…
x
Reference in New Issue
Block a user