From ce666f3320d9468ef26393c0bc7fcaf718737c76 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 11 Aug 2024 12:17:54 -0700 Subject: [PATCH] Propagate Embedding Enum (#2108) --- .../7477a5f5d728_added_model_defaults_for_users.py | 4 ++-- backend/danswer/db/models.py | 9 +++++++-- backend/danswer/indexing/chunker.py | 8 ++++++-- backend/danswer/indexing/embedder.py | 5 +++-- .../natural_language_processing/search_nlp_models.py | 3 ++- backend/danswer/natural_language_processing/utils.py | 9 ++++++++- backend/danswer/server/manage/embedding/models.py | 4 +++- backend/model_server/constants.py | 10 +--------- backend/model_server/encoders.py | 11 ++++------- backend/shared_configs/enums.py | 7 +++++++ backend/shared_configs/model_server_models.py | 3 ++- .../tests/integration/embedding/test_embeddings.py | 5 +++-- .../models/embedding/modals/SelectModelModal.tsx | 2 +- 13 files changed, 49 insertions(+), 31 deletions(-) diff --git a/backend/alembic/versions/7477a5f5d728_added_model_defaults_for_users.py b/backend/alembic/versions/7477a5f5d728_added_model_defaults_for_users.py index b15637020..6efb98405 100644 --- a/backend/alembic/versions/7477a5f5d728_added_model_defaults_for_users.py +++ b/backend/alembic/versions/7477a5f5d728_added_model_defaults_for_users.py @@ -12,8 +12,8 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "7477a5f5d728" down_revision = "213fd978c6d8" -branch_labels = None -depends_on = None +branch_labels: None = None +depends_on: None = None def upgrade() -> None: diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index dc0810e76..927610b5e 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -54,6 +54,7 @@ from danswer.llm.override_models import PromptOverride from danswer.search.enums import RecencyBiasSetting from danswer.utils.encryption import decrypt_bytes_to_string from danswer.utils.encryption import encrypt_string_to_bytes +from shared_configs.enums import EmbeddingProvider class Base(DeclarativeBase): @@ -582,8 +583,12 @@ class EmbeddingModel(Base): cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>" @property - def provider_type(self) -> str | None: - return self.cloud_provider.name if self.cloud_provider is not None else None + def provider_type(self) -> EmbeddingProvider | None: + return ( + EmbeddingProvider(self.cloud_provider.name.lower()) + if self.cloud_provider is not None + else None + ) @property def api_key(self) -> str | None: diff --git a/backend/danswer/indexing/chunker.py b/backend/danswer/indexing/chunker.py index 87cd3a83c..6bf35008f 100644 --- a/backend/danswer/indexing/chunker.py +++ b/backend/danswer/indexing/chunker.py @@ -18,6 +18,7 @@ from danswer.indexing.models import DocAwareChunk from danswer.natural_language_processing.utils import get_tokenizer from danswer.utils.logger import setup_logger from danswer.utils.text_processing import shared_precompare_cleanup +from shared_configs.enums import EmbeddingProvider if TYPE_CHECKING: from llama_index.text_splitter import SentenceSplitter # type:ignore @@ -123,7 +124,7 @@ def _get_metadata_suffix_for_document_index( def chunk_document( document: Document, model_name: str, - provider_type: str | None, + provider_type: EmbeddingProvider | None, enable_multipass: bool, chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE, subsection_overlap: int = CHUNK_OVERLAP, @@ -301,7 +302,10 @@ class Chunker: class DefaultChunker(Chunker): def __init__( - self, model_name: str, provider_type: str | None, enable_multipass: bool + self, + model_name: str, + provider_type: EmbeddingProvider | None, + enable_multipass: bool, ): self.model_name = model_name self.provider_type = provider_type diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index 89fcb1131..478c0c991 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -15,6 +15,7 @@ from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.utils.logger import setup_logger from shared_configs.configs import INDEXING_MODEL_SERVER_HOST from shared_configs.configs import INDEXING_MODEL_SERVER_PORT +from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType from shared_configs.model_server_models import Embedding @@ -29,7 +30,7 @@ class IndexingEmbedder(ABC): normalize: bool, query_prefix: str | None, passage_prefix: str | None, - provider_type: str | None, + provider_type: EmbeddingProvider | None, api_key: str | None, ): self.model_name = model_name @@ -54,7 +55,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder): normalize: bool, query_prefix: str | None, passage_prefix: str | None, - provider_type: str | None = None, + provider_type: EmbeddingProvider | None = None, api_key: str | None = None, ): super().__init__( diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index db2f1181e..78736d18e 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -16,6 +16,7 @@ from danswer.natural_language_processing.utils import tokenizer_trim_content from danswer.utils.logger import setup_logger from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT +from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType from shared_configs.model_server_models import Embedding from shared_configs.model_server_models import EmbedRequest @@ -76,7 +77,7 @@ class EmbeddingModel: query_prefix: str | None, passage_prefix: str | None, api_key: str | None, - provider_type: str | None, + provider_type: EmbeddingProvider | None, # The following are globals are currently not configurable max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE, retrim_content: bool = False, diff --git a/backend/danswer/natural_language_processing/utils.py b/backend/danswer/natural_language_processing/utils.py index 02d599ffc..d2b9a7d7f 100644 --- a/backend/danswer/natural_language_processing/utils.py +++ b/backend/danswer/natural_language_processing/utils.py @@ -9,6 +9,7 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL from danswer.search.models import InferenceChunk from danswer.utils.logger import setup_logger +from shared_configs.enums import EmbeddingProvider logger = setup_logger() transformer_logging.set_verbosity_error() @@ -114,7 +115,13 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer: _DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL) -def get_tokenizer(model_name: str | None, provider_type: str | None) -> BaseTokenizer: +def get_tokenizer( + model_name: str | None, provider_type: EmbeddingProvider | str | None +) -> BaseTokenizer: + # Currently all of the viable models use the same sentencepiece tokenizer + # OpenAI uses a different one but currently it's not supported due to quality issues + # the inconsistent chunking makes using the sentencepiece tokenizer default better for now + # LLM tokenizers are specified by strings global _DEFAULT_TOKENIZER return _DEFAULT_TOKENIZER diff --git a/backend/danswer/server/manage/embedding/models.py b/backend/danswer/server/manage/embedding/models.py index 4f1e72319..cf6cbd3f2 100644 --- a/backend/danswer/server/manage/embedding/models.py +++ b/backend/danswer/server/manage/embedding/models.py @@ -2,12 +2,14 @@ from typing import TYPE_CHECKING from pydantic import BaseModel +from shared_configs.enums import EmbeddingProvider + if TYPE_CHECKING: from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel class TestEmbeddingRequest(BaseModel): - provider: str + provider: EmbeddingProvider api_key: str | None = None diff --git a/backend/model_server/constants.py b/backend/model_server/constants.py index 1b9f1a068..d6991b402 100644 --- a/backend/model_server/constants.py +++ b/backend/model_server/constants.py @@ -1,5 +1,4 @@ -from enum import Enum - +from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType @@ -10,13 +9,6 @@ DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct" DEFAULT_VERTEX_MODEL = "text-embedding-004" -class EmbeddingProvider(Enum): - OPENAI = "openai" - COHERE = "cohere" - VOYAGE = "voyage" - GOOGLE = "google" - - class EmbeddingModelTextType: PROVIDER_TEXT_TYPE_MAP = { EmbeddingProvider.COHERE: { diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index cda6cb139..f546ac705 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -76,14 +76,11 @@ class CloudEmbedding: def __init__( self, api_key: str, - provider: str, + provider: EmbeddingProvider, # Only for Google as is needed on client setup model: str | None = None, ) -> None: - try: - self.provider = EmbeddingProvider(provider.lower()) - except ValueError: - raise ValueError(f"Unsupported provider: {provider}") + self.provider = provider self.client = _initialize_client(api_key, self.provider, model) def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: @@ -193,7 +190,7 @@ class CloudEmbedding: @staticmethod def create( - api_key: str, provider: str, model: str | None = None + api_key: str, provider: EmbeddingProvider, model: str | None = None ) -> "CloudEmbedding": logger.debug(f"Creating Embedding instance for provider: {provider}") return CloudEmbedding(api_key, provider, model) @@ -254,7 +251,7 @@ def embed_text( max_context_length: int, normalize_embeddings: bool, api_key: str | None, - provider_type: str | None, + provider_type: EmbeddingProvider | None, prefix: str | None, ) -> list[Embedding]: if not all(texts): diff --git a/backend/shared_configs/enums.py b/backend/shared_configs/enums.py index af84c0feb..e7cff3754 100644 --- a/backend/shared_configs/enums.py +++ b/backend/shared_configs/enums.py @@ -1,6 +1,13 @@ from enum import Enum +class EmbeddingProvider(str, Enum): + OPENAI = "openai" + COHERE = "cohere" + VOYAGE = "voyage" + GOOGLE = "google" + + class EmbedTextType(str, Enum): QUERY = "query" PASSAGE = "passage" diff --git a/backend/shared_configs/model_server_models.py b/backend/shared_configs/model_server_models.py index aa024d7e8..effafc621 100644 --- a/backend/shared_configs/model_server_models.py +++ b/backend/shared_configs/model_server_models.py @@ -1,5 +1,6 @@ from pydantic import BaseModel +from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType Embedding = list[float] @@ -12,7 +13,7 @@ class EmbedRequest(BaseModel): max_context_length: int normalize_embeddings: bool api_key: str | None - provider_type: str | None + provider_type: EmbeddingProvider | None text_type: EmbedTextType manual_query_prefix: str | None manual_passage_prefix: str | None diff --git a/backend/tests/integration/embedding/test_embeddings.py b/backend/tests/integration/embedding/test_embeddings.py index ce056477d..a9c12b236 100644 --- a/backend/tests/integration/embedding/test_embeddings.py +++ b/backend/tests/integration/embedding/test_embeddings.py @@ -4,6 +4,7 @@ import pytest from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from shared_configs.enums import EmbedTextType +from shared_configs.model_server_models import EmbeddingProvider VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"] # openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't @@ -30,7 +31,7 @@ def openai_embedding_model() -> EmbeddingModel: query_prefix=None, passage_prefix=None, api_key=os.getenv("OPENAI_API_KEY"), - provider_type="openai", + provider_type=EmbeddingProvider.OPENAI, ) @@ -49,7 +50,7 @@ def cohere_embedding_model() -> EmbeddingModel: query_prefix=None, passage_prefix=None, api_key=os.getenv("COHERE_API_KEY"), - provider_type="cohere", + provider_type=EmbeddingProvider.COHERE, ) diff --git a/web/src/app/admin/models/embedding/modals/SelectModelModal.tsx b/web/src/app/admin/models/embedding/modals/SelectModelModal.tsx index 731337f7b..721034442 100644 --- a/web/src/app/admin/models/embedding/modals/SelectModelModal.tsx +++ b/web/src/app/admin/models/embedding/modals/SelectModelModal.tsx @@ -15,7 +15,7 @@ export function SelectModelModal({ return (