mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 12:30:49 +02:00
Update embedding interface (#2205)
* squash * simplify interface * some updates to typing * cloud provider type * update typing to be even clearer * push local commit (squash) * cleaner interfaces * another quick pass * squash * cleaner alembic * cleaner * remove trailing whitespace * add sequence * quick circle back to double check * update * update naming * update naming
This commit is contained in:
parent
7da6d33451
commit
e89dc67e5d
@ -0,0 +1,163 @@
|
||||
"""embedding provider by provider type
|
||||
|
||||
Revision ID: f17bf3b0d9f1
|
||||
Revises: ee3f4b47fad5
|
||||
Create Date: 2024-08-21 13:13:31.120460
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f17bf3b0d9f1"
|
||||
down_revision = "ee3f4b47fad5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add provider_type column to embedding_provider
|
||||
op.add_column(
|
||||
"embedding_provider",
|
||||
sa.Column("provider_type", sa.String(50), nullable=True),
|
||||
)
|
||||
|
||||
# Update provider_type with existing name values
|
||||
op.execute("UPDATE embedding_provider SET provider_type = UPPER(name)")
|
||||
|
||||
# Make provider_type not nullable
|
||||
op.alter_column("embedding_provider", "provider_type", nullable=False)
|
||||
|
||||
# Drop the foreign key constraint in embedding_model table
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Drop the existing primary key constraint
|
||||
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
|
||||
|
||||
# Create a new primary key constraint on provider_type
|
||||
op.create_primary_key(
|
||||
"embedding_provider_pkey", "embedding_provider", ["provider_type"]
|
||||
)
|
||||
|
||||
# Add provider_type column to embedding_model
|
||||
op.add_column(
|
||||
"embedding_model",
|
||||
sa.Column("provider_type", sa.String(50), nullable=True),
|
||||
)
|
||||
|
||||
# Update provider_type for existing embedding models
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_model
|
||||
SET provider_type = (
|
||||
SELECT provider_type
|
||||
FROM embedding_provider
|
||||
WHERE embedding_provider.id = embedding_model.cloud_provider_id
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the old id column from embedding_provider
|
||||
op.drop_column("embedding_provider", "id")
|
||||
|
||||
# Drop the name column from embedding_provider
|
||||
op.drop_column("embedding_provider", "name")
|
||||
|
||||
# Drop the default_model_id column from embedding_provider
|
||||
op.drop_column("embedding_provider", "default_model_id")
|
||||
|
||||
# Drop the old cloud_provider_id column from embedding_model
|
||||
op.drop_column("embedding_model", "cloud_provider_id")
|
||||
|
||||
# Create the new foreign key constraint
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["provider_type"],
|
||||
["provider_type"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the foreign key constraint in embedding_model table
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Add back the cloud_provider_id column to embedding_model
|
||||
op.add_column(
|
||||
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column("embedding_provider", sa.Column("id", sa.Integer(), nullable=True))
|
||||
|
||||
# Assign incrementing IDs to embedding providers
|
||||
op.execute(
|
||||
"""
|
||||
CREATE SEQUENCE IF NOT EXISTS embedding_provider_id_seq;"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_provider SET id = nextval('embedding_provider_id_seq');
|
||||
"""
|
||||
)
|
||||
|
||||
# Update cloud_provider_id based on provider_type
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_model
|
||||
SET cloud_provider_id = CASE
|
||||
WHEN provider_type IS NULL THEN NULL
|
||||
ELSE (
|
||||
SELECT id
|
||||
FROM embedding_provider
|
||||
WHERE embedding_provider.provider_type = embedding_model.provider_type
|
||||
)
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the provider_type column from embedding_model
|
||||
op.drop_column("embedding_model", "provider_type")
|
||||
|
||||
# Add back the columns to embedding_provider
|
||||
op.add_column("embedding_provider", sa.Column("name", sa.String(50), nullable=True))
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("default_model_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Drop the existing primary key constraint on provider_type
|
||||
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
|
||||
|
||||
# Create the original primary key constraint on id
|
||||
op.create_primary_key("embedding_provider_pkey", "embedding_provider", ["id"])
|
||||
|
||||
# Update name with existing provider_type values
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE embedding_provider
|
||||
SET name = CASE
|
||||
WHEN provider_type = 'OPENAI' THEN 'OpenAI'
|
||||
WHEN provider_type = 'COHERE' THEN 'Cohere'
|
||||
WHEN provider_type = 'GOOGLE' THEN 'Google'
|
||||
WHEN provider_type = 'VOYAGE' THEN 'Voyage'
|
||||
ELSE provider_type
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Drop the provider_type column from embedding_provider
|
||||
op.drop_column("embedding_provider", "provider_type")
|
||||
|
||||
# Recreate the foreign key constraint in embedding_model table
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["cloud_provider_id"],
|
||||
["id"],
|
||||
)
|
@ -378,7 +378,7 @@ def update_loop(
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
if db_embedding_model.provider_type is None:
|
||||
logger.notice("Running a first inference to warm up embedding model")
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=db_embedding_model,
|
||||
|
@ -469,7 +469,7 @@ if __name__ == "__main__":
|
||||
# or the tokens have updated (set up for the first time)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
if embedding_model.cloud_provider_id is None:
|
||||
if embedding_model.provider_type is None:
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
|
@ -14,32 +14,34 @@ from danswer.db.llm import fetch_embedding_provider
|
||||
from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.models import EmbeddingModelCreateRequest
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def create_embedding_model(
|
||||
model_details: EmbeddingModelDetail,
|
||||
create_embed_model_details: EmbeddingModelCreateRequest,
|
||||
db_session: Session,
|
||||
status: IndexModelStatus = IndexModelStatus.FUTURE,
|
||||
) -> EmbeddingModel:
|
||||
embedding_model = EmbeddingModel(
|
||||
model_name=model_details.model_name,
|
||||
model_dim=model_details.model_dim,
|
||||
normalize=model_details.normalize,
|
||||
query_prefix=model_details.query_prefix,
|
||||
passage_prefix=model_details.passage_prefix,
|
||||
model_name=create_embed_model_details.model_name,
|
||||
model_dim=create_embed_model_details.model_dim,
|
||||
normalize=create_embed_model_details.normalize,
|
||||
query_prefix=create_embed_model_details.query_prefix,
|
||||
passage_prefix=create_embed_model_details.passage_prefix,
|
||||
status=status,
|
||||
cloud_provider_id=model_details.cloud_provider_id,
|
||||
provider_type=create_embed_model_details.provider_type,
|
||||
# Every single embedding model except the initial one from migrations has this name
|
||||
# The initial one from migration is called "danswer_chunk"
|
||||
index_name=model_details.index_name,
|
||||
index_name=create_embed_model_details.index_name,
|
||||
)
|
||||
|
||||
db_session.add(embedding_model)
|
||||
@ -48,14 +50,14 @@ def create_embedding_model(
|
||||
return embedding_model
|
||||
|
||||
|
||||
def get_model_id_from_name(
|
||||
db_session: Session, embedding_provider_name: str
|
||||
) -> int | None:
|
||||
def get_embedding_provider_from_provider_type(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProvider | None:
|
||||
query = select(CloudEmbeddingProvider).where(
|
||||
CloudEmbeddingProvider.name == embedding_provider_name
|
||||
CloudEmbeddingProvider.provider_type == provider_type
|
||||
)
|
||||
provider = db_session.execute(query).scalars().first()
|
||||
return provider.id if provider else None
|
||||
return provider if provider else None
|
||||
|
||||
|
||||
def get_current_db_embedding_provider(
|
||||
@ -65,14 +67,12 @@ def get_current_db_embedding_provider(
|
||||
get_current_db_embedding_model(db_session=db_session)
|
||||
)
|
||||
|
||||
if (
|
||||
current_embedding_model is None
|
||||
or current_embedding_model.cloud_provider_id is None
|
||||
):
|
||||
if current_embedding_model is None or current_embedding_model.provider_type is None:
|
||||
return None
|
||||
|
||||
embedding_provider = fetch_embedding_provider(
|
||||
db_session=db_session, provider_id=current_embedding_model.cloud_provider_id
|
||||
db_session=db_session,
|
||||
provider_type=current_embedding_model.provider_type,
|
||||
)
|
||||
if embedding_provider is None:
|
||||
raise RuntimeError("No embedding provider exists for this model.")
|
||||
|
@ -12,6 +12,7 @@ from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
def update_group_llm_provider_relationships__no_commit(
|
||||
@ -41,7 +42,7 @@ def upsert_cloud_embedding_provider(
|
||||
) -> CloudEmbeddingProvider:
|
||||
existing_provider = (
|
||||
db_session.query(CloudEmbeddingProviderModel)
|
||||
.filter_by(name=provider.name)
|
||||
.filter_by(provider_type=provider.provider_type)
|
||||
.first()
|
||||
)
|
||||
if existing_provider:
|
||||
@ -124,11 +125,11 @@ def fetch_existing_llm_providers(
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_id: int
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
return db_session.scalar(
|
||||
select(CloudEmbeddingProviderModel).where(
|
||||
CloudEmbeddingProviderModel.id == provider_id
|
||||
CloudEmbeddingProviderModel.provider_type == provider_type
|
||||
)
|
||||
)
|
||||
|
||||
@ -154,11 +155,11 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
|
||||
|
||||
|
||||
def remove_embedding_provider(
|
||||
db_session: Session, embedding_provider_name: str
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(CloudEmbeddingProviderModel).where(
|
||||
CloudEmbeddingProviderModel.name == embedding_provider_name
|
||||
CloudEmbeddingProviderModel.provider_type == provider_type
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -558,13 +558,14 @@ class EmbeddingModel(Base):
|
||||
index_name: Mapped[str] = mapped_column(String)
|
||||
|
||||
# New field for cloud provider relationship
|
||||
cloud_provider_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("embedding_provider.id")
|
||||
provider_type: Mapped[EmbeddingProvider | None] = mapped_column(
|
||||
ForeignKey("embedding_provider.provider_type"), nullable=True
|
||||
)
|
||||
|
||||
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
|
||||
"CloudEmbeddingProvider",
|
||||
back_populates="embedding_models",
|
||||
foreign_keys=[cloud_provider_id],
|
||||
foreign_keys=[provider_type],
|
||||
)
|
||||
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
@ -588,15 +589,7 @@ class EmbeddingModel(Base):
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
|
||||
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
|
||||
|
||||
@property
|
||||
def provider_type(self) -> EmbeddingProvider | None:
|
||||
return (
|
||||
EmbeddingProvider(self.cloud_provider.name.lower())
|
||||
if self.cloud_provider is not None
|
||||
else None
|
||||
)
|
||||
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
@ -1073,24 +1066,18 @@ class LLMProvider(Base):
|
||||
class CloudEmbeddingProvider(Base):
|
||||
__tablename__ = "embedding_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||
default_model_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("embedding_model.id"), nullable=True
|
||||
provider_type: Mapped[EmbeddingProvider] = mapped_column(
|
||||
Enum(EmbeddingProvider), primary_key=True
|
||||
)
|
||||
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||
embedding_models: Mapped[list["EmbeddingModel"]] = relationship(
|
||||
"EmbeddingModel",
|
||||
back_populates="cloud_provider",
|
||||
foreign_keys="EmbeddingModel.cloud_provider_id",
|
||||
)
|
||||
default_model: Mapped["EmbeddingModel"] = relationship(
|
||||
"EmbeddingModel", foreign_keys=[default_model_id]
|
||||
foreign_keys="EmbeddingModel.provider_type",
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EmbeddingProvider(name='{self.name}')>"
|
||||
return f"<EmbeddingProvider(type='{self.provider_type}')>"
|
||||
|
||||
|
||||
class DocumentSet(Base):
|
||||
|
@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -99,9 +100,7 @@ class EmbeddingModelDetail(BaseModel):
|
||||
normalize: bool
|
||||
query_prefix: str | None
|
||||
passage_prefix: str | None
|
||||
cloud_provider_id: int | None = None
|
||||
cloud_provider_name: str | None = None
|
||||
index_name: str | None = None
|
||||
provider_type: EmbeddingProvider | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@ -114,6 +113,9 @@ class EmbeddingModelDetail(BaseModel):
|
||||
normalize=embedding_model.normalize,
|
||||
query_prefix=embedding_model.query_prefix,
|
||||
passage_prefix=embedding_model.passage_prefix,
|
||||
cloud_provider_id=embedding_model.cloud_provider_id,
|
||||
index_name=embedding_model.index_name,
|
||||
provider_type=embedding_model.provider_type,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingModelCreateRequest(EmbeddingModelDetail):
|
||||
index_name: str
|
||||
|
@ -343,7 +343,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
setup_vespa(document_index, db_embedding_model, secondary_db_embedding_model)
|
||||
|
||||
logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
if db_embedding_model.provider_type is None:
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=db_embedding_model,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
|
@ -17,6 +17,7 @@ from danswer.server.manage.embedding.models import TestEmbeddingRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
logger = setup_logger()
|
||||
@ -36,7 +37,7 @@ def test_embedding_configuration(
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
api_key=test_llm_request.api_key,
|
||||
provider_type=test_llm_request.provider,
|
||||
provider_type=test_llm_request.provider_type,
|
||||
normalize=False,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
@ -66,22 +67,22 @@ def list_embedding_providers(
|
||||
]
|
||||
|
||||
|
||||
@admin_router.delete("/embedding-provider/{embedding_provider_name}")
|
||||
@admin_router.delete("/embedding-provider/{provider_type}")
|
||||
def delete_embedding_provider(
|
||||
embedding_provider_name: str,
|
||||
provider_type: EmbeddingProvider,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
embedding_provider = get_current_db_embedding_provider(db_session=db_session)
|
||||
if (
|
||||
embedding_provider is not None
|
||||
and embedding_provider_name == embedding_provider.name
|
||||
and provider_type == embedding_provider.provider_type
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't delete a currently active model"
|
||||
)
|
||||
|
||||
remove_embedding_provider(db_session, embedding_provider_name)
|
||||
remove_embedding_provider(db_session, provider_type=provider_type)
|
||||
|
||||
|
||||
@admin_router.put("/embedding-provider")
|
||||
|
@ -9,29 +9,24 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class TestEmbeddingRequest(BaseModel):
|
||||
provider: EmbeddingProvider
|
||||
provider_type: EmbeddingProvider
|
||||
api_key: str | None = None
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(BaseModel):
|
||||
name: str
|
||||
provider_type: EmbeddingProvider
|
||||
api_key: str | None = None
|
||||
default_model_id: int | None = None
|
||||
id: int
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls, cloud_provider_model: "CloudEmbeddingProviderModel"
|
||||
) -> "CloudEmbeddingProvider":
|
||||
return cls(
|
||||
id=cloud_provider_model.id,
|
||||
name=cloud_provider_model.name,
|
||||
provider_type=cloud_provider_model.provider_type,
|
||||
api_key=cloud_provider_model.api_key,
|
||||
default_model_id=cloud_provider_model.default_model_id,
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProviderCreationRequest(BaseModel):
|
||||
name: str
|
||||
provider_type: EmbeddingProvider
|
||||
api_key: str | None = None
|
||||
default_model_id: int | None = None
|
||||
|
@ -11,7 +11,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.embedding_model import create_embedding_model
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_model_id_from_name
|
||||
from danswer.db.embedding_model import get_embedding_provider_from_provider_type
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.embedding_model import update_embedding_model_status
|
||||
from danswer.db.engine import get_session
|
||||
@ -19,6 +19,7 @@ from danswer.db.index_attempt import expire_index_attempts
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.models import EmbeddingModelCreateRequest
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
@ -42,30 +43,32 @@ def set_new_embedding_model(
|
||||
"""Creates a new EmbeddingModel row and cancels the previous secondary indexing if any
|
||||
Gives an error if the same model name is used as the current or secondary index
|
||||
"""
|
||||
current_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
if embed_model_details.cloud_provider_name is not None:
|
||||
cloud_id = get_model_id_from_name(
|
||||
db_session, embed_model_details.cloud_provider_name
|
||||
# Validate cloud provider exists
|
||||
if embed_model_details.provider_type is not None:
|
||||
cloud_provider = get_embedding_provider_from_provider_type(
|
||||
db_session, provider_type=embed_model_details.provider_type
|
||||
)
|
||||
|
||||
if cloud_id is None:
|
||||
if cloud_provider is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No ID exists for given provider name",
|
||||
detail=f"No embedding provider exists for cloud embedding type {embed_model_details.provider_type}",
|
||||
)
|
||||
|
||||
embed_model_details.cloud_provider_id = cloud_id
|
||||
current_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
embed_model_details.index_name = (
|
||||
f"danswer_chunk_{clean_model_name(embed_model_details.model_name)}"
|
||||
)
|
||||
# account for same model name being indexed with two different configurations
|
||||
# We define index name here
|
||||
index_name = f"danswer_chunk_{clean_model_name(embed_model_details.model_name)}"
|
||||
if (
|
||||
embed_model_details.model_name == current_model.model_name
|
||||
and not current_model.index_name.endswith(ALT_INDEX_SUFFIX)
|
||||
):
|
||||
embed_model_details.index_name += ALT_INDEX_SUFFIX
|
||||
index_name += ALT_INDEX_SUFFIX
|
||||
|
||||
create_embed_model_details = EmbeddingModelCreateRequest(
|
||||
**embed_model_details.dict(), index_name=index_name
|
||||
)
|
||||
|
||||
secondary_model = get_secondary_db_embedding_model(db_session)
|
||||
|
||||
@ -89,8 +92,7 @@ def set_new_embedding_model(
|
||||
)
|
||||
|
||||
new_model = create_embedding_model(
|
||||
model_details=embed_model_details,
|
||||
db_session=db_session,
|
||||
create_embed_model_details=create_embed_model_details, db_session=db_session
|
||||
)
|
||||
|
||||
# Ensure Vespa has the new index immediately
|
||||
|
@ -3,7 +3,6 @@ import { Modal } from "@/components/Modal";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { ConnectorIndexingStatus } from "@/lib/types";
|
||||
import { Button, Text, Title } from "@tremor/react";
|
||||
import Link from "next/link";
|
||||
import { useState } from "react";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { ReindexingProgressTable } from "../../../../components/embedding/ReindexingProgressTable";
|
||||
|
@ -79,7 +79,7 @@ function Main() {
|
||||
(provider) =>
|
||||
provider.embedding_models.map((model) => ({
|
||||
...model,
|
||||
cloud_provider_id: provider.id,
|
||||
provider_type: provider.provider_type,
|
||||
model_name: model.model_name, // Ensure model_name is set for consistency
|
||||
}))
|
||||
);
|
||||
|
@ -11,6 +11,7 @@ import {
|
||||
INVALID_OLD_MODEL,
|
||||
HostedEmbeddingModel,
|
||||
EmbeddingModelDescriptor,
|
||||
EmbeddingProvider,
|
||||
} from "../../../components/embedding/interfaces";
|
||||
import { Connector } from "@/lib/connectors/connectors";
|
||||
import OpenEmbeddingPage from "./pages/OpenEmbeddingPage";
|
||||
@ -28,8 +29,7 @@ import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../configuration/llm/constants";
|
||||
export interface EmbeddingDetails {
|
||||
api_key: string;
|
||||
custom_config: any;
|
||||
default_model_id?: number;
|
||||
name: string;
|
||||
provider_type: EmbeddingProvider;
|
||||
}
|
||||
|
||||
export function EmbeddingModelSelection({
|
||||
@ -122,28 +122,28 @@ export function EmbeddingModelSelection({
|
||||
};
|
||||
|
||||
const clientsideAddProvider = (provider: CloudEmbeddingProvider) => {
|
||||
const providerName = provider.name;
|
||||
const providerType = provider.provider_type;
|
||||
setNewEnabledProviders((newEnabledProviders) => [
|
||||
...newEnabledProviders,
|
||||
providerName,
|
||||
providerType,
|
||||
]);
|
||||
setNewUnenabledProviders((newUnenabledProviders) =>
|
||||
newUnenabledProviders.filter(
|
||||
(givenProvidername) => givenProvidername != providerName
|
||||
(givenProviderType) => givenProviderType != providerType
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
const clientsideRemoveProvider = (provider: CloudEmbeddingProvider) => {
|
||||
const providerName = provider.name;
|
||||
const providerType = provider.provider_type;
|
||||
setNewEnabledProviders((newEnabledProviders) =>
|
||||
newEnabledProviders.filter(
|
||||
(givenProvidername) => givenProvidername != providerName
|
||||
(givenProviderType) => givenProviderType != providerType
|
||||
)
|
||||
);
|
||||
setNewUnenabledProviders((newUnenabledProviders) => [
|
||||
...newUnenabledProviders,
|
||||
providerName,
|
||||
providerType,
|
||||
]);
|
||||
};
|
||||
|
||||
@ -191,7 +191,7 @@ export function EmbeddingModelSelection({
|
||||
)}
|
||||
{changeCredentialsProvider && (
|
||||
<ChangeCredentialsModal
|
||||
useFileUpload={changeCredentialsProvider.name == "Google"}
|
||||
useFileUpload={changeCredentialsProvider.provider_type == "Google"}
|
||||
onDeleted={() => {
|
||||
clientsideRemoveProvider(changeCredentialsProvider);
|
||||
setChangeCredentialsProvider(null);
|
||||
|
@ -74,7 +74,7 @@ export function ChangeCredentialsModal({
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.name}`,
|
||||
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
@ -99,19 +99,12 @@ export function ChangeCredentialsModal({
|
||||
|
||||
const handleSubmit = async () => {
|
||||
setTestError("");
|
||||
|
||||
try {
|
||||
const body = JSON.stringify({
|
||||
api_key: apiKey,
|
||||
provider: provider.name.toLowerCase().split(" ")[0],
|
||||
default_model_id: provider.name,
|
||||
});
|
||||
|
||||
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider: provider.name.toLowerCase().split(" ")[0],
|
||||
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
|
||||
api_key: apiKey,
|
||||
}),
|
||||
});
|
||||
@ -125,7 +118,7 @@ export function ChangeCredentialsModal({
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
name: provider.name,
|
||||
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
|
||||
api_key: apiKey,
|
||||
is_default_provider: false,
|
||||
is_configured: true,
|
||||
@ -151,7 +144,7 @@ export function ChangeCredentialsModal({
|
||||
<Modal
|
||||
width="max-w-3xl"
|
||||
icon={provider.icon}
|
||||
title={`Modify your ${provider.name} key`}
|
||||
title={`Modify your ${provider.provider_type} key`}
|
||||
onOutsideClick={onCancel}
|
||||
>
|
||||
<div className="mb-4">
|
||||
|
@ -15,13 +15,13 @@ export function DeleteCredentialsModal({
|
||||
return (
|
||||
<Modal
|
||||
width="max-w-3xl"
|
||||
title={`Nuke ${modelProvider.name} Credentials?`}
|
||||
title={`Delete ${modelProvider.provider_type} Credentials?`}
|
||||
onOutsideClick={onCancel}
|
||||
>
|
||||
<div className="mb-4">
|
||||
<Text className="text-lg mb-2">
|
||||
You're about to delete your {modelProvider.name} credentials. Are
|
||||
you sure?
|
||||
You're about to delete your {modelProvider.provider_type}{" "}
|
||||
credentials. Are you sure?
|
||||
</Text>
|
||||
<Callout
|
||||
title="Point of No Return"
|
||||
|
@ -19,24 +19,24 @@ export function ProviderCreationModal({
|
||||
onCancel: () => void;
|
||||
existingProvider?: CloudEmbeddingProvider;
|
||||
}) {
|
||||
const useFileUpload = selectedProvider.name == "Google";
|
||||
const useFileUpload = selectedProvider.provider_type == "Google";
|
||||
|
||||
const [isProcessing, setIsProcessing] = useState(false);
|
||||
const [errorMsg, setErrorMsg] = useState<string>("");
|
||||
const [fileName, setFileName] = useState<string>("");
|
||||
|
||||
const initialValues = {
|
||||
name: existingProvider?.name || selectedProvider.name,
|
||||
provider_type:
|
||||
existingProvider?.provider_type || selectedProvider.provider_type,
|
||||
api_key: existingProvider?.api_key || "",
|
||||
custom_config: existingProvider?.custom_config
|
||||
? Object.entries(existingProvider.custom_config)
|
||||
: [],
|
||||
default_model_name: "",
|
||||
model_id: 0,
|
||||
};
|
||||
|
||||
const validationSchema = Yup.object({
|
||||
name: Yup.string().required("Name is required"),
|
||||
provider_type: Yup.string().required("Provider type is required"),
|
||||
api_key: useFileUpload
|
||||
? Yup.string()
|
||||
: Yup.string().required("API Key is required"),
|
||||
@ -76,7 +76,6 @@ export function ProviderCreationModal({
|
||||
) => {
|
||||
setIsProcessing(true);
|
||||
setErrorMsg("");
|
||||
|
||||
try {
|
||||
const customConfig = Object.fromEntries(values.custom_config);
|
||||
|
||||
@ -86,7 +85,7 @@ export function ProviderCreationModal({
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider: values.name.toLowerCase().split(" ")[0],
|
||||
provider_type: values.provider_type.toLowerCase().split(" ")[0],
|
||||
api_key: values.api_key,
|
||||
}),
|
||||
}
|
||||
@ -105,6 +104,7 @@ export function ProviderCreationModal({
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
...values,
|
||||
provider_type: values.provider_type.toLowerCase().split(" ")[0],
|
||||
custom_config: customConfig,
|
||||
is_default_provider: false,
|
||||
is_configured: true,
|
||||
@ -134,7 +134,7 @@ export function ProviderCreationModal({
|
||||
return (
|
||||
<Modal
|
||||
width="max-w-3xl"
|
||||
title={`Configure ${selectedProvider.name}`}
|
||||
title={`Configure ${selectedProvider.provider_type}`}
|
||||
onOutsideClick={onCancel}
|
||||
icon={selectedProvider.icon}
|
||||
>
|
||||
|
@ -39,12 +39,12 @@ export default function CloudEmbeddingPage({
|
||||
React.SetStateAction<CloudEmbeddingProvider | null>
|
||||
>;
|
||||
}) {
|
||||
function hasNameInArray(
|
||||
arr: Array<{ name: string }>,
|
||||
function hasProviderTypeinArray(
|
||||
arr: Array<{ provider_type: string }>,
|
||||
searchName: string
|
||||
): boolean {
|
||||
return arr.some(
|
||||
(item) => item.name.toLowerCase() === searchName.toLowerCase()
|
||||
(item) => item.provider_type.toLowerCase() === searchName.toLowerCase()
|
||||
);
|
||||
}
|
||||
|
||||
@ -52,10 +52,13 @@ export default function CloudEmbeddingPage({
|
||||
(model) => ({
|
||||
...model,
|
||||
configured:
|
||||
!newUnenabledProviders.includes(model.name) &&
|
||||
(newEnabledProviders.includes(model.name) ||
|
||||
!newUnenabledProviders.includes(model.provider_type) &&
|
||||
(newEnabledProviders.includes(model.provider_type) ||
|
||||
(embeddingProviderDetails &&
|
||||
hasNameInArray(embeddingProviderDetails, model.name))!),
|
||||
hasProviderTypeinArray(
|
||||
embeddingProviderDetails,
|
||||
model.provider_type
|
||||
))!),
|
||||
})
|
||||
);
|
||||
|
||||
@ -71,11 +74,12 @@ export default function CloudEmbeddingPage({
|
||||
|
||||
<div className="gap-4 mt-2 pb-10 flex content-start flex-wrap">
|
||||
{providers.map((provider) => (
|
||||
<div key={provider.name} className="mt-4 w-full">
|
||||
<div key={provider.provider_type} className="mt-4 w-full">
|
||||
<div className="flex items-center mb-2">
|
||||
{provider.icon({ size: 40 })}
|
||||
<h2 className="ml-2 mt-2 text-xl font-bold">
|
||||
{provider.name} {provider.name == "Cohere" && "(recommended)"}
|
||||
{provider.provider_type}{" "}
|
||||
{provider.provider_type == "Cohere" && "(recommended)"}
|
||||
</h2>
|
||||
<HoverPopup
|
||||
mainContent={
|
||||
|
@ -167,12 +167,14 @@ export default function EmbeddingForm() {
|
||||
const onConfirm = async () => {
|
||||
let newModel: EmbeddingModelDescriptor;
|
||||
|
||||
if ("cloud_provider_name" in selectedProvider) {
|
||||
if ("provider_type" in selectedProvider) {
|
||||
// This is a CloudEmbeddingModel
|
||||
newModel = {
|
||||
...selectedProvider,
|
||||
model_name: selectedProvider.model_name,
|
||||
cloud_provider_name: selectedProvider.cloud_provider_name,
|
||||
provider_type: selectedProvider.provider_type
|
||||
?.toLowerCase()
|
||||
.split(" ")[0],
|
||||
};
|
||||
} else {
|
||||
// This is an EmbeddingModelDescriptor
|
||||
@ -180,7 +182,7 @@ export default function EmbeddingForm() {
|
||||
...selectedProvider,
|
||||
model_name: selectedProvider.model_name!,
|
||||
description: "",
|
||||
cloud_provider_name: null,
|
||||
provider_type: null,
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -9,11 +9,15 @@ import {
|
||||
VoyageIcon,
|
||||
} from "@/components/icons/icons";
|
||||
|
||||
// Cloud Provider (not needed for hosted ones)
|
||||
export enum EmbeddingProvider {
|
||||
OPENAI = "OpenAI",
|
||||
COHERE = "Cohere",
|
||||
VOYAGE = "Voyage",
|
||||
GOOGLE = "Google",
|
||||
}
|
||||
|
||||
export interface CloudEmbeddingProvider {
|
||||
id: number;
|
||||
name: string;
|
||||
provider_type: EmbeddingProvider;
|
||||
api_key?: string;
|
||||
custom_config?: Record<string, string>;
|
||||
docsLink?: string;
|
||||
@ -37,12 +41,11 @@ export interface EmbeddingModelDescriptor {
|
||||
normalize: boolean;
|
||||
query_prefix: string;
|
||||
passage_prefix: string;
|
||||
cloud_provider_name?: string | null;
|
||||
provider_type?: string | null;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export interface CloudEmbeddingModel extends EmbeddingModelDescriptor {
|
||||
cloud_provider_name: string | null;
|
||||
pricePerMillion: number;
|
||||
enabled?: boolean;
|
||||
mtebScore: number;
|
||||
@ -124,8 +127,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
|
||||
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
{
|
||||
id: 1,
|
||||
name: "Cohere",
|
||||
provider_type: EmbeddingProvider.COHERE,
|
||||
website: "https://cohere.ai",
|
||||
icon: CohereIcon,
|
||||
docsLink:
|
||||
@ -136,8 +138,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
costslink: "https://cohere.com/pricing",
|
||||
embedding_models: [
|
||||
{
|
||||
provider_type: EmbeddingProvider.COHERE,
|
||||
model_name: "embed-english-v3.0",
|
||||
cloud_provider_name: "Cohere",
|
||||
description:
|
||||
"Cohere's English embedding model. Good performance for English-language tasks.",
|
||||
pricePerMillion: 0.1,
|
||||
@ -151,7 +153,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
},
|
||||
{
|
||||
model_name: "embed-english-light-v3.0",
|
||||
cloud_provider_name: "Cohere",
|
||||
provider_type: EmbeddingProvider.COHERE,
|
||||
description:
|
||||
"Cohere's lightweight English embedding model. Faster and more efficient for simpler tasks.",
|
||||
pricePerMillion: 0.1,
|
||||
@ -166,8 +168,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
],
|
||||
},
|
||||
{
|
||||
id: 0,
|
||||
name: "OpenAI",
|
||||
provider_type: EmbeddingProvider.OPENAI,
|
||||
website: "https://openai.com",
|
||||
icon: OpenAIIcon,
|
||||
description: "AI industry leader known for ChatGPT and DALL-E",
|
||||
@ -177,8 +178,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
costslink: "https://openai.com/pricing",
|
||||
embedding_models: [
|
||||
{
|
||||
provider_type: EmbeddingProvider.OPENAI,
|
||||
model_name: "text-embedding-3-large",
|
||||
cloud_provider_name: "OpenAI",
|
||||
description:
|
||||
"OpenAI's large embedding model. Best performance, but more expensive.",
|
||||
pricePerMillion: 0.13,
|
||||
@ -191,8 +192,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
enabled: false,
|
||||
},
|
||||
{
|
||||
provider_type: EmbeddingProvider.OPENAI,
|
||||
model_name: "text-embedding-3-small",
|
||||
cloud_provider_name: "OpenAI",
|
||||
model_dim: 1536,
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
@ -208,8 +209,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
},
|
||||
|
||||
{
|
||||
id: 2,
|
||||
name: "Google",
|
||||
provider_type: EmbeddingProvider.GOOGLE,
|
||||
website: "https://ai.google",
|
||||
icon: GoogleIcon,
|
||||
docsLink:
|
||||
@ -220,7 +220,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
costslink: "https://cloud.google.com/vertex-ai/pricing",
|
||||
embedding_models: [
|
||||
{
|
||||
cloud_provider_name: "Google",
|
||||
provider_type: EmbeddingProvider.GOOGLE,
|
||||
model_name: "text-embedding-004",
|
||||
description: "Google's most recent text embedding model.",
|
||||
pricePerMillion: 0.025,
|
||||
@ -233,7 +233,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
passage_prefix: "",
|
||||
},
|
||||
{
|
||||
cloud_provider_name: "Google",
|
||||
provider_type: EmbeddingProvider.GOOGLE,
|
||||
model_name: "textembedding-gecko@003",
|
||||
description: "Google's Gecko embedding model. Powerful and efficient.",
|
||||
pricePerMillion: 0.025,
|
||||
@ -248,8 +248,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
],
|
||||
},
|
||||
{
|
||||
id: 3,
|
||||
name: "Voyage",
|
||||
provider_type: EmbeddingProvider.VOYAGE,
|
||||
website: "https://www.voyageai.com",
|
||||
icon: VoyageIcon,
|
||||
description: "Advanced NLP research startup born from Stanford AI Labs",
|
||||
@ -259,7 +258,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
costslink: "https://www.voyageai.com/pricing",
|
||||
embedding_models: [
|
||||
{
|
||||
cloud_provider_name: "Voyage",
|
||||
provider_type: EmbeddingProvider.VOYAGE,
|
||||
model_name: "voyage-large-2-instruct",
|
||||
description:
|
||||
"Voyage's large embedding model. High performance with instruction fine-tuning.",
|
||||
@ -273,7 +272,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
passage_prefix: "",
|
||||
},
|
||||
{
|
||||
cloud_provider_name: "Voyage",
|
||||
provider_type: EmbeddingProvider.VOYAGE,
|
||||
model_name: "voyage-light-2-instruct",
|
||||
description:
|
||||
"Voyage's lightweight embedding model. Good balance of performance and efficiency.",
|
||||
|
Loading…
x
Reference in New Issue
Block a user