Propagate Embedding Enum (#2108)

This commit is contained in:
Yuhong Sun
2024-08-11 12:17:54 -07:00
committed by GitHub
parent d60fb15ad3
commit ce666f3320
13 changed files with 49 additions and 31 deletions

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "7477a5f5d728" revision = "7477a5f5d728"
down_revision = "213fd978c6d8" down_revision = "213fd978c6d8"
branch_labels = None branch_labels: None = None
depends_on = None depends_on: None = None
def upgrade() -> None: def upgrade() -> None:

View File

@ -54,6 +54,7 @@ from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting from danswer.search.enums import RecencyBiasSetting
from danswer.utils.encryption import decrypt_bytes_to_string from danswer.utils.encryption import decrypt_bytes_to_string
from danswer.utils.encryption import encrypt_string_to_bytes from danswer.utils.encryption import encrypt_string_to_bytes
from shared_configs.enums import EmbeddingProvider
class Base(DeclarativeBase): class Base(DeclarativeBase):
@ -582,8 +583,12 @@ class EmbeddingModel(Base):
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>" cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
@property @property
def provider_type(self) -> str | None: def provider_type(self) -> EmbeddingProvider | None:
return self.cloud_provider.name if self.cloud_provider is not None else None return (
EmbeddingProvider(self.cloud_provider.name.lower())
if self.cloud_provider is not None
else None
)
@property @property
def api_key(self) -> str | None: def api_key(self) -> str | None:

View File

@ -18,6 +18,7 @@ from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import get_tokenizer from danswer.natural_language_processing.utils import get_tokenizer
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import shared_precompare_cleanup from danswer.utils.text_processing import shared_precompare_cleanup
from shared_configs.enums import EmbeddingProvider
if TYPE_CHECKING: if TYPE_CHECKING:
from llama_index.text_splitter import SentenceSplitter # type:ignore from llama_index.text_splitter import SentenceSplitter # type:ignore
@ -123,7 +124,7 @@ def _get_metadata_suffix_for_document_index(
def chunk_document( def chunk_document(
document: Document, document: Document,
model_name: str, model_name: str,
provider_type: str | None, provider_type: EmbeddingProvider | None,
enable_multipass: bool, enable_multipass: bool,
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE, chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
subsection_overlap: int = CHUNK_OVERLAP, subsection_overlap: int = CHUNK_OVERLAP,
@ -301,7 +302,10 @@ class Chunker:
class DefaultChunker(Chunker): class DefaultChunker(Chunker):
def __init__( def __init__(
self, model_name: str, provider_type: str | None, enable_multipass: bool self,
model_name: str,
provider_type: EmbeddingProvider | None,
enable_multipass: bool,
): ):
self.model_name = model_name self.model_name = model_name
self.provider_type = provider_type self.provider_type = provider_type

View File

@ -15,6 +15,7 @@ from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding from shared_configs.model_server_models import Embedding
@ -29,7 +30,7 @@ class IndexingEmbedder(ABC):
normalize: bool, normalize: bool,
query_prefix: str | None, query_prefix: str | None,
passage_prefix: str | None, passage_prefix: str | None,
provider_type: str | None, provider_type: EmbeddingProvider | None,
api_key: str | None, api_key: str | None,
): ):
self.model_name = model_name self.model_name = model_name
@ -54,7 +55,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
normalize: bool, normalize: bool,
query_prefix: str | None, query_prefix: str | None,
passage_prefix: str | None, passage_prefix: str | None,
provider_type: str | None = None, provider_type: EmbeddingProvider | None = None,
api_key: str | None = None, api_key: str | None = None,
): ):
super().__init__( super().__init__(

View File

@ -16,6 +16,7 @@ from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest from shared_configs.model_server_models import EmbedRequest
@ -76,7 +77,7 @@ class EmbeddingModel:
query_prefix: str | None, query_prefix: str | None,
passage_prefix: str | None, passage_prefix: str | None,
api_key: str | None, api_key: str | None,
provider_type: str | None, provider_type: EmbeddingProvider | None,
# The following are globals are currently not configurable # The following are globals are currently not configurable
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE, max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
retrim_content: bool = False, retrim_content: bool = False,

View File

@ -9,6 +9,7 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.search.models import InferenceChunk from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
logger = setup_logger() logger = setup_logger()
transformer_logging.set_verbosity_error() transformer_logging.set_verbosity_error()
@ -114,7 +115,13 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL) _DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
def get_tokenizer(model_name: str | None, provider_type: str | None) -> BaseTokenizer: def get_tokenizer(
model_name: str | None, provider_type: EmbeddingProvider | str | None
) -> BaseTokenizer:
# Currently all of the viable models use the same sentencepiece tokenizer
# OpenAI uses a different one but currently it's not supported due to quality issues
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now
# LLM tokenizers are specified by strings
global _DEFAULT_TOKENIZER global _DEFAULT_TOKENIZER
return _DEFAULT_TOKENIZER return _DEFAULT_TOKENIZER

View File

@ -2,12 +2,14 @@ from typing import TYPE_CHECKING
from pydantic import BaseModel from pydantic import BaseModel
from shared_configs.enums import EmbeddingProvider
if TYPE_CHECKING: if TYPE_CHECKING:
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
class TestEmbeddingRequest(BaseModel): class TestEmbeddingRequest(BaseModel):
provider: str provider: EmbeddingProvider
api_key: str | None = None api_key: str | None = None

View File

@ -1,5 +1,4 @@
from enum import Enum from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType from shared_configs.enums import EmbedTextType
@ -10,13 +9,6 @@ DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
DEFAULT_VERTEX_MODEL = "text-embedding-004" DEFAULT_VERTEX_MODEL = "text-embedding-004"
class EmbeddingProvider(Enum):
OPENAI = "openai"
COHERE = "cohere"
VOYAGE = "voyage"
GOOGLE = "google"
class EmbeddingModelTextType: class EmbeddingModelTextType:
PROVIDER_TEXT_TYPE_MAP = { PROVIDER_TEXT_TYPE_MAP = {
EmbeddingProvider.COHERE: { EmbeddingProvider.COHERE: {

View File

@ -76,14 +76,11 @@ class CloudEmbedding:
def __init__( def __init__(
self, self,
api_key: str, api_key: str,
provider: str, provider: EmbeddingProvider,
# Only for Google as is needed on client setup # Only for Google as is needed on client setup
model: str | None = None, model: str | None = None,
) -> None: ) -> None:
try: self.provider = provider
self.provider = EmbeddingProvider(provider.lower())
except ValueError:
raise ValueError(f"Unsupported provider: {provider}")
self.client = _initialize_client(api_key, self.provider, model) self.client = _initialize_client(api_key, self.provider, model)
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
@ -193,7 +190,7 @@ class CloudEmbedding:
@staticmethod @staticmethod
def create( def create(
api_key: str, provider: str, model: str | None = None api_key: str, provider: EmbeddingProvider, model: str | None = None
) -> "CloudEmbedding": ) -> "CloudEmbedding":
logger.debug(f"Creating Embedding instance for provider: {provider}") logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, model) return CloudEmbedding(api_key, provider, model)
@ -254,7 +251,7 @@ def embed_text(
max_context_length: int, max_context_length: int,
normalize_embeddings: bool, normalize_embeddings: bool,
api_key: str | None, api_key: str | None,
provider_type: str | None, provider_type: EmbeddingProvider | None,
prefix: str | None, prefix: str | None,
) -> list[Embedding]: ) -> list[Embedding]:
if not all(texts): if not all(texts):

View File

@ -1,6 +1,13 @@
from enum import Enum from enum import Enum
class EmbeddingProvider(str, Enum):
OPENAI = "openai"
COHERE = "cohere"
VOYAGE = "voyage"
GOOGLE = "google"
class EmbedTextType(str, Enum): class EmbedTextType(str, Enum):
QUERY = "query" QUERY = "query"
PASSAGE = "passage" PASSAGE = "passage"

View File

@ -1,5 +1,6 @@
from pydantic import BaseModel from pydantic import BaseModel
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType from shared_configs.enums import EmbedTextType
Embedding = list[float] Embedding = list[float]
@ -12,7 +13,7 @@ class EmbedRequest(BaseModel):
max_context_length: int max_context_length: int
normalize_embeddings: bool normalize_embeddings: bool
api_key: str | None api_key: str | None
provider_type: str | None provider_type: EmbeddingProvider | None
text_type: EmbedTextType text_type: EmbedTextType
manual_query_prefix: str | None manual_query_prefix: str | None
manual_passage_prefix: str | None manual_passage_prefix: str | None

View File

@ -4,6 +4,7 @@ import pytest
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from shared_configs.enums import EmbedTextType from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import EmbeddingProvider
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"] VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't # openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
@ -30,7 +31,7 @@ def openai_embedding_model() -> EmbeddingModel:
query_prefix=None, query_prefix=None,
passage_prefix=None, passage_prefix=None,
api_key=os.getenv("OPENAI_API_KEY"), api_key=os.getenv("OPENAI_API_KEY"),
provider_type="openai", provider_type=EmbeddingProvider.OPENAI,
) )
@ -49,7 +50,7 @@ def cohere_embedding_model() -> EmbeddingModel:
query_prefix=None, query_prefix=None,
passage_prefix=None, passage_prefix=None,
api_key=os.getenv("COHERE_API_KEY"), api_key=os.getenv("COHERE_API_KEY"),
provider_type="cohere", provider_type=EmbeddingProvider.COHERE,
) )

View File

@ -15,7 +15,7 @@ export function SelectModelModal({
return ( return (
<Modal <Modal
onOutsideClick={onCancel} onOutsideClick={onCancel}
title={`Elevate Your Game with ${model.model_name}`} title={`Update model to ${model.model_name}`}
> >
<div className="mb-4"> <div className="mb-4">
<Text className="text-lg mb-2"> <Text className="text-lg mb-2">