Propagate Embedding Enum (#2108)

This commit is contained in:
Yuhong Sun 2024-08-11 12:17:54 -07:00 committed by GitHub
parent d60fb15ad3
commit ce666f3320
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 49 additions and 31 deletions

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7477a5f5d728"
down_revision = "213fd978c6d8"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -54,6 +54,7 @@ from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.encryption import decrypt_bytes_to_string
from danswer.utils.encryption import encrypt_string_to_bytes
from shared_configs.enums import EmbeddingProvider
class Base(DeclarativeBase):
@ -582,8 +583,12 @@ class EmbeddingModel(Base):
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
@property
def provider_type(self) -> str | None:
return self.cloud_provider.name if self.cloud_provider is not None else None
def provider_type(self) -> EmbeddingProvider | None:
return (
EmbeddingProvider(self.cloud_provider.name.lower())
if self.cloud_provider is not None
else None
)
@property
def api_key(self) -> str | None:

View File

@ -18,6 +18,7 @@ from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import shared_precompare_cleanup
from shared_configs.enums import EmbeddingProvider
if TYPE_CHECKING:
from llama_index.text_splitter import SentenceSplitter # type:ignore
@ -123,7 +124,7 @@ def _get_metadata_suffix_for_document_index(
def chunk_document(
document: Document,
model_name: str,
provider_type: str | None,
provider_type: EmbeddingProvider | None,
enable_multipass: bool,
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
subsection_overlap: int = CHUNK_OVERLAP,
@ -301,7 +302,10 @@ class Chunker:
class DefaultChunker(Chunker):
def __init__(
self, model_name: str, provider_type: str | None, enable_multipass: bool
self,
model_name: str,
provider_type: EmbeddingProvider | None,
enable_multipass: bool,
):
self.model_name = model_name
self.provider_type = provider_type

View File

@ -15,6 +15,7 @@ from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
@ -29,7 +30,7 @@ class IndexingEmbedder(ABC):
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
provider_type: str | None,
provider_type: EmbeddingProvider | None,
api_key: str | None,
):
self.model_name = model_name
@ -54,7 +55,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
provider_type: str | None = None,
provider_type: EmbeddingProvider | None = None,
api_key: str | None = None,
):
super().__init__(

View File

@ -16,6 +16,7 @@ from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
@ -76,7 +77,7 @@ class EmbeddingModel:
query_prefix: str | None,
passage_prefix: str | None,
api_key: str | None,
provider_type: str | None,
provider_type: EmbeddingProvider | None,
# The following are globals are currently not configurable
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
retrim_content: bool = False,

View File

@ -9,6 +9,7 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
transformer_logging.set_verbosity_error()
@ -114,7 +115,13 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
def get_tokenizer(model_name: str | None, provider_type: str | None) -> BaseTokenizer:
def get_tokenizer(
model_name: str | None, provider_type: EmbeddingProvider | str | None
) -> BaseTokenizer:
# Currently all of the viable models use the same sentencepiece tokenizer
# OpenAI uses a different one but currently it's not supported due to quality issues
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now
# LLM tokenizers are specified by strings
global _DEFAULT_TOKENIZER
return _DEFAULT_TOKENIZER

View File

@ -2,12 +2,14 @@ from typing import TYPE_CHECKING
from pydantic import BaseModel
from shared_configs.enums import EmbeddingProvider
if TYPE_CHECKING:
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
class TestEmbeddingRequest(BaseModel):
provider: str
provider: EmbeddingProvider
api_key: str | None = None

View File

@ -1,5 +1,4 @@
from enum import Enum
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
@ -10,13 +9,6 @@ DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
DEFAULT_VERTEX_MODEL = "text-embedding-004"
class EmbeddingProvider(Enum):
OPENAI = "openai"
COHERE = "cohere"
VOYAGE = "voyage"
GOOGLE = "google"
class EmbeddingModelTextType:
PROVIDER_TEXT_TYPE_MAP = {
EmbeddingProvider.COHERE: {

View File

@ -76,14 +76,11 @@ class CloudEmbedding:
def __init__(
self,
api_key: str,
provider: str,
provider: EmbeddingProvider,
# Only for Google as is needed on client setup
model: str | None = None,
) -> None:
try:
self.provider = EmbeddingProvider(provider.lower())
except ValueError:
raise ValueError(f"Unsupported provider: {provider}")
self.provider = provider
self.client = _initialize_client(api_key, self.provider, model)
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
@ -193,7 +190,7 @@ class CloudEmbedding:
@staticmethod
def create(
api_key: str, provider: str, model: str | None = None
api_key: str, provider: EmbeddingProvider, model: str | None = None
) -> "CloudEmbedding":
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, model)
@ -254,7 +251,7 @@ def embed_text(
max_context_length: int,
normalize_embeddings: bool,
api_key: str | None,
provider_type: str | None,
provider_type: EmbeddingProvider | None,
prefix: str | None,
) -> list[Embedding]:
if not all(texts):

View File

@ -1,6 +1,13 @@
from enum import Enum
class EmbeddingProvider(str, Enum):
OPENAI = "openai"
COHERE = "cohere"
VOYAGE = "voyage"
GOOGLE = "google"
class EmbedTextType(str, Enum):
QUERY = "query"
PASSAGE = "passage"

View File

@ -1,5 +1,6 @@
from pydantic import BaseModel
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
Embedding = list[float]
@ -12,7 +13,7 @@ class EmbedRequest(BaseModel):
max_context_length: int
normalize_embeddings: bool
api_key: str | None
provider_type: str | None
provider_type: EmbeddingProvider | None
text_type: EmbedTextType
manual_query_prefix: str | None
manual_passage_prefix: str | None

View File

@ -4,6 +4,7 @@ import pytest
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import EmbeddingProvider
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
@ -30,7 +31,7 @@ def openai_embedding_model() -> EmbeddingModel:
query_prefix=None,
passage_prefix=None,
api_key=os.getenv("OPENAI_API_KEY"),
provider_type="openai",
provider_type=EmbeddingProvider.OPENAI,
)
@ -49,7 +50,7 @@ def cohere_embedding_model() -> EmbeddingModel:
query_prefix=None,
passage_prefix=None,
api_key=os.getenv("COHERE_API_KEY"),
provider_type="cohere",
provider_type=EmbeddingProvider.COHERE,
)

View File

@ -15,7 +15,7 @@ export function SelectModelModal({
return (
<Modal
onOutsideClick={onCancel}
title={`Elevate Your Game with ${model.model_name}`}
title={`Update model to ${model.model_name}`}
>
<div className="mb-4">
<Text className="text-lg mb-2">