Propagate Embedding Enum (#2108)

This commit is contained in:
Yuhong Sun
2024-08-11 12:17:54 -07:00
committed by GitHub
parent d60fb15ad3
commit ce666f3320
13 changed files with 49 additions and 31 deletions

View File

@ -1,5 +1,4 @@
from enum import Enum
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
@ -10,13 +9,6 @@ DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
DEFAULT_VERTEX_MODEL = "text-embedding-004"
class EmbeddingProvider(Enum):
OPENAI = "openai"
COHERE = "cohere"
VOYAGE = "voyage"
GOOGLE = "google"
class EmbeddingModelTextType:
PROVIDER_TEXT_TYPE_MAP = {
EmbeddingProvider.COHERE: {

View File

@ -76,14 +76,11 @@ class CloudEmbedding:
def __init__(
self,
api_key: str,
provider: str,
provider: EmbeddingProvider,
# Only for Google as is needed on client setup
model: str | None = None,
) -> None:
try:
self.provider = EmbeddingProvider(provider.lower())
except ValueError:
raise ValueError(f"Unsupported provider: {provider}")
self.provider = provider
self.client = _initialize_client(api_key, self.provider, model)
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
@ -193,7 +190,7 @@ class CloudEmbedding:
@staticmethod
def create(
api_key: str, provider: str, model: str | None = None
api_key: str, provider: EmbeddingProvider, model: str | None = None
) -> "CloudEmbedding":
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, model)
@ -254,7 +251,7 @@ def embed_text(
max_context_length: int,
normalize_embeddings: bool,
api_key: str | None,
provider_type: str | None,
provider_type: EmbeddingProvider | None,
prefix: str | None,
) -> list[Embedding]:
if not all(texts):