mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 05:43:33 +02:00
Added retries and multithreading for cloud embedding (#1879)
* added retries and multithreading for cloud embedding * refactored a bit * cleaned up code * got the errors to bubble up to the ui correctly * added exceptin printing * added requirements * touchups --------- Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
@@ -4,7 +4,6 @@ from abc import abstractmethod
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
|
||||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||||
@@ -15,7 +14,6 @@ from danswer.indexing.models import ChunkEmbedding
|
|||||||
from danswer.indexing.models import DocAwareChunk
|
from danswer.indexing.models import DocAwareChunk
|
||||||
from danswer.indexing.models import IndexChunk
|
from danswer.indexing.models import IndexChunk
|
||||||
from danswer.search.search_nlp_models import EmbeddingModel
|
from danswer.search.search_nlp_models import EmbeddingModel
|
||||||
from danswer.utils.batching import batch_list
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||||
@@ -71,7 +69,6 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
|||||||
def embed_chunks(
|
def embed_chunks(
|
||||||
self,
|
self,
|
||||||
chunks: list[DocAwareChunk],
|
chunks: list[DocAwareChunk],
|
||||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
|
||||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||||
) -> list[IndexChunk]:
|
) -> list[IndexChunk]:
|
||||||
# Cache the Title embeddings to only have to do it once
|
# Cache the Title embeddings to only have to do it once
|
||||||
@@ -80,7 +77,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
|||||||
|
|
||||||
# Create Mini Chunks for more precise matching of details
|
# Create Mini Chunks for more precise matching of details
|
||||||
# Off by default with unedited settings
|
# Off by default with unedited settings
|
||||||
chunk_texts = []
|
chunk_texts: list[str] = []
|
||||||
chunk_mini_chunks_count = {}
|
chunk_mini_chunks_count = {}
|
||||||
for chunk_ind, chunk in enumerate(chunks):
|
for chunk_ind, chunk in enumerate(chunks):
|
||||||
chunk_texts.append(chunk.content)
|
chunk_texts.append(chunk.content)
|
||||||
@@ -92,22 +89,9 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
|||||||
chunk_texts.extend(mini_chunk_texts)
|
chunk_texts.extend(mini_chunk_texts)
|
||||||
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
||||||
|
|
||||||
# Batching for embedding
|
embeddings = self.embedding_model.encode(
|
||||||
text_batches = batch_list(chunk_texts, batch_size)
|
chunk_texts, text_type=EmbedTextType.PASSAGE
|
||||||
|
)
|
||||||
embeddings: list[list[float]] = []
|
|
||||||
len_text_batches = len(text_batches)
|
|
||||||
for idx, text_batch in enumerate(text_batches, start=1):
|
|
||||||
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
|
|
||||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
|
||||||
# value for the set loss
|
|
||||||
embeddings.extend(
|
|
||||||
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Replace line above with the line below for easy debugging of indexing flow
|
|
||||||
# skipping the actual model
|
|
||||||
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
|
|
||||||
|
|
||||||
chunk_titles = {
|
chunk_titles = {
|
||||||
chunk.source_document.get_title_for_document_index() for chunk in chunks
|
chunk.source_document.get_title_for_document_index() for chunk in chunks
|
||||||
@@ -116,17 +100,15 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
|||||||
# Drop any None or empty strings
|
# Drop any None or empty strings
|
||||||
chunk_titles_list = [title for title in chunk_titles if title]
|
chunk_titles_list = [title for title in chunk_titles if title]
|
||||||
|
|
||||||
# Embed Titles in batches
|
title_embeddings = self.embedding_model.encode(
|
||||||
title_batches = batch_list(chunk_titles_list, batch_size)
|
chunk_titles_list, text_type=EmbedTextType.PASSAGE
|
||||||
len_title_batches = len(title_batches)
|
)
|
||||||
for ind_batch, title_batch in enumerate(title_batches, start=1):
|
title_embed_dict.update(
|
||||||
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
|
{
|
||||||
title_embeddings = self.embedding_model.encode(
|
title: vector
|
||||||
title_batch, text_type=EmbedTextType.PASSAGE
|
for title, vector in zip(chunk_titles_list, title_embeddings)
|
||||||
)
|
}
|
||||||
title_embed_dict.update(
|
)
|
||||||
{title: vector for title, vector in zip(title_batch, title_embeddings)}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mapping embeddings to chunks
|
# Mapping embeddings to chunks
|
||||||
embedding_ind_start = 0
|
embedding_ind_start = 0
|
||||||
|
@@ -5,10 +5,13 @@ from typing import Optional
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from httpx import HTTPError
|
||||||
from transformers import logging as transformer_logging # type:ignore
|
from transformers import logging as transformer_logging # type:ignore
|
||||||
|
|
||||||
|
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||||
|
from danswer.utils.batching import batch_list
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from shared_configs.configs import MODEL_SERVER_HOST
|
from shared_configs.configs import MODEL_SERVER_HOST
|
||||||
from shared_configs.configs import MODEL_SERVER_PORT
|
from shared_configs.configs import MODEL_SERVER_PORT
|
||||||
@@ -103,28 +106,69 @@ class EmbeddingModel:
|
|||||||
model_server_url = build_model_server_url(server_host, server_port)
|
model_server_url = build_model_server_url(server_host, server_port)
|
||||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||||
|
|
||||||
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
|
def encode(
|
||||||
if text_type == EmbedTextType.QUERY and self.query_prefix:
|
self,
|
||||||
prefixed_texts = [self.query_prefix + text for text in texts]
|
texts: list[str],
|
||||||
elif text_type == EmbedTextType.PASSAGE and self.passage_prefix:
|
text_type: EmbedTextType,
|
||||||
prefixed_texts = [self.passage_prefix + text for text in texts]
|
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||||
else:
|
) -> list[list[float]]:
|
||||||
prefixed_texts = texts
|
if self.provider_type:
|
||||||
|
embed_request = EmbedRequest(
|
||||||
|
model_name=self.model_name,
|
||||||
|
texts=texts,
|
||||||
|
max_context_length=self.max_seq_length,
|
||||||
|
normalize_embeddings=self.normalize,
|
||||||
|
api_key=self.api_key,
|
||||||
|
provider_type=self.provider_type,
|
||||||
|
text_type=text_type,
|
||||||
|
manual_query_prefix=self.query_prefix,
|
||||||
|
manual_passage_prefix=self.passage_prefix,
|
||||||
|
)
|
||||||
|
response = requests.post(
|
||||||
|
self.embed_server_endpoint, json=embed_request.dict()
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
error_detail = response.json().get("detail", str(e))
|
||||||
|
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||||
|
return EmbedResponse(**response.json()).embeddings
|
||||||
|
|
||||||
embed_request = EmbedRequest(
|
# Batching for local embedding
|
||||||
model_name=self.model_name,
|
text_batches = batch_list(texts, batch_size)
|
||||||
texts=prefixed_texts,
|
embeddings: list[list[float]] = []
|
||||||
max_context_length=self.max_seq_length,
|
for idx, text_batch in enumerate(text_batches, start=1):
|
||||||
normalize_embeddings=self.normalize,
|
logger.debug(f"Embedding Content Texts batch {idx} of {len(text_batches)}")
|
||||||
api_key=self.api_key,
|
|
||||||
provider_type=self.provider_type,
|
|
||||||
text_type=text_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
embed_request = EmbedRequest(
|
||||||
response.raise_for_status()
|
model_name=self.model_name,
|
||||||
|
texts=text_batch,
|
||||||
|
max_context_length=self.max_seq_length,
|
||||||
|
normalize_embeddings=self.normalize,
|
||||||
|
api_key=self.api_key,
|
||||||
|
provider_type=self.provider_type,
|
||||||
|
text_type=text_type,
|
||||||
|
manual_query_prefix=self.query_prefix,
|
||||||
|
manual_passage_prefix=self.passage_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
return EmbedResponse(**response.json()).embeddings
|
response = requests.post(
|
||||||
|
self.embed_server_endpoint, json=embed_request.dict()
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
error_detail = response.json().get("detail", str(e))
|
||||||
|
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||||
|
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||||
|
# value for the set loss
|
||||||
|
embeddings.extend(EmbedResponse(**response.json()).embeddings)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
class CrossEncoderEnsembleModel:
|
class CrossEncoderEnsembleModel:
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import concurrent.futures
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -10,6 +11,7 @@ from cohere import Client as CohereClient
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from google.oauth2 import service_account # type: ignore
|
from google.oauth2 import service_account # type: ignore
|
||||||
|
from retry import retry
|
||||||
from sentence_transformers import CrossEncoder # type: ignore
|
from sentence_transformers import CrossEncoder # type: ignore
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||||
@@ -40,65 +42,44 @@ router = APIRouter(prefix="/encoder")
|
|||||||
|
|
||||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||||
|
# If we are not only indexing, dont want retry very long
|
||||||
|
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||||
|
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_client(
|
||||||
|
api_key: str, provider: EmbeddingProvider, model: str | None = None
|
||||||
|
) -> Any:
|
||||||
|
if provider == EmbeddingProvider.OPENAI:
|
||||||
|
return openai.OpenAI(api_key=api_key)
|
||||||
|
elif provider == EmbeddingProvider.COHERE:
|
||||||
|
return CohereClient(api_key=api_key)
|
||||||
|
elif provider == EmbeddingProvider.VOYAGE:
|
||||||
|
return voyageai.Client(api_key=api_key)
|
||||||
|
elif provider == EmbeddingProvider.GOOGLE:
|
||||||
|
credentials = service_account.Credentials.from_service_account_info(
|
||||||
|
json.loads(api_key)
|
||||||
|
)
|
||||||
|
project_id = json.loads(api_key)["project_id"]
|
||||||
|
vertexai.init(project=project_id, credentials=credentials)
|
||||||
|
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
|
|
||||||
|
|
||||||
class CloudEmbedding:
|
class CloudEmbedding:
|
||||||
def __init__(self, api_key: str, provider: str, model: str | None = None):
|
def __init__(
|
||||||
self.api_key = api_key
|
self,
|
||||||
|
api_key: str,
|
||||||
|
provider: str,
|
||||||
# Only for Google as is needed on client setup
|
# Only for Google as is needed on client setup
|
||||||
self.model = model
|
model: str | None = None,
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
self.provider = EmbeddingProvider(provider.lower())
|
self.provider = EmbeddingProvider(provider.lower())
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(f"Unsupported provider: {provider}")
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
self.client = self._initialize_client()
|
self.client = _initialize_client(api_key, self.provider, model)
|
||||||
|
|
||||||
def _initialize_client(self) -> Any:
|
|
||||||
if self.provider == EmbeddingProvider.OPENAI:
|
|
||||||
return openai.OpenAI(api_key=self.api_key)
|
|
||||||
elif self.provider == EmbeddingProvider.COHERE:
|
|
||||||
return CohereClient(api_key=self.api_key)
|
|
||||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
|
||||||
return voyageai.Client(api_key=self.api_key)
|
|
||||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
|
||||||
credentials = service_account.Credentials.from_service_account_info(
|
|
||||||
json.loads(self.api_key)
|
|
||||||
)
|
|
||||||
project_id = json.loads(self.api_key)["project_id"]
|
|
||||||
vertexai.init(project=project_id, credentials=credentials)
|
|
||||||
return TextEmbeddingModel.from_pretrained(
|
|
||||||
self.model or DEFAULT_VERTEX_MODEL
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
|
|
||||||
) -> list[list[float]]:
|
|
||||||
return [
|
|
||||||
self.embed(text=text, text_type=text_type, model=model_name)
|
|
||||||
for text in texts
|
|
||||||
]
|
|
||||||
|
|
||||||
def embed(
|
|
||||||
self, *, text: str, text_type: EmbedTextType, model: str | None = None
|
|
||||||
) -> list[float]:
|
|
||||||
logger.debug(f"Embedding text with provider: {self.provider}")
|
|
||||||
if self.provider == EmbeddingProvider.OPENAI:
|
|
||||||
return self._embed_openai(text, model)
|
|
||||||
|
|
||||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
|
||||||
|
|
||||||
if self.provider == EmbeddingProvider.COHERE:
|
|
||||||
return self._embed_cohere(text, model, embedding_type)
|
|
||||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
|
||||||
return self._embed_voyage(text, model, embedding_type)
|
|
||||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
|
||||||
return self._embed_vertex(text, model, embedding_type)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
||||||
|
|
||||||
def _embed_openai(self, text: str, model: str | None) -> list[float]:
|
def _embed_openai(self, text: str, model: str | None) -> list[float]:
|
||||||
if model is None:
|
if model is None:
|
||||||
@@ -145,6 +126,62 @@ class CloudEmbedding:
|
|||||||
)
|
)
|
||||||
return embedding[0].values
|
return embedding[0].values
|
||||||
|
|
||||||
|
def _embed(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
text: str,
|
||||||
|
text_type: EmbedTextType,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> list[float]:
|
||||||
|
try:
|
||||||
|
if self.provider == EmbeddingProvider.OPENAI:
|
||||||
|
return self._embed_openai(text, model)
|
||||||
|
|
||||||
|
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||||
|
if self.provider == EmbeddingProvider.COHERE:
|
||||||
|
return self._embed_cohere(text, model, embedding_type)
|
||||||
|
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||||
|
return self._embed_voyage(text, model, embedding_type)
|
||||||
|
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||||
|
return self._embed_vertex(text, model, embedding_type)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Error embedding text with {self.provider}: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
texts: list[str],
|
||||||
|
model_name: str | None,
|
||||||
|
text_type: EmbedTextType,
|
||||||
|
) -> list[list[float]]:
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
futures = [
|
||||||
|
executor.submit(
|
||||||
|
self._embed,
|
||||||
|
text=text,
|
||||||
|
text_type=text_type,
|
||||||
|
model=model_name,
|
||||||
|
)
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for future in concurrent.futures.as_completed(futures):
|
||||||
|
try:
|
||||||
|
results.append(future.result())
|
||||||
|
except Exception as e:
|
||||||
|
# Cancel all pending futures
|
||||||
|
for f in futures:
|
||||||
|
f.cancel()
|
||||||
|
# Raise the exception immediately
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
api_key: str, provider: str, model: str | None = None
|
api_key: str, provider: str, model: str | None = None
|
||||||
@@ -204,6 +241,7 @@ def warm_up_cross_encoders() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@simple_log_function_time()
|
@simple_log_function_time()
|
||||||
|
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||||
def embed_text(
|
def embed_text(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
text_type: EmbedTextType,
|
text_type: EmbedTextType,
|
||||||
@@ -212,29 +250,50 @@ def embed_text(
|
|||||||
normalize_embeddings: bool,
|
normalize_embeddings: bool,
|
||||||
api_key: str | None,
|
api_key: str | None,
|
||||||
provider_type: str | None,
|
provider_type: str | None,
|
||||||
|
prefix: str | None,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
if provider_type is not None:
|
if provider_type is not None:
|
||||||
|
logger.debug(f"Embedding text with provider: {provider_type}")
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise RuntimeError("API key not provided for cloud model")
|
raise RuntimeError("API key not provided for cloud model")
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
# This may change in the future if some providers require the user
|
||||||
|
# to manually append a prefix but this is not the case currently
|
||||||
|
raise ValueError(
|
||||||
|
"Prefix string is not valid for cloud models. "
|
||||||
|
"Cloud models take an explicit text type instead."
|
||||||
|
)
|
||||||
|
|
||||||
cloud_model = CloudEmbedding(
|
cloud_model = CloudEmbedding(
|
||||||
api_key=api_key, provider=provider_type, model=model_name
|
api_key=api_key, provider=provider_type, model=model_name
|
||||||
)
|
)
|
||||||
embeddings = cloud_model.encode(texts, model_name, text_type)
|
embeddings = cloud_model.encode(
|
||||||
|
texts=texts,
|
||||||
|
model_name=model_name,
|
||||||
|
text_type=text_type,
|
||||||
|
)
|
||||||
|
|
||||||
elif model_name is not None:
|
elif model_name is not None:
|
||||||
|
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||||
hosted_model = get_embedding_model(
|
hosted_model = get_embedding_model(
|
||||||
model_name=model_name, max_context_length=max_context_length
|
model_name=model_name, max_context_length=max_context_length
|
||||||
)
|
)
|
||||||
embeddings = hosted_model.encode(
|
embeddings = hosted_model.encode(
|
||||||
texts, normalize_embeddings=normalize_embeddings
|
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Either model name or provider must be provided to run embeddings."
|
||||||
)
|
)
|
||||||
|
|
||||||
if embeddings is None:
|
if embeddings is None:
|
||||||
raise RuntimeError("Embeddings were not created")
|
raise RuntimeError("Failed to create Embeddings")
|
||||||
|
|
||||||
if not isinstance(embeddings, list):
|
if not isinstance(embeddings, list):
|
||||||
embeddings = embeddings.tolist()
|
embeddings = embeddings.tolist()
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
@@ -253,6 +312,13 @@ async def process_embed_request(
|
|||||||
embed_request: EmbedRequest,
|
embed_request: EmbedRequest,
|
||||||
) -> EmbedResponse:
|
) -> EmbedResponse:
|
||||||
try:
|
try:
|
||||||
|
if embed_request.text_type == EmbedTextType.QUERY:
|
||||||
|
prefix = embed_request.manual_query_prefix
|
||||||
|
elif embed_request.text_type == EmbedTextType.PASSAGE:
|
||||||
|
prefix = embed_request.manual_passage_prefix
|
||||||
|
else:
|
||||||
|
prefix = None
|
||||||
|
|
||||||
embeddings = embed_text(
|
embeddings = embed_text(
|
||||||
texts=embed_request.texts,
|
texts=embed_request.texts,
|
||||||
model_name=embed_request.model_name,
|
model_name=embed_request.model_name,
|
||||||
@@ -261,13 +327,13 @@ async def process_embed_request(
|
|||||||
api_key=embed_request.api_key,
|
api_key=embed_request.api_key,
|
||||||
provider_type=embed_request.provider_type,
|
provider_type=embed_request.provider_type,
|
||||||
text_type=embed_request.text_type,
|
text_type=embed_request.text_type,
|
||||||
|
prefix=prefix,
|
||||||
)
|
)
|
||||||
return EmbedResponse(embeddings=embeddings)
|
return EmbedResponse(embeddings=embeddings)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error during embedding process:\n{str(e)}")
|
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||||
raise HTTPException(
|
logger.exception(exception_detail)
|
||||||
status_code=500, detail="Failed to run Bi-Encoder embedding"
|
raise HTTPException(status_code=500, detail=exception_detail)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/cross-encoder-scores")
|
@router.post("/cross-encoder-scores")
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
fastapi==0.109.2
|
fastapi==0.109.2
|
||||||
h5py==3.9.0
|
h5py==3.9.0
|
||||||
pydantic==1.10.13
|
pydantic==1.10.13
|
||||||
|
retry==0.9.2
|
||||||
safetensors==0.4.2
|
safetensors==0.4.2
|
||||||
sentence-transformers==2.6.1
|
sentence-transformers==2.6.1
|
||||||
tensorflow==2.15.0
|
tensorflow==2.15.0
|
||||||
|
@@ -4,9 +4,7 @@ from shared_configs.enums import EmbedTextType
|
|||||||
|
|
||||||
|
|
||||||
class EmbedRequest(BaseModel):
|
class EmbedRequest(BaseModel):
|
||||||
# This already includes any prefixes, the text is just passed directly to the model
|
|
||||||
texts: list[str]
|
texts: list[str]
|
||||||
|
|
||||||
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||||
model_name: str | None
|
model_name: str | None
|
||||||
max_context_length: int
|
max_context_length: int
|
||||||
@@ -14,6 +12,8 @@ class EmbedRequest(BaseModel):
|
|||||||
api_key: str | None
|
api_key: str | None
|
||||||
provider_type: str | None
|
provider_type: str | None
|
||||||
text_type: EmbedTextType
|
text_type: EmbedTextType
|
||||||
|
manual_query_prefix: str | None
|
||||||
|
manual_passage_prefix: str | None
|
||||||
|
|
||||||
|
|
||||||
class EmbedResponse(BaseModel):
|
class EmbedResponse(BaseModel):
|
||||||
|
Reference in New Issue
Block a user