mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-22 11:03:05 +02:00
Propagate Embedding Enum (#2108)
This commit is contained in:
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
|||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = "7477a5f5d728"
|
revision = "7477a5f5d728"
|
||||||
down_revision = "213fd978c6d8"
|
down_revision = "213fd978c6d8"
|
||||||
branch_labels = None
|
branch_labels: None = None
|
||||||
depends_on = None
|
depends_on: None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
|
@ -54,6 +54,7 @@ from danswer.llm.override_models import PromptOverride
|
|||||||
from danswer.search.enums import RecencyBiasSetting
|
from danswer.search.enums import RecencyBiasSetting
|
||||||
from danswer.utils.encryption import decrypt_bytes_to_string
|
from danswer.utils.encryption import decrypt_bytes_to_string
|
||||||
from danswer.utils.encryption import encrypt_string_to_bytes
|
from danswer.utils.encryption import encrypt_string_to_bytes
|
||||||
|
from shared_configs.enums import EmbeddingProvider
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
@ -582,8 +583,12 @@ class EmbeddingModel(Base):
|
|||||||
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
|
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_type(self) -> str | None:
|
def provider_type(self) -> EmbeddingProvider | None:
|
||||||
return self.cloud_provider.name if self.cloud_provider is not None else None
|
return (
|
||||||
|
EmbeddingProvider(self.cloud_provider.name.lower())
|
||||||
|
if self.cloud_provider is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def api_key(self) -> str | None:
|
def api_key(self) -> str | None:
|
||||||
|
@ -18,6 +18,7 @@ from danswer.indexing.models import DocAwareChunk
|
|||||||
from danswer.natural_language_processing.utils import get_tokenizer
|
from danswer.natural_language_processing.utils import get_tokenizer
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||||
|
from shared_configs.enums import EmbeddingProvider
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from llama_index.text_splitter import SentenceSplitter # type:ignore
|
from llama_index.text_splitter import SentenceSplitter # type:ignore
|
||||||
@ -123,7 +124,7 @@ def _get_metadata_suffix_for_document_index(
|
|||||||
def chunk_document(
|
def chunk_document(
|
||||||
document: Document,
|
document: Document,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
provider_type: str | None,
|
provider_type: EmbeddingProvider | None,
|
||||||
enable_multipass: bool,
|
enable_multipass: bool,
|
||||||
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||||
subsection_overlap: int = CHUNK_OVERLAP,
|
subsection_overlap: int = CHUNK_OVERLAP,
|
||||||
@ -301,7 +302,10 @@ class Chunker:
|
|||||||
|
|
||||||
class DefaultChunker(Chunker):
|
class DefaultChunker(Chunker):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name: str, provider_type: str | None, enable_multipass: bool
|
self,
|
||||||
|
model_name: str,
|
||||||
|
provider_type: EmbeddingProvider | None,
|
||||||
|
enable_multipass: bool,
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.provider_type = provider_type
|
self.provider_type = provider_type
|
||||||
|
@ -15,6 +15,7 @@ from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
|||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||||
|
from shared_configs.enums import EmbeddingProvider
|
||||||
from shared_configs.enums import EmbedTextType
|
from shared_configs.enums import EmbedTextType
|
||||||
from shared_configs.model_server_models import Embedding
|
from shared_configs.model_server_models import Embedding
|
||||||
|
|
||||||
@ -29,7 +30,7 @@ class IndexingEmbedder(ABC):
|
|||||||
normalize: bool,
|
normalize: bool,
|
||||||
query_prefix: str | None,
|
query_prefix: str | None,
|
||||||
passage_prefix: str | None,
|
passage_prefix: str | None,
|
||||||
provider_type: str | None,
|
provider_type: EmbeddingProvider | None,
|
||||||
api_key: str | None,
|
api_key: str | None,
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@ -54,7 +55,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
|||||||
normalize: bool,
|
normalize: bool,
|
||||||
query_prefix: str | None,
|
query_prefix: str | None,
|
||||||
passage_prefix: str | None,
|
passage_prefix: str | None,
|
||||||
provider_type: str | None = None,
|
provider_type: EmbeddingProvider | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
@ -16,6 +16,7 @@ from danswer.natural_language_processing.utils import tokenizer_trim_content
|
|||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from shared_configs.configs import MODEL_SERVER_HOST
|
from shared_configs.configs import MODEL_SERVER_HOST
|
||||||
from shared_configs.configs import MODEL_SERVER_PORT
|
from shared_configs.configs import MODEL_SERVER_PORT
|
||||||
|
from shared_configs.enums import EmbeddingProvider
|
||||||
from shared_configs.enums import EmbedTextType
|
from shared_configs.enums import EmbedTextType
|
||||||
from shared_configs.model_server_models import Embedding
|
from shared_configs.model_server_models import Embedding
|
||||||
from shared_configs.model_server_models import EmbedRequest
|
from shared_configs.model_server_models import EmbedRequest
|
||||||
@ -76,7 +77,7 @@ class EmbeddingModel:
|
|||||||
query_prefix: str | None,
|
query_prefix: str | None,
|
||||||
passage_prefix: str | None,
|
passage_prefix: str | None,
|
||||||
api_key: str | None,
|
api_key: str | None,
|
||||||
provider_type: str | None,
|
provider_type: EmbeddingProvider | None,
|
||||||
# The following are globals are currently not configurable
|
# The following are globals are currently not configurable
|
||||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||||
retrim_content: bool = False,
|
retrim_content: bool = False,
|
||||||
|
@ -9,6 +9,7 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
|||||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||||
from danswer.search.models import InferenceChunk
|
from danswer.search.models import InferenceChunk
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
from shared_configs.enums import EmbeddingProvider
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
transformer_logging.set_verbosity_error()
|
transformer_logging.set_verbosity_error()
|
||||||
@ -114,7 +115,13 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
|
|||||||
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||||
|
|
||||||
|
|
||||||
def get_tokenizer(model_name: str | None, provider_type: str | None) -> BaseTokenizer:
|
def get_tokenizer(
|
||||||
|
model_name: str | None, provider_type: EmbeddingProvider | str | None
|
||||||
|
) -> BaseTokenizer:
|
||||||
|
# Currently all of the viable models use the same sentencepiece tokenizer
|
||||||
|
# OpenAI uses a different one but currently it's not supported due to quality issues
|
||||||
|
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now
|
||||||
|
# LLM tokenizers are specified by strings
|
||||||
global _DEFAULT_TOKENIZER
|
global _DEFAULT_TOKENIZER
|
||||||
return _DEFAULT_TOKENIZER
|
return _DEFAULT_TOKENIZER
|
||||||
|
|
||||||
|
@ -2,12 +2,14 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from shared_configs.enums import EmbeddingProvider
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||||
|
|
||||||
|
|
||||||
class TestEmbeddingRequest(BaseModel):
|
class TestEmbeddingRequest(BaseModel):
|
||||||
provider: str
|
provider: EmbeddingProvider
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from enum import Enum
|
from shared_configs.enums import EmbeddingProvider
|
||||||
|
|
||||||
from shared_configs.enums import EmbedTextType
|
from shared_configs.enums import EmbedTextType
|
||||||
|
|
||||||
|
|
||||||
@ -10,13 +9,6 @@ DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
|||||||
DEFAULT_VERTEX_MODEL = "text-embedding-004"
|
DEFAULT_VERTEX_MODEL = "text-embedding-004"
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingProvider(Enum):
|
|
||||||
OPENAI = "openai"
|
|
||||||
COHERE = "cohere"
|
|
||||||
VOYAGE = "voyage"
|
|
||||||
GOOGLE = "google"
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModelTextType:
|
class EmbeddingModelTextType:
|
||||||
PROVIDER_TEXT_TYPE_MAP = {
|
PROVIDER_TEXT_TYPE_MAP = {
|
||||||
EmbeddingProvider.COHERE: {
|
EmbeddingProvider.COHERE: {
|
||||||
|
@ -76,14 +76,11 @@ class CloudEmbedding:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
provider: str,
|
provider: EmbeddingProvider,
|
||||||
# Only for Google as is needed on client setup
|
# Only for Google as is needed on client setup
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
self.provider = provider
|
||||||
self.provider = EmbeddingProvider(provider.lower())
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError(f"Unsupported provider: {provider}")
|
|
||||||
self.client = _initialize_client(api_key, self.provider, model)
|
self.client = _initialize_client(api_key, self.provider, model)
|
||||||
|
|
||||||
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
|
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
|
||||||
@ -193,7 +190,7 @@ class CloudEmbedding:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
api_key: str, provider: str, model: str | None = None
|
api_key: str, provider: EmbeddingProvider, model: str | None = None
|
||||||
) -> "CloudEmbedding":
|
) -> "CloudEmbedding":
|
||||||
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||||
return CloudEmbedding(api_key, provider, model)
|
return CloudEmbedding(api_key, provider, model)
|
||||||
@ -254,7 +251,7 @@ def embed_text(
|
|||||||
max_context_length: int,
|
max_context_length: int,
|
||||||
normalize_embeddings: bool,
|
normalize_embeddings: bool,
|
||||||
api_key: str | None,
|
api_key: str | None,
|
||||||
provider_type: str | None,
|
provider_type: EmbeddingProvider | None,
|
||||||
prefix: str | None,
|
prefix: str | None,
|
||||||
) -> list[Embedding]:
|
) -> list[Embedding]:
|
||||||
if not all(texts):
|
if not all(texts):
|
||||||
|
@ -1,6 +1,13 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProvider(str, Enum):
|
||||||
|
OPENAI = "openai"
|
||||||
|
COHERE = "cohere"
|
||||||
|
VOYAGE = "voyage"
|
||||||
|
GOOGLE = "google"
|
||||||
|
|
||||||
|
|
||||||
class EmbedTextType(str, Enum):
|
class EmbedTextType(str, Enum):
|
||||||
QUERY = "query"
|
QUERY = "query"
|
||||||
PASSAGE = "passage"
|
PASSAGE = "passage"
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from shared_configs.enums import EmbeddingProvider
|
||||||
from shared_configs.enums import EmbedTextType
|
from shared_configs.enums import EmbedTextType
|
||||||
|
|
||||||
Embedding = list[float]
|
Embedding = list[float]
|
||||||
@ -12,7 +13,7 @@ class EmbedRequest(BaseModel):
|
|||||||
max_context_length: int
|
max_context_length: int
|
||||||
normalize_embeddings: bool
|
normalize_embeddings: bool
|
||||||
api_key: str | None
|
api_key: str | None
|
||||||
provider_type: str | None
|
provider_type: EmbeddingProvider | None
|
||||||
text_type: EmbedTextType
|
text_type: EmbedTextType
|
||||||
manual_query_prefix: str | None
|
manual_query_prefix: str | None
|
||||||
manual_passage_prefix: str | None
|
manual_passage_prefix: str | None
|
||||||
|
@ -4,6 +4,7 @@ import pytest
|
|||||||
|
|
||||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||||
from shared_configs.enums import EmbedTextType
|
from shared_configs.enums import EmbedTextType
|
||||||
|
from shared_configs.model_server_models import EmbeddingProvider
|
||||||
|
|
||||||
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
|
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
|
||||||
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
|
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
|
||||||
@ -30,7 +31,7 @@ def openai_embedding_model() -> EmbeddingModel:
|
|||||||
query_prefix=None,
|
query_prefix=None,
|
||||||
passage_prefix=None,
|
passage_prefix=None,
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
provider_type="openai",
|
provider_type=EmbeddingProvider.OPENAI,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -49,7 +50,7 @@ def cohere_embedding_model() -> EmbeddingModel:
|
|||||||
query_prefix=None,
|
query_prefix=None,
|
||||||
passage_prefix=None,
|
passage_prefix=None,
|
||||||
api_key=os.getenv("COHERE_API_KEY"),
|
api_key=os.getenv("COHERE_API_KEY"),
|
||||||
provider_type="cohere",
|
provider_type=EmbeddingProvider.COHERE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ export function SelectModelModal({
|
|||||||
return (
|
return (
|
||||||
<Modal
|
<Modal
|
||||||
onOutsideClick={onCancel}
|
onOutsideClick={onCancel}
|
||||||
title={`Elevate Your Game with ${model.model_name}`}
|
title={`Update model to ${model.model_name}`}
|
||||||
>
|
>
|
||||||
<div className="mb-4">
|
<div className="mb-4">
|
||||||
<Text className="text-lg mb-2">
|
<Text className="text-lg mb-2">
|
||||||
|
Reference in New Issue
Block a user