mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-08-08 22:12:30 +02:00
Propagate Embedding Enum (#2108)
This commit is contained in:
@@ -76,14 +76,11 @@ class CloudEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
provider: str,
|
||||
provider: EmbeddingProvider,
|
||||
# Only for Google as is needed on client setup
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
self.provider = EmbeddingProvider(provider.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
self.provider = provider
|
||||
self.client = _initialize_client(api_key, self.provider, model)
|
||||
|
||||
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
|
||||
@@ -193,7 +190,7 @@ class CloudEmbedding:
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
api_key: str, provider: str, model: str | None = None
|
||||
api_key: str, provider: EmbeddingProvider, model: str | None = None
|
||||
) -> "CloudEmbedding":
|
||||
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||
return CloudEmbedding(api_key, provider, model)
|
||||
@@ -254,7 +251,7 @@ def embed_text(
|
||||
max_context_length: int,
|
||||
normalize_embeddings: bool,
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
prefix: str | None,
|
||||
) -> list[Embedding]:
|
||||
if not all(texts):
|
||||
|
Reference in New Issue
Block a user