mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-29 11:12:02 +01:00
Propagate Embedding Enum (#2108)
This commit is contained in:
parent
d60fb15ad3
commit
ce666f3320
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7477a5f5d728"
|
||||
down_revision = "213fd978c6d8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -54,6 +54,7 @@ from danswer.llm.override_models import PromptOverride
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.utils.encryption import decrypt_bytes_to_string
|
||||
from danswer.utils.encryption import encrypt_string_to_bytes
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
@ -582,8 +583,12 @@ class EmbeddingModel(Base):
|
||||
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
|
||||
|
||||
@property
|
||||
def provider_type(self) -> str | None:
|
||||
return self.cloud_provider.name if self.cloud_provider is not None else None
|
||||
def provider_type(self) -> EmbeddingProvider | None:
|
||||
return (
|
||||
EmbeddingProvider(self.cloud_provider.name.lower())
|
||||
if self.cloud_provider is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
|
@ -18,6 +18,7 @@ from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.text_splitter import SentenceSplitter # type:ignore
|
||||
@ -123,7 +124,7 @@ def _get_metadata_suffix_for_document_index(
|
||||
def chunk_document(
|
||||
document: Document,
|
||||
model_name: str,
|
||||
provider_type: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
enable_multipass: bool,
|
||||
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
subsection_overlap: int = CHUNK_OVERLAP,
|
||||
@ -301,7 +302,10 @@ class Chunker:
|
||||
|
||||
class DefaultChunker(Chunker):
|
||||
def __init__(
|
||||
self, model_name: str, provider_type: str | None, enable_multipass: bool
|
||||
self,
|
||||
model_name: str,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
enable_multipass: bool,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.provider_type = provider_type
|
||||
|
@ -15,6 +15,7 @@ from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
@ -29,7 +30,7 @@ class IndexingEmbedder(ABC):
|
||||
normalize: bool,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
provider_type: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
api_key: str | None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
@ -54,7 +55,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
normalize: bool,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
provider_type: str | None = None,
|
||||
provider_type: EmbeddingProvider | None = None,
|
||||
api_key: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
|
@ -16,6 +16,7 @@ from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
@ -76,7 +77,7 @@ class EmbeddingModel:
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
# The following are globals are currently not configurable
|
||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
retrim_content: bool = False,
|
||||
|
@ -9,6 +9,7 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = setup_logger()
|
||||
transformer_logging.set_verbosity_error()
|
||||
@ -114,7 +115,13 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
|
||||
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str | None, provider_type: str | None) -> BaseTokenizer:
|
||||
def get_tokenizer(
|
||||
model_name: str | None, provider_type: EmbeddingProvider | str | None
|
||||
) -> BaseTokenizer:
|
||||
# Currently all of the viable models use the same sentencepiece tokenizer
|
||||
# OpenAI uses a different one but currently it's not supported due to quality issues
|
||||
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now
|
||||
# LLM tokenizers are specified by strings
|
||||
global _DEFAULT_TOKENIZER
|
||||
return _DEFAULT_TOKENIZER
|
||||
|
||||
|
@ -2,12 +2,14 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
|
||||
|
||||
class TestEmbeddingRequest(BaseModel):
|
||||
provider: str
|
||||
provider: EmbeddingProvider
|
||||
api_key: str | None = None
|
||||
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
from enum import Enum
|
||||
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
@ -10,13 +9,6 @@ DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
||||
DEFAULT_VERTEX_MODEL = "text-embedding-004"
|
||||
|
||||
|
||||
class EmbeddingProvider(Enum):
|
||||
OPENAI = "openai"
|
||||
COHERE = "cohere"
|
||||
VOYAGE = "voyage"
|
||||
GOOGLE = "google"
|
||||
|
||||
|
||||
class EmbeddingModelTextType:
|
||||
PROVIDER_TEXT_TYPE_MAP = {
|
||||
EmbeddingProvider.COHERE: {
|
||||
|
@ -76,14 +76,11 @@ class CloudEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
provider: str,
|
||||
provider: EmbeddingProvider,
|
||||
# Only for Google as is needed on client setup
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
try:
|
||||
self.provider = EmbeddingProvider(provider.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
self.provider = provider
|
||||
self.client = _initialize_client(api_key, self.provider, model)
|
||||
|
||||
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
|
||||
@ -193,7 +190,7 @@ class CloudEmbedding:
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
api_key: str, provider: str, model: str | None = None
|
||||
api_key: str, provider: EmbeddingProvider, model: str | None = None
|
||||
) -> "CloudEmbedding":
|
||||
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||
return CloudEmbedding(api_key, provider, model)
|
||||
@ -254,7 +251,7 @@ def embed_text(
|
||||
max_context_length: int,
|
||||
normalize_embeddings: bool,
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
prefix: str | None,
|
||||
) -> list[Embedding]:
|
||||
if not all(texts):
|
||||
|
@ -1,6 +1,13 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class EmbeddingProvider(str, Enum):
|
||||
OPENAI = "openai"
|
||||
COHERE = "cohere"
|
||||
VOYAGE = "voyage"
|
||||
GOOGLE = "google"
|
||||
|
||||
|
||||
class EmbedTextType(str, Enum):
|
||||
QUERY = "query"
|
||||
PASSAGE = "passage"
|
||||
|
@ -1,5 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
Embedding = list[float]
|
||||
@ -12,7 +13,7 @@ class EmbedRequest(BaseModel):
|
||||
max_context_length: int
|
||||
normalize_embeddings: bool
|
||||
api_key: str | None
|
||||
provider_type: str | None
|
||||
provider_type: EmbeddingProvider | None
|
||||
text_type: EmbedTextType
|
||||
manual_query_prefix: str | None
|
||||
manual_passage_prefix: str | None
|
||||
|
@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import EmbeddingProvider
|
||||
|
||||
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
|
||||
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
|
||||
@ -30,7 +31,7 @@ def openai_embedding_model() -> EmbeddingModel:
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
provider_type="openai",
|
||||
provider_type=EmbeddingProvider.OPENAI,
|
||||
)
|
||||
|
||||
|
||||
@ -49,7 +50,7 @@ def cohere_embedding_model() -> EmbeddingModel:
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
api_key=os.getenv("COHERE_API_KEY"),
|
||||
provider_type="cohere",
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
)
|
||||
|
||||
|
||||
|
@ -15,7 +15,7 @@ export function SelectModelModal({
|
||||
return (
|
||||
<Modal
|
||||
onOutsideClick={onCancel}
|
||||
title={`Elevate Your Game with ${model.model_name}`}
|
||||
title={`Update model to ${model.model_name}`}
|
||||
>
|
||||
<div className="mb-4">
|
||||
<Text className="text-lg mb-2">
|
||||
|
Loading…
x
Reference in New Issue
Block a user