add third party embedding models (#1818)

This commit is contained in:
pablodanswer
2024-07-14 10:19:53 -07:00
committed by GitHub
parent b6bd818e60
commit e7f81d1688
46 changed files with 2293 additions and 453 deletions

View File

@@ -0,0 +1,65 @@
"""add cloud embedding model and update embedding_model
Revision ID: 44f856ae2a4a
Revises: d716b0791ddd
Create Date: 2024-06-28 20:01:05.927647
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "44f856ae2a4a"
down_revision = "d716b0791ddd"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Create embedding_provider table
op.create_table(
"embedding_provider",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("default_model_id", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
# Add cloud_provider_id to embedding_model table
op.add_column(
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
)
# Add foreign key constraints
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["cloud_provider_id"],
["id"],
)
op.create_foreign_key(
"fk_embedding_provider_default_model",
"embedding_provider",
"embedding_model",
["default_model_id"],
["id"],
)
def downgrade() -> None:
# Remove foreign key constraints
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
op.drop_constraint(
"fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey"
)
# Remove cloud_provider_id column
op.drop_column("embedding_model", "cloud_provider_id")
# Drop embedding_provider table
op.drop_table("embedding_provider")

View File

@@ -10,8 +10,8 @@ from alembic import op
# revision identifiers, used by Alembic.
revision = "d716b0791ddd"
down_revision = "7aea705850d5"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@@ -98,7 +98,6 @@ def _run_indexing(
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
start_time = time.time()
db_embedding_model = index_attempt.embedding_model
index_name = db_embedding_model.index_name
@@ -116,6 +115,8 @@ def _run_indexing(
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
api_key=db_embedding_model.api_key,
provider_type=db_embedding_model.provider_type,
)
indexing_pipeline = build_indexing_pipeline(
@@ -287,6 +288,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
db_session=db_session,
index_attempt_id=index_attempt_id,
)
if attempt is None:
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")

View File

@@ -343,13 +343,15 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if db_embedding_model.cloud_provider_id is None:
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient

View File

@@ -469,13 +469,13 @@ if __name__ == "__main__":
# or the tokens have updated (set up for the first time)
with Session(get_sqlalchemy_engine()) as db_session:
embedding_model = get_current_db_embedding_model(db_session)
warm_up_encoders(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if embedding_model.cloud_provider_id is None:
warm_up_encoders(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
slack_bot_tokens = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated

View File

@@ -10,10 +10,15 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.models import EmbeddingModelDetail
from danswer.search.search_nlp_models import clean_model_name
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -31,6 +36,7 @@ def create_embedding_model(
query_prefix=model_details.query_prefix,
passage_prefix=model_details.passage_prefix,
status=status,
cloud_provider_id=model_details.cloud_provider_id,
# Every single embedding model except the initial one from migrations has this name
# The initial one from migration is called "danswer_chunk"
index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}",
@@ -42,6 +48,42 @@ def create_embedding_model(
return embedding_model
def get_model_id_from_name(
db_session: Session, embedding_provider_name: str
) -> int | None:
query = select(CloudEmbeddingProvider).where(
CloudEmbeddingProvider.name == embedding_provider_name
)
provider = db_session.execute(query).scalars().first()
return provider.id if provider else None
def get_current_db_embedding_provider(
db_session: Session,
) -> ServerCloudEmbeddingProvider | None:
current_embedding_model = EmbeddingModelDetail.from_model(
get_current_db_embedding_model(db_session=db_session)
)
if (
current_embedding_model is None
or current_embedding_model.cloud_provider_id is None
):
return None
embedding_provider = fetch_embedding_provider(
db_session=db_session, provider_id=current_embedding_model.cloud_provider_id
)
if embedding_provider is None:
raise RuntimeError("No embedding provider exists for this model.")
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
cloud_provider_model=embedding_provider
)
return current_embedding_provider
def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel:
query = (
select(EmbeddingModel)

View File

@@ -2,11 +2,34 @@ from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
def upsert_cloud_embedding_provider(
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
) -> CloudEmbeddingProvider:
existing_provider = (
db_session.query(CloudEmbeddingProviderModel)
.filter_by(name=provider.name)
.first()
)
if existing_provider:
for key, value in provider.dict().items():
setattr(existing_provider, key, value)
else:
new_provider = CloudEmbeddingProviderModel(**provider.dict())
db_session.add(new_provider)
existing_provider = new_provider
db_session.commit()
db_session.refresh(existing_provider)
return CloudEmbeddingProvider.from_request(existing_provider)
def upsert_llm_provider(
db_session: Session, llm_provider: LLMProviderUpsertRequest
) -> FullLLMProvider:
@@ -26,7 +49,6 @@ def upsert_llm_provider(
existing_llm_provider.model_names = llm_provider.model_names
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
# if it does not exist, create a new entry
llm_provider_model = LLMProviderModel(
name=llm_provider.name,
@@ -46,10 +68,26 @@ def upsert_llm_provider(
return FullLLMProvider.from_model(llm_provider_model)
def fetch_existing_embedding_providers(
db_session: Session,
) -> list[CloudEmbeddingProviderModel]:
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
return list(db_session.scalars(select(LLMProviderModel)).all())
def fetch_embedding_provider(
db_session: Session, provider_id: int
) -> CloudEmbeddingProviderModel | None:
return db_session.scalar(
select(CloudEmbeddingProviderModel).where(
CloudEmbeddingProviderModel.id == provider_id
)
)
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
@@ -70,6 +108,16 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
return FullLLMProvider.from_model(provider_model)
def remove_embedding_provider(
db_session: Session, embedding_provider_name: str
) -> None:
db_session.execute(
delete(CloudEmbeddingProviderModel).where(
CloudEmbeddingProviderModel.name == embedding_provider_name
)
)
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)

View File

@@ -130,6 +130,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
chat_folders: Mapped[list["ChatFolder"]] = relationship(
"ChatFolder", back_populates="user"
)
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
# Personas owned by this user
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
@@ -469,7 +470,7 @@ class Credential(Base):
class EmbeddingModel(Base):
__tablename__ = "embedding_model"
# ID is used also to indicate the order that the models are configured by the admin
id: Mapped[int] = mapped_column(primary_key=True)
model_name: Mapped[str] = mapped_column(String)
model_dim: Mapped[int] = mapped_column(Integer)
@@ -481,6 +482,16 @@ class EmbeddingModel(Base):
)
index_name: Mapped[str] = mapped_column(String)
# New field for cloud provider relationship
cloud_provider_id: Mapped[int | None] = mapped_column(
ForeignKey("embedding_provider.id")
)
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
"CloudEmbeddingProvider",
back_populates="embedding_models",
foreign_keys=[cloud_provider_id],
)
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
"IndexAttempt", back_populates="embedding_model"
)
@@ -500,6 +511,18 @@ class EmbeddingModel(Base):
),
)
def __repr__(self) -> str:
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
@property
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider else None
@property
def provider_type(self) -> str | None:
return self.cloud_provider.name if self.cloud_provider else None
class IndexAttempt(Base):
"""
@@ -519,6 +542,7 @@ class IndexAttempt(Base):
ForeignKey("credential.id"),
nullable=True,
)
# Some index attempts that run from beginning will still have this as False
# This is only for attempts that are explicitly marked as from the start via
# the run once API
@@ -879,11 +903,6 @@ class ChatMessageFeedback(Base):
)
"""
Structures, Organizational, Configurations Tables
"""
class LLMProvider(Base):
__tablename__ = "llm_provider"
@@ -912,6 +931,29 @@ class LLMProvider(Base):
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
class CloudEmbeddingProvider(Base):
__tablename__ = "embedding_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
default_model_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("embedding_model.id"), nullable=True
)
embedding_models: Mapped[list["EmbeddingModel"]] = relationship(
"EmbeddingModel",
back_populates="cloud_provider",
foreign_keys="EmbeddingModel.cloud_provider_id",
)
default_model: Mapped["EmbeddingModel"] = relationship(
"EmbeddingModel", foreign_keys=[default_model_id]
)
def __repr__(self) -> str:
return f"<EmbeddingProvider(name='{self.name}')>"
class DocumentSet(Base):
__tablename__ = "document_set"
@@ -1194,6 +1236,7 @@ class SlackBotConfig(Base):
response_type: Mapped[SlackBotResponseType] = mapped_column(
Enum(SlackBotResponseType, native_enum=False), nullable=False
)
enable_auto_filters: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)

View File

@@ -50,6 +50,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
api_key: str | None = None,
provider_type: str | None = None,
):
super().__init__(model_name, normalize, query_prefix, passage_prefix)
self.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE # Currently not customizable
@@ -59,6 +61,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
query_prefix=query_prefix,
passage_prefix=passage_prefix,
normalize=normalize,
api_key=api_key,
provider_type=provider_type,
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,

View File

@@ -97,13 +97,19 @@ class EmbeddingModelDetail(BaseModel):
normalize: bool
query_prefix: str | None
passage_prefix: str | None
cloud_provider_id: int | None = None
cloud_provider_name: str | None = None
@classmethod
def from_model(cls, embedding_model: "EmbeddingModel") -> "EmbeddingModelDetail":
def from_model(
cls,
embedding_model: "EmbeddingModel",
) -> "EmbeddingModelDetail":
return cls(
model_name=embedding_model.model_name,
model_dim=embedding_model.model_dim,
normalize=embedding_model.normalize,
query_prefix=embedding_model.query_prefix,
passage_prefix=embedding_model.passage_prefix,
cloud_provider_id=embedding_model.cloud_provider_id,
)

View File

@@ -67,6 +67,8 @@ from danswer.server.features.tool.api import admin_router as admin_tool_router
from danswer.server.features.tool.api import router as tool_router
from danswer.server.gpts.api import router as gpts_router
from danswer.server.manage.administrative import router as admin_router
from danswer.server.manage.embedding.api import admin_router as embedding_admin_router
from danswer.server.manage.embedding.api import basic_router as embedding_router
from danswer.server.manage.get_state import router as state_router
from danswer.server.manage.llm.api import admin_router as llm_admin_router
from danswer.server.manage.llm.api import basic_router as llm_router
@@ -247,12 +249,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
time.sleep(wait_time)
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if db_embedding_model.cloud_provider_id is None:
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
yield
@@ -291,6 +294,8 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, settings_admin_router)
include_router_with_global_prefix_prepended(application, llm_admin_router)
include_router_with_global_prefix_prepended(application, llm_router)
include_router_with_global_prefix_prepended(application, embedding_admin_router)
include_router_with_global_prefix_prepended(application, embedding_router)
include_router_with_global_prefix_prepended(
application, token_rate_limit_settings_router
)

View File

@@ -168,6 +168,7 @@ def stream_answer_objects(
max_tokens=max_document_tokens,
use_sections=query_req.chunks_above > 0 or query_req.chunks_below > 0,
)
search_tool = SearchTool(
db_session=db_session,
user=user,

View File

@@ -131,6 +131,8 @@ def doc_index_retrieval(
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
normalize=db_embedding_model.normalize,
api_key=db_embedding_model.api_key,
provider_type=db_embedding_model.provider_type,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,

View File

@@ -84,20 +84,24 @@ def build_model_server_url(
class EmbeddingModel:
def __init__(
self,
model_name: str,
query_prefix: str | None,
passage_prefix: str | None,
normalize: bool,
server_host: str, # Changes depending on indexing or inference
server_port: int,
model_name: str | None,
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
api_key: str | None,
provider_type: str | None,
# The following are globals are currently not configurable
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> None:
self.model_name = model_name
self.api_key = api_key
self.provider_type = provider_type
self.max_seq_length = max_seq_length
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix
self.normalize = normalize
self.model_name = model_name
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
@@ -111,10 +115,13 @@ class EmbeddingModel:
prefixed_texts = texts
embed_request = EmbedRequest(
texts=prefixed_texts,
model_name=self.model_name,
texts=prefixed_texts,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
)
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
@@ -187,6 +194,8 @@ def warm_up_encoders(
passage_prefix=None,
server_host=model_server_host,
server_port=model_server_port,
api_key=None,
provider_type=None,
)
# First time downloading the models it may take even longer, but just in case,

View File

@@ -0,0 +1,93 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.db.embedding_model import get_current_db_embedding_provider
from danswer.db.engine import get_session
from danswer.db.llm import fetch_existing_embedding_providers
from danswer.db.llm import remove_embedding_provider
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.models import User
from danswer.search.enums import EmbedTextType
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.embedding.models import TestEmbeddingRequest
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/embedding")
basic_router = APIRouter(prefix="/embedding")
@admin_router.post("/test-embedding")
def test_embedding_configuration(
test_llm_request: TestEmbeddingRequest,
_: User | None = Depends(current_admin_user),
) -> None:
try:
test_model = EmbeddingModel(
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
api_key=test_llm_request.api_key,
provider_type=test_llm_request.provider,
normalize=False,
query_prefix=None,
passage_prefix=None,
model_name=None,
)
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
except ValueError as e:
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
logger.error(error_msg)
raise ValueError(error_msg)
except Exception as e:
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
raise HTTPException(status_code=400, detail=error_msg)
@admin_router.get("/embedding-provider")
def list_embedding_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[CloudEmbeddingProvider]:
return [
CloudEmbeddingProvider.from_request(embedding_provider_model)
for embedding_provider_model in fetch_existing_embedding_providers(db_session)
]
@admin_router.delete("/embedding-provider/{embedding_provider_name}")
def delete_embedding_provider(
embedding_provider_name: str,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
embedding_provider = get_current_db_embedding_provider(db_session=db_session)
if (
embedding_provider is not None
and embedding_provider_name == embedding_provider.name
):
raise HTTPException(
status_code=400, detail="You can't delete a currently active model"
)
remove_embedding_provider(db_session, embedding_provider_name)
@admin_router.put("/embedding-provider")
def put_cloud_embedding_provider(
provider: CloudEmbeddingProviderCreationRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> CloudEmbeddingProvider:
return upsert_cloud_embedding_provider(db_session, provider)

View File

@@ -0,0 +1,35 @@
from typing import TYPE_CHECKING
from pydantic import BaseModel
if TYPE_CHECKING:
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
class TestEmbeddingRequest(BaseModel):
provider: str
api_key: str | None = None
class CloudEmbeddingProvider(BaseModel):
name: str
api_key: str | None = None
default_model_id: int | None = None
id: int
@classmethod
def from_request(
cls, cloud_provider_model: "CloudEmbeddingProviderModel"
) -> "CloudEmbeddingProvider":
return cls(
id=cloud_provider_model.id,
name=cloud_provider_model.name,
api_key=cloud_provider_model.api_key,
default_model_id=cloud_provider_model.default_model_id,
)
class CloudEmbeddingProviderCreationRequest(BaseModel):
name: str
api_key: str | None = None
default_model_id: int | None = None

View File

@@ -4,6 +4,7 @@ from pydantic import BaseModel
from danswer.llm.llm_provider_options import fetch_models_for_provider
if TYPE_CHECKING:
from danswer.db.models import LLMProvider as LLMProviderModel

View File

@@ -11,6 +11,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.embedding_model import create_embedding_model
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_model_id_from_name
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.embedding_model import update_embedding_model_status
from danswer.db.engine import get_session
@@ -38,6 +39,19 @@ def set_new_embedding_model(
"""
current_model = get_current_db_embedding_model(db_session)
if embed_model_details.cloud_provider_name is not None:
cloud_id = get_model_id_from_name(
db_session, embed_model_details.cloud_provider_name
)
if cloud_id is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No ID exists for given provider name",
)
embed_model_details.cloud_provider_id = cloud_id
if embed_model_details.model_name == current_model.model_name:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View File

@@ -1 +1,38 @@
from enum import Enum
from danswer.search.enums import EmbedTextType
MODEL_WARM_UP_STRING = "hi " * 512
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
DEFAULT_VERTEX_MODEL = "text-embedding-004"
class EmbeddingProvider(Enum):
OPENAI = "openai"
COHERE = "cohere"
VOYAGE = "voyage"
GOOGLE = "google"
class EmbeddingModelTextType:
PROVIDER_TEXT_TYPE_MAP = {
EmbeddingProvider.COHERE: {
EmbedTextType.QUERY: "search_query",
EmbedTextType.PASSAGE: "search_document",
},
EmbeddingProvider.VOYAGE: {
EmbedTextType.QUERY: "query",
EmbedTextType.PASSAGE: "document",
},
EmbeddingProvider.GOOGLE: {
EmbedTextType.QUERY: "RETRIEVAL_QUERY",
EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT",
},
}
@staticmethod
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]

View File

@@ -1,12 +1,28 @@
import gc
import json
from typing import Any
from typing import Optional
import openai
import vertexai # type: ignore
import voyageai # type: ignore
from cohere import Client as CohereClient
from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from vertexai.language_models import TextEmbeddingInput # type: ignore
from vertexai.language_models import TextEmbeddingModel # type: ignore
from danswer.search.enums import EmbedTextType
from danswer.utils.logger import setup_logger
from model_server.constants import DEFAULT_COHERE_MODEL
from model_server.constants import DEFAULT_OPENAI_MODEL
from model_server.constants import DEFAULT_VERTEX_MODEL
from model_server.constants import DEFAULT_VOYAGE_MODEL
from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.utils import simple_log_function_time
from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
@@ -17,6 +33,7 @@ from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
logger = setup_logger()
router = APIRouter(prefix="/encoder")
@@ -25,6 +42,117 @@ _GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
class CloudEmbedding:
def __init__(self, api_key: str, provider: str, model: str | None = None):
self.api_key = api_key
# Only for Google as is needed on client setup
self.model = model
try:
self.provider = EmbeddingProvider(provider.lower())
except ValueError:
raise ValueError(f"Unsupported provider: {provider}")
self.client = self._initialize_client()
def _initialize_client(self) -> Any:
if self.provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=self.api_key)
elif self.provider == EmbeddingProvider.COHERE:
return CohereClient(api_key=self.api_key)
elif self.provider == EmbeddingProvider.VOYAGE:
return voyageai.Client(api_key=self.api_key)
elif self.provider == EmbeddingProvider.GOOGLE:
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.api_key)
)
project_id = json.loads(self.api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(
self.model or DEFAULT_VERTEX_MODEL
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def encode(
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
) -> list[list[float]]:
return [
self.embed(text=text, text_type=text_type, model=model_name)
for text in texts
]
def embed(
self, *, text: str, text_type: EmbedTextType, model: str | None = None
) -> list[float]:
logger.debug(f"Embedding text with provider: {self.provider}")
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(text, model)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(text, model, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(text, model, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(text, model, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def _embed_openai(self, text: str, model: str | None) -> list[float]:
if model is None:
model = DEFAULT_OPENAI_MODEL
response = self.client.embeddings.create(input=text, model=model)
return response.data[0].embedding
def _embed_cohere(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
if model is None:
model = DEFAULT_COHERE_MODEL
response = self.client.embed(
texts=[text],
model=model,
input_type=embedding_type,
)
return response.embeddings[0]
def _embed_voyage(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
if model is None:
model = DEFAULT_VOYAGE_MODEL
response = self.client.embed(text, model=model, input_type=embedding_type)
return response.embeddings[0]
def _embed_vertex(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
if model is None:
model = DEFAULT_VERTEX_MODEL
embedding = self.client.get_embeddings(
[
TextEmbeddingInput(
text,
embedding_type,
)
]
)
return embedding[0].values
@staticmethod
def create(
api_key: str, provider: str, model: str | None = None
) -> "CloudEmbedding":
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, model)
def get_embedding_model(
model_name: str,
max_context_length: int,
@@ -78,18 +206,35 @@ def warm_up_cross_encoders() -> None:
@simple_log_function_time()
def embed_text(
texts: list[str],
model_name: str,
text_type: EmbedTextType,
model_name: str | None,
max_context_length: int,
normalize_embeddings: bool,
api_key: str | None,
provider_type: str | None,
) -> list[list[float]]:
model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
if provider_type is not None:
if api_key is None:
raise RuntimeError("API key not provided for cloud model")
cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=model_name
)
embeddings = cloud_model.encode(texts, model_name, text_type)
elif model_name is not None:
hosted_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = hosted_model.encode(
texts, normalize_embeddings=normalize_embeddings
)
if embeddings is None:
raise RuntimeError("Embeddings were not created")
if not isinstance(embeddings, list):
embeddings = embeddings.tolist()
return embeddings
@@ -113,6 +258,9 @@ async def process_embed_request(
model_name=embed_request.model_name,
max_context_length=embed_request.max_context_length,
normalize_embeddings=embed_request.normalize_embeddings,
api_key=embed_request.api_key,
provider_type=embed_request.provider_type,
text_type=embed_request.text_type,
)
return EmbedResponse(embeddings=embeddings)
except Exception as e:

View File

@@ -7,3 +7,7 @@ tensorflow==2.15.0
torch==2.0.1
transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
openai==1.14.3
cohere==5.5.8
google-cloud-aiplatform==1.58.0

View File

@@ -1,12 +1,19 @@
from pydantic import BaseModel
from danswer.search.enums import EmbedTextType
class EmbedRequest(BaseModel):
# This already includes any prefixes, the text is just passed directly to the model
texts: list[str]
model_name: str
# Can be none for cloud embedding model requests, error handling logic exists for other cases
model_name: str | None
max_context_length: int
normalize_embeddings: bool
api_key: str | None
provider_type: str | None
text_type: EmbedTextType
class EmbedResponse(BaseModel):