mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
add third party embedding models (#1818)
This commit is contained in:
@@ -0,0 +1,65 @@
|
||||
"""add cloud embedding model and update embedding_model
|
||||
|
||||
Revision ID: 44f856ae2a4a
|
||||
Revises: d716b0791ddd
|
||||
Create Date: 2024-06-28 20:01:05.927647
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "44f856ae2a4a"
|
||||
down_revision = "d716b0791ddd"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create embedding_provider table
|
||||
op.create_table(
|
||||
"embedding_provider",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("default_model_id", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
# Add cloud_provider_id to embedding_model table
|
||||
op.add_column(
|
||||
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Add foreign key constraints
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["cloud_provider_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_provider_default_model",
|
||||
"embedding_provider",
|
||||
"embedding_model",
|
||||
["default_model_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove foreign key constraints
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Remove cloud_provider_id column
|
||||
op.drop_column("embedding_model", "cloud_provider_id")
|
||||
|
||||
# Drop embedding_provider table
|
||||
op.drop_table("embedding_provider")
|
@@ -10,8 +10,8 @@ from alembic import op
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d716b0791ddd"
|
||||
down_revision = "7aea705850d5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@@ -98,7 +98,6 @@ def _run_indexing(
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
db_embedding_model = index_attempt.embedding_model
|
||||
index_name = db_embedding_model.index_name
|
||||
|
||||
@@ -116,6 +115,8 @@ def _run_indexing(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@@ -287,6 +288,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
|
||||
if attempt is None:
|
||||
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
|
||||
|
||||
|
@@ -343,13 +343,15 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
|
@@ -469,13 +469,13 @@ if __name__ == "__main__":
|
||||
# or the tokens have updated (set up for the first time)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
warm_up_encoders(
|
||||
model_name=embedding_model.model_name,
|
||||
normalize=embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
if embedding_model.cloud_provider_id is None:
|
||||
warm_up_encoders(
|
||||
model_name=embedding_model.model_name,
|
||||
normalize=embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
slack_bot_tokens = latest_slack_bot_tokens
|
||||
# potentially may cause a message to be dropped, but it is complicated
|
||||
|
@@ -10,10 +10,15 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from danswer.db.llm import fetch_embedding_provider
|
||||
from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.search.search_nlp_models import clean_model_name
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -31,6 +36,7 @@ def create_embedding_model(
|
||||
query_prefix=model_details.query_prefix,
|
||||
passage_prefix=model_details.passage_prefix,
|
||||
status=status,
|
||||
cloud_provider_id=model_details.cloud_provider_id,
|
||||
# Every single embedding model except the initial one from migrations has this name
|
||||
# The initial one from migration is called "danswer_chunk"
|
||||
index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}",
|
||||
@@ -42,6 +48,42 @@ def create_embedding_model(
|
||||
return embedding_model
|
||||
|
||||
|
||||
def get_model_id_from_name(
|
||||
db_session: Session, embedding_provider_name: str
|
||||
) -> int | None:
|
||||
query = select(CloudEmbeddingProvider).where(
|
||||
CloudEmbeddingProvider.name == embedding_provider_name
|
||||
)
|
||||
provider = db_session.execute(query).scalars().first()
|
||||
return provider.id if provider else None
|
||||
|
||||
|
||||
def get_current_db_embedding_provider(
|
||||
db_session: Session,
|
||||
) -> ServerCloudEmbeddingProvider | None:
|
||||
current_embedding_model = EmbeddingModelDetail.from_model(
|
||||
get_current_db_embedding_model(db_session=db_session)
|
||||
)
|
||||
|
||||
if (
|
||||
current_embedding_model is None
|
||||
or current_embedding_model.cloud_provider_id is None
|
||||
):
|
||||
return None
|
||||
|
||||
embedding_provider = fetch_embedding_provider(
|
||||
db_session=db_session, provider_id=current_embedding_model.cloud_provider_id
|
||||
)
|
||||
if embedding_provider is None:
|
||||
raise RuntimeError("No embedding provider exists for this model.")
|
||||
|
||||
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
|
||||
cloud_provider_model=embedding_provider
|
||||
)
|
||||
|
||||
return current_embedding_provider
|
||||
|
||||
|
||||
def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel:
|
||||
query = (
|
||||
select(EmbeddingModel)
|
||||
|
@@ -2,11 +2,34 @@ from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
|
||||
|
||||
def upsert_cloud_embedding_provider(
|
||||
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
||||
) -> CloudEmbeddingProvider:
|
||||
existing_provider = (
|
||||
db_session.query(CloudEmbeddingProviderModel)
|
||||
.filter_by(name=provider.name)
|
||||
.first()
|
||||
)
|
||||
if existing_provider:
|
||||
for key, value in provider.dict().items():
|
||||
setattr(existing_provider, key, value)
|
||||
else:
|
||||
new_provider = CloudEmbeddingProviderModel(**provider.dict())
|
||||
db_session.add(new_provider)
|
||||
existing_provider = new_provider
|
||||
db_session.commit()
|
||||
db_session.refresh(existing_provider)
|
||||
return CloudEmbeddingProvider.from_request(existing_provider)
|
||||
|
||||
|
||||
def upsert_llm_provider(
|
||||
db_session: Session, llm_provider: LLMProviderUpsertRequest
|
||||
) -> FullLLMProvider:
|
||||
@@ -26,7 +49,6 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
db_session.commit()
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
# if it does not exist, create a new entry
|
||||
llm_provider_model = LLMProviderModel(
|
||||
name=llm_provider.name,
|
||||
@@ -46,10 +68,26 @@ def upsert_llm_provider(
|
||||
return FullLLMProvider.from_model(llm_provider_model)
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
db_session: Session,
|
||||
) -> list[CloudEmbeddingProviderModel]:
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_id: int
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
return db_session.scalar(
|
||||
select(CloudEmbeddingProviderModel).where(
|
||||
CloudEmbeddingProviderModel.id == provider_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
@@ -70,6 +108,16 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def remove_embedding_provider(
|
||||
db_session: Session, embedding_provider_name: str
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(CloudEmbeddingProviderModel).where(
|
||||
CloudEmbeddingProviderModel.name == embedding_provider_name
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
|
@@ -130,6 +130,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
chat_folders: Mapped[list["ChatFolder"]] = relationship(
|
||||
"ChatFolder", back_populates="user"
|
||||
)
|
||||
|
||||
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
|
||||
# Personas owned by this user
|
||||
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
|
||||
@@ -469,7 +470,7 @@ class Credential(Base):
|
||||
|
||||
class EmbeddingModel(Base):
|
||||
__tablename__ = "embedding_model"
|
||||
# ID is used also to indicate the order that the models are configured by the admin
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
model_name: Mapped[str] = mapped_column(String)
|
||||
model_dim: Mapped[int] = mapped_column(Integer)
|
||||
@@ -481,6 +482,16 @@ class EmbeddingModel(Base):
|
||||
)
|
||||
index_name: Mapped[str] = mapped_column(String)
|
||||
|
||||
# New field for cloud provider relationship
|
||||
cloud_provider_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("embedding_provider.id")
|
||||
)
|
||||
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
|
||||
"CloudEmbeddingProvider",
|
||||
back_populates="embedding_models",
|
||||
foreign_keys=[cloud_provider_id],
|
||||
)
|
||||
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="embedding_model"
|
||||
)
|
||||
@@ -500,6 +511,18 @@ class EmbeddingModel(Base):
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
|
||||
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider else None
|
||||
|
||||
@property
|
||||
def provider_type(self) -> str | None:
|
||||
return self.cloud_provider.name if self.cloud_provider else None
|
||||
|
||||
|
||||
class IndexAttempt(Base):
|
||||
"""
|
||||
@@ -519,6 +542,7 @@ class IndexAttempt(Base):
|
||||
ForeignKey("credential.id"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Some index attempts that run from beginning will still have this as False
|
||||
# This is only for attempts that are explicitly marked as from the start via
|
||||
# the run once API
|
||||
@@ -879,11 +903,6 @@ class ChatMessageFeedback(Base):
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Structures, Organizational, Configurations Tables
|
||||
"""
|
||||
|
||||
|
||||
class LLMProvider(Base):
|
||||
__tablename__ = "llm_provider"
|
||||
|
||||
@@ -912,6 +931,29 @@ class LLMProvider(Base):
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
__tablename__ = "embedding_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||
default_model_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("embedding_model.id"), nullable=True
|
||||
)
|
||||
|
||||
embedding_models: Mapped[list["EmbeddingModel"]] = relationship(
|
||||
"EmbeddingModel",
|
||||
back_populates="cloud_provider",
|
||||
foreign_keys="EmbeddingModel.cloud_provider_id",
|
||||
)
|
||||
default_model: Mapped["EmbeddingModel"] = relationship(
|
||||
"EmbeddingModel", foreign_keys=[default_model_id]
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EmbeddingProvider(name='{self.name}')>"
|
||||
|
||||
|
||||
class DocumentSet(Base):
|
||||
__tablename__ = "document_set"
|
||||
|
||||
@@ -1194,6 +1236,7 @@ class SlackBotConfig(Base):
|
||||
response_type: Mapped[SlackBotResponseType] = mapped_column(
|
||||
Enum(SlackBotResponseType, native_enum=False), nullable=False
|
||||
)
|
||||
|
||||
enable_auto_filters: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
|
@@ -50,6 +50,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
normalize: bool,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
api_key: str | None = None,
|
||||
provider_type: str | None = None,
|
||||
):
|
||||
super().__init__(model_name, normalize, query_prefix, passage_prefix)
|
||||
self.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE # Currently not customizable
|
||||
@@ -59,6 +61,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
query_prefix=query_prefix,
|
||||
passage_prefix=passage_prefix,
|
||||
normalize=normalize,
|
||||
api_key=api_key,
|
||||
provider_type=provider_type,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
|
@@ -97,13 +97,19 @@ class EmbeddingModelDetail(BaseModel):
|
||||
normalize: bool
|
||||
query_prefix: str | None
|
||||
passage_prefix: str | None
|
||||
cloud_provider_id: int | None = None
|
||||
cloud_provider_name: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, embedding_model: "EmbeddingModel") -> "EmbeddingModelDetail":
|
||||
def from_model(
|
||||
cls,
|
||||
embedding_model: "EmbeddingModel",
|
||||
) -> "EmbeddingModelDetail":
|
||||
return cls(
|
||||
model_name=embedding_model.model_name,
|
||||
model_dim=embedding_model.model_dim,
|
||||
normalize=embedding_model.normalize,
|
||||
query_prefix=embedding_model.query_prefix,
|
||||
passage_prefix=embedding_model.passage_prefix,
|
||||
cloud_provider_id=embedding_model.cloud_provider_id,
|
||||
)
|
||||
|
@@ -67,6 +67,8 @@ from danswer.server.features.tool.api import admin_router as admin_tool_router
|
||||
from danswer.server.features.tool.api import router as tool_router
|
||||
from danswer.server.gpts.api import router as gpts_router
|
||||
from danswer.server.manage.administrative import router as admin_router
|
||||
from danswer.server.manage.embedding.api import admin_router as embedding_admin_router
|
||||
from danswer.server.manage.embedding.api import basic_router as embedding_router
|
||||
from danswer.server.manage.get_state import router as state_router
|
||||
from danswer.server.manage.llm.api import admin_router as llm_admin_router
|
||||
from danswer.server.manage.llm.api import basic_router as llm_router
|
||||
@@ -247,12 +249,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
time.sleep(wait_time)
|
||||
|
||||
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
yield
|
||||
@@ -291,6 +294,8 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, settings_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_router)
|
||||
include_router_with_global_prefix_prepended(application, embedding_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, embedding_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
|
@@ -168,6 +168,7 @@ def stream_answer_objects(
|
||||
max_tokens=max_document_tokens,
|
||||
use_sections=query_req.chunks_above > 0 or query_req.chunks_below > 0,
|
||||
)
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
|
@@ -131,6 +131,8 @@ def doc_index_retrieval(
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
normalize=db_embedding_model.normalize,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
|
@@ -84,20 +84,24 @@ def build_model_server_url(
|
||||
class EmbeddingModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
normalize: bool,
|
||||
server_host: str, # Changes depending on indexing or inference
|
||||
server_port: int,
|
||||
model_name: str | None,
|
||||
normalize: bool,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
# The following are globals are currently not configurable
|
||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
self.max_seq_length = max_seq_length
|
||||
self.query_prefix = query_prefix
|
||||
self.passage_prefix = passage_prefix
|
||||
self.normalize = normalize
|
||||
self.model_name = model_name
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
@@ -111,10 +115,13 @@ class EmbeddingModel:
|
||||
prefixed_texts = texts
|
||||
|
||||
embed_request = EmbedRequest(
|
||||
texts=prefixed_texts,
|
||||
model_name=self.model_name,
|
||||
texts=prefixed_texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
||||
@@ -187,6 +194,8 @@ def warm_up_encoders(
|
||||
passage_prefix=None,
|
||||
server_host=model_server_host,
|
||||
server_port=model_server_port,
|
||||
api_key=None,
|
||||
provider_type=None,
|
||||
)
|
||||
|
||||
# First time downloading the models it may take even longer, but just in case,
|
||||
|
93
backend/danswer/server/manage/embedding/api.py
Normal file
93
backend/danswer/server/manage/embedding/api.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.embedding_model import get_current_db_embedding_provider
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.llm import fetch_existing_embedding_providers
|
||||
from danswer.db.llm import remove_embedding_provider
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
from danswer.db.models import User
|
||||
from danswer.search.enums import EmbedTextType
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.embedding.models import TestEmbeddingRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/embedding")
|
||||
basic_router = APIRouter(prefix="/embedding")
|
||||
|
||||
|
||||
@admin_router.post("/test-embedding")
|
||||
def test_embedding_configuration(
|
||||
test_llm_request: TestEmbeddingRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
test_model = EmbeddingModel(
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
api_key=test_llm_request.api_key,
|
||||
provider_type=test_llm_request.provider,
|
||||
normalize=False,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
model_name=None,
|
||||
)
|
||||
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
|
||||
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.get("/embedding-provider")
|
||||
def list_embedding_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[CloudEmbeddingProvider]:
|
||||
return [
|
||||
CloudEmbeddingProvider.from_request(embedding_provider_model)
|
||||
for embedding_provider_model in fetch_existing_embedding_providers(db_session)
|
||||
]
|
||||
|
||||
|
||||
@admin_router.delete("/embedding-provider/{embedding_provider_name}")
|
||||
def delete_embedding_provider(
|
||||
embedding_provider_name: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
embedding_provider = get_current_db_embedding_provider(db_session=db_session)
|
||||
if (
|
||||
embedding_provider is not None
|
||||
and embedding_provider_name == embedding_provider.name
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't delete a currently active model"
|
||||
)
|
||||
|
||||
remove_embedding_provider(db_session, embedding_provider_name)
|
||||
|
||||
|
||||
@admin_router.put("/embedding-provider")
|
||||
def put_cloud_embedding_provider(
|
||||
provider: CloudEmbeddingProviderCreationRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CloudEmbeddingProvider:
|
||||
return upsert_cloud_embedding_provider(db_session, provider)
|
35
backend/danswer/server/manage/embedding/models.py
Normal file
35
backend/danswer/server/manage/embedding/models.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
|
||||
|
||||
class TestEmbeddingRequest(BaseModel):
|
||||
provider: str
|
||||
api_key: str | None = None
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(BaseModel):
|
||||
name: str
|
||||
api_key: str | None = None
|
||||
default_model_id: int | None = None
|
||||
id: int
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls, cloud_provider_model: "CloudEmbeddingProviderModel"
|
||||
) -> "CloudEmbeddingProvider":
|
||||
return cls(
|
||||
id=cloud_provider_model.id,
|
||||
name=cloud_provider_model.name,
|
||||
api_key=cloud_provider_model.api_key,
|
||||
default_model_id=cloud_provider_model.default_model_id,
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProviderCreationRequest(BaseModel):
|
||||
name: str
|
||||
api_key: str | None = None
|
||||
default_model_id: int | None = None
|
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
from danswer.llm.llm_provider_options import fetch_models_for_provider
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
|
||||
|
@@ -11,6 +11,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.embedding_model import create_embedding_model
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_model_id_from_name
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.embedding_model import update_embedding_model_status
|
||||
from danswer.db.engine import get_session
|
||||
@@ -38,6 +39,19 @@ def set_new_embedding_model(
|
||||
"""
|
||||
current_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
if embed_model_details.cloud_provider_name is not None:
|
||||
cloud_id = get_model_id_from_name(
|
||||
db_session, embed_model_details.cloud_provider_name
|
||||
)
|
||||
|
||||
if cloud_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No ID exists for given provider name",
|
||||
)
|
||||
|
||||
embed_model_details.cloud_provider_id = cloud_id
|
||||
|
||||
if embed_model_details.model_name == current_model.model_name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
@@ -1 +1,38 @@
|
||||
from enum import Enum
|
||||
|
||||
from danswer.search.enums import EmbedTextType
|
||||
|
||||
|
||||
MODEL_WARM_UP_STRING = "hi " * 512
|
||||
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
|
||||
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
||||
DEFAULT_VERTEX_MODEL = "text-embedding-004"
|
||||
|
||||
|
||||
class EmbeddingProvider(Enum):
|
||||
OPENAI = "openai"
|
||||
COHERE = "cohere"
|
||||
VOYAGE = "voyage"
|
||||
GOOGLE = "google"
|
||||
|
||||
|
||||
class EmbeddingModelTextType:
|
||||
PROVIDER_TEXT_TYPE_MAP = {
|
||||
EmbeddingProvider.COHERE: {
|
||||
EmbedTextType.QUERY: "search_query",
|
||||
EmbedTextType.PASSAGE: "search_document",
|
||||
},
|
||||
EmbeddingProvider.VOYAGE: {
|
||||
EmbedTextType.QUERY: "query",
|
||||
EmbedTextType.PASSAGE: "document",
|
||||
},
|
||||
EmbeddingProvider.GOOGLE: {
|
||||
EmbedTextType.QUERY: "RETRIEVAL_QUERY",
|
||||
EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT",
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
|
||||
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
|
||||
|
@@ -1,12 +1,28 @@
|
||||
import gc
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
import vertexai # type: ignore
|
||||
import voyageai # type: ignore
|
||||
from cohere import Client as CohereClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingModel # type: ignore
|
||||
|
||||
from danswer.search.enums import EmbedTextType
|
||||
from danswer.utils.logger import setup_logger
|
||||
from model_server.constants import DEFAULT_COHERE_MODEL
|
||||
from model_server.constants import DEFAULT_OPENAI_MODEL
|
||||
from model_server.constants import DEFAULT_VERTEX_MODEL
|
||||
from model_server.constants import DEFAULT_VOYAGE_MODEL
|
||||
from model_server.constants import EmbeddingModelTextType
|
||||
from model_server.constants import EmbeddingProvider
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.utils import simple_log_function_time
|
||||
from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
|
||||
@@ -17,6 +33,7 @@ from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/encoder")
|
||||
@@ -25,6 +42,117 @@ _GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(self, api_key: str, provider: str, model: str | None = None):
|
||||
self.api_key = api_key
|
||||
|
||||
# Only for Google as is needed on client setup
|
||||
self.model = model
|
||||
try:
|
||||
self.provider = EmbeddingProvider(provider.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
self.client = self._initialize_client()
|
||||
|
||||
def _initialize_client(self) -> Any:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return openai.OpenAI(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.COHERE:
|
||||
return CohereClient(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return voyageai.Client(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.api_key)
|
||||
)
|
||||
project_id = json.loads(self.api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
return TextEmbeddingModel.from_pretrained(
|
||||
self.model or DEFAULT_VERTEX_MODEL
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def encode(
|
||||
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
|
||||
) -> list[list[float]]:
|
||||
return [
|
||||
self.embed(text=text, text_type=text_type, model=model_name)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
def embed(
|
||||
self, *, text: str, text_type: EmbedTextType, model: str | None = None
|
||||
) -> list[float]:
|
||||
logger.debug(f"Embedding text with provider: {self.provider}")
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(text, model)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(text, model, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def _embed_openai(self, text: str, model: str | None) -> list[float]:
|
||||
if model is None:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
response = self.client.embeddings.create(input=text, model=model)
|
||||
return response.data[0].embedding
|
||||
|
||||
def _embed_cohere(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
if model is None:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
response = self.client.embed(
|
||||
texts=[text],
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
)
|
||||
return response.embeddings[0]
|
||||
|
||||
def _embed_voyage(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
if model is None:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
response = self.client.embed(text, model=model, input_type=embedding_type)
|
||||
return response.embeddings[0]
|
||||
|
||||
def _embed_vertex(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
if model is None:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
|
||||
embedding = self.client.get_embeddings(
|
||||
[
|
||||
TextEmbeddingInput(
|
||||
text,
|
||||
embedding_type,
|
||||
)
|
||||
]
|
||||
)
|
||||
return embedding[0].values
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
api_key: str, provider: str, model: str | None = None
|
||||
) -> "CloudEmbedding":
|
||||
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||
return CloudEmbedding(api_key, provider, model)
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name: str,
|
||||
max_context_length: int,
|
||||
@@ -78,18 +206,35 @@ def warm_up_cross_encoders() -> None:
|
||||
@simple_log_function_time()
|
||||
def embed_text(
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None,
|
||||
max_context_length: int,
|
||||
normalize_embeddings: bool,
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
) -> list[list[float]]:
|
||||
model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||
if provider_type is not None:
|
||||
if api_key is None:
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
|
||||
cloud_model = CloudEmbedding(
|
||||
api_key=api_key, provider=provider_type, model=model_name
|
||||
)
|
||||
embeddings = cloud_model.encode(texts, model_name, text_type)
|
||||
|
||||
elif model_name is not None:
|
||||
hosted_model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings = hosted_model.encode(
|
||||
texts, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
|
||||
if embeddings is None:
|
||||
raise RuntimeError("Embeddings were not created")
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
embeddings = embeddings.tolist()
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
@@ -113,6 +258,9 @@ async def process_embed_request(
|
||||
model_name=embed_request.model_name,
|
||||
max_context_length=embed_request.max_context_length,
|
||||
normalize_embeddings=embed_request.normalize_embeddings,
|
||||
api_key=embed_request.api_key,
|
||||
provider_type=embed_request.provider_type,
|
||||
text_type=embed_request.text_type,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
|
@@ -7,3 +7,7 @@ tensorflow==2.15.0
|
||||
torch==2.0.1
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
openai==1.14.3
|
||||
cohere==5.5.8
|
||||
google-cloud-aiplatform==1.58.0
|
@@ -1,12 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.search.enums import EmbedTextType
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
# This already includes any prefixes, the text is just passed directly to the model
|
||||
texts: list[str]
|
||||
model_name: str
|
||||
|
||||
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||
model_name: str | None
|
||||
max_context_length: int
|
||||
normalize_embeddings: bool
|
||||
api_key: str | None
|
||||
provider_type: str | None
|
||||
text_type: EmbedTextType
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
|
Reference in New Issue
Block a user