mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 05:43:33 +02:00
k
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import concurrent.futures
|
||||
import gc
|
||||
import json
|
||||
from typing import Any
|
||||
@@ -81,69 +80,70 @@ class CloudEmbedding:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
self.client = _initialize_client(api_key, self.provider, model)
|
||||
|
||||
def _embed_openai(self, text: str, model: str | None) -> list[float]:
|
||||
def _embed_openai(self, texts: list[str], model: str | None) -> list[list[float]]:
|
||||
if model is None:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
response = self.client.embeddings.create(input=text, model=model)
|
||||
return response.data[0].embedding
|
||||
response = self.client.embeddings.create(input=texts, model=model)
|
||||
return [embedding.embedding for embedding in response.data]
|
||||
|
||||
def _embed_cohere(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float]]:
|
||||
if model is None:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
response = self.client.embed(
|
||||
texts=[text],
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
)
|
||||
return response.embeddings[0]
|
||||
return response.embeddings
|
||||
|
||||
def _embed_voyage(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float]]:
|
||||
if model is None:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
response = self.client.embed(text, model=model, input_type=embedding_type)
|
||||
return response.embeddings[0]
|
||||
response = self.client.embed(texts, model=model, input_type=embedding_type)
|
||||
return response.embeddings
|
||||
|
||||
def _embed_vertex(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float]]:
|
||||
if model is None:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
|
||||
embedding = self.client.get_embeddings(
|
||||
embeddings = self.client.get_embeddings(
|
||||
[
|
||||
TextEmbeddingInput(
|
||||
text,
|
||||
embedding_type,
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
)
|
||||
return embedding[0].values
|
||||
return [embedding.values for embedding in embeddings]
|
||||
|
||||
def _embed(
|
||||
def embed(
|
||||
self,
|
||||
*,
|
||||
text: str,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model: str | None = None,
|
||||
) -> list[float]:
|
||||
model_name: str | None = None,
|
||||
) -> list[list[float]]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(text, model)
|
||||
return self._embed_openai(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(text, model, embedding_type)
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(text, model, embedding_type)
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(text, model, embedding_type)
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except Exception as e:
|
||||
@@ -152,36 +152,6 @@ class CloudEmbedding:
|
||||
detail=f"Error embedding text with {self.provider}: {str(e)}",
|
||||
)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
texts: list[str],
|
||||
model_name: str | None,
|
||||
text_type: EmbedTextType,
|
||||
) -> list[list[float]]:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
self._embed,
|
||||
text=text,
|
||||
text_type=text_type,
|
||||
model=model_name,
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
results = []
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as e:
|
||||
# Cancel all pending futures
|
||||
for f in futures:
|
||||
f.cancel()
|
||||
# Raise the exception immediately
|
||||
raise e
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
api_key: str, provider: str, model: str | None = None
|
||||
@@ -252,6 +222,7 @@ def embed_text(
|
||||
provider_type: str | None,
|
||||
prefix: str | None,
|
||||
) -> list[list[float]]:
|
||||
# Third party API based embedding model
|
||||
if provider_type is not None:
|
||||
logger.debug(f"Embedding text with provider: {provider_type}")
|
||||
if api_key is None:
|
||||
@@ -268,18 +239,19 @@ def embed_text(
|
||||
cloud_model = CloudEmbedding(
|
||||
api_key=api_key, provider=provider_type, model=model_name
|
||||
)
|
||||
embeddings = cloud_model.encode(
|
||||
embeddings = cloud_model.embed(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
# Locally running model
|
||||
elif model_name is not None:
|
||||
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||
hosted_model = get_embedding_model(
|
||||
local_model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings = hosted_model.encode(
|
||||
embeddings = local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user