mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-20 13:01:34 +02:00
add third party embedding models (#1818)
This commit is contained in:
parent
b6bd818e60
commit
e7f81d1688
@ -0,0 +1,65 @@
|
|||||||
|
"""add cloud embedding model and update embedding_model
|
||||||
|
|
||||||
|
Revision ID: 44f856ae2a4a
|
||||||
|
Revises: d716b0791ddd
|
||||||
|
Create Date: 2024-06-28 20:01:05.927647
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "44f856ae2a4a"
|
||||||
|
down_revision = "d716b0791ddd"
|
||||||
|
branch_labels: None = None
|
||||||
|
depends_on: None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Create embedding_provider table
|
||||||
|
op.create_table(
|
||||||
|
"embedding_provider",
|
||||||
|
sa.Column("id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("name", sa.String(), nullable=False),
|
||||||
|
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||||
|
sa.Column("default_model_id", sa.Integer(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add cloud_provider_id to embedding_model table
|
||||||
|
op.add_column(
|
||||||
|
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add foreign key constraints
|
||||||
|
op.create_foreign_key(
|
||||||
|
"fk_embedding_model_cloud_provider",
|
||||||
|
"embedding_model",
|
||||||
|
"embedding_provider",
|
||||||
|
["cloud_provider_id"],
|
||||||
|
["id"],
|
||||||
|
)
|
||||||
|
op.create_foreign_key(
|
||||||
|
"fk_embedding_provider_default_model",
|
||||||
|
"embedding_provider",
|
||||||
|
"embedding_model",
|
||||||
|
["default_model_id"],
|
||||||
|
["id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Remove foreign key constraints
|
||||||
|
op.drop_constraint(
|
||||||
|
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||||
|
)
|
||||||
|
op.drop_constraint(
|
||||||
|
"fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove cloud_provider_id column
|
||||||
|
op.drop_column("embedding_model", "cloud_provider_id")
|
||||||
|
|
||||||
|
# Drop embedding_provider table
|
||||||
|
op.drop_table("embedding_provider")
|
@ -10,8 +10,8 @@ from alembic import op
|
|||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = "d716b0791ddd"
|
revision = "d716b0791ddd"
|
||||||
down_revision = "7aea705850d5"
|
down_revision = "7aea705850d5"
|
||||||
branch_labels = None
|
branch_labels: None = None
|
||||||
depends_on = None
|
depends_on: None = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
|
@ -98,7 +98,6 @@ def _run_indexing(
|
|||||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
db_embedding_model = index_attempt.embedding_model
|
db_embedding_model = index_attempt.embedding_model
|
||||||
index_name = db_embedding_model.index_name
|
index_name = db_embedding_model.index_name
|
||||||
|
|
||||||
@ -116,6 +115,8 @@ def _run_indexing(
|
|||||||
normalize=db_embedding_model.normalize,
|
normalize=db_embedding_model.normalize,
|
||||||
query_prefix=db_embedding_model.query_prefix,
|
query_prefix=db_embedding_model.query_prefix,
|
||||||
passage_prefix=db_embedding_model.passage_prefix,
|
passage_prefix=db_embedding_model.passage_prefix,
|
||||||
|
api_key=db_embedding_model.api_key,
|
||||||
|
provider_type=db_embedding_model.provider_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
indexing_pipeline = build_indexing_pipeline(
|
indexing_pipeline = build_indexing_pipeline(
|
||||||
@ -287,6 +288,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
|||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
index_attempt_id=index_attempt_id,
|
index_attempt_id=index_attempt_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if attempt is None:
|
if attempt is None:
|
||||||
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
|
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
|
||||||
|
|
||||||
|
@ -343,6 +343,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
|||||||
|
|
||||||
# So that the first time users aren't surprised by really slow speed of first
|
# So that the first time users aren't surprised by really slow speed of first
|
||||||
# batch of documents indexed
|
# batch of documents indexed
|
||||||
|
|
||||||
|
if db_embedding_model.cloud_provider_id is None:
|
||||||
logger.info("Running a first inference to warm up embedding model")
|
logger.info("Running a first inference to warm up embedding model")
|
||||||
warm_up_encoders(
|
warm_up_encoders(
|
||||||
model_name=db_embedding_model.model_name,
|
model_name=db_embedding_model.model_name,
|
||||||
|
@ -469,7 +469,7 @@ if __name__ == "__main__":
|
|||||||
# or the tokens have updated (set up for the first time)
|
# or the tokens have updated (set up for the first time)
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
embedding_model = get_current_db_embedding_model(db_session)
|
embedding_model = get_current_db_embedding_model(db_session)
|
||||||
|
if embedding_model.cloud_provider_id is None:
|
||||||
warm_up_encoders(
|
warm_up_encoders(
|
||||||
model_name=embedding_model.model_name,
|
model_name=embedding_model.model_name,
|
||||||
normalize=embedding_model.normalize,
|
normalize=embedding_model.normalize,
|
||||||
|
@ -10,10 +10,15 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
|||||||
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||||
|
from danswer.db.llm import fetch_embedding_provider
|
||||||
|
from danswer.db.models import CloudEmbeddingProvider
|
||||||
from danswer.db.models import EmbeddingModel
|
from danswer.db.models import EmbeddingModel
|
||||||
from danswer.db.models import IndexModelStatus
|
from danswer.db.models import IndexModelStatus
|
||||||
from danswer.indexing.models import EmbeddingModelDetail
|
from danswer.indexing.models import EmbeddingModelDetail
|
||||||
from danswer.search.search_nlp_models import clean_model_name
|
from danswer.search.search_nlp_models import clean_model_name
|
||||||
|
from danswer.server.manage.embedding.models import (
|
||||||
|
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||||
|
)
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -31,6 +36,7 @@ def create_embedding_model(
|
|||||||
query_prefix=model_details.query_prefix,
|
query_prefix=model_details.query_prefix,
|
||||||
passage_prefix=model_details.passage_prefix,
|
passage_prefix=model_details.passage_prefix,
|
||||||
status=status,
|
status=status,
|
||||||
|
cloud_provider_id=model_details.cloud_provider_id,
|
||||||
# Every single embedding model except the initial one from migrations has this name
|
# Every single embedding model except the initial one from migrations has this name
|
||||||
# The initial one from migration is called "danswer_chunk"
|
# The initial one from migration is called "danswer_chunk"
|
||||||
index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}",
|
index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}",
|
||||||
@ -42,6 +48,42 @@ def create_embedding_model(
|
|||||||
return embedding_model
|
return embedding_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_id_from_name(
|
||||||
|
db_session: Session, embedding_provider_name: str
|
||||||
|
) -> int | None:
|
||||||
|
query = select(CloudEmbeddingProvider).where(
|
||||||
|
CloudEmbeddingProvider.name == embedding_provider_name
|
||||||
|
)
|
||||||
|
provider = db_session.execute(query).scalars().first()
|
||||||
|
return provider.id if provider else None
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_db_embedding_provider(
|
||||||
|
db_session: Session,
|
||||||
|
) -> ServerCloudEmbeddingProvider | None:
|
||||||
|
current_embedding_model = EmbeddingModelDetail.from_model(
|
||||||
|
get_current_db_embedding_model(db_session=db_session)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
current_embedding_model is None
|
||||||
|
or current_embedding_model.cloud_provider_id is None
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
embedding_provider = fetch_embedding_provider(
|
||||||
|
db_session=db_session, provider_id=current_embedding_model.cloud_provider_id
|
||||||
|
)
|
||||||
|
if embedding_provider is None:
|
||||||
|
raise RuntimeError("No embedding provider exists for this model.")
|
||||||
|
|
||||||
|
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
|
||||||
|
cloud_provider_model=embedding_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
return current_embedding_provider
|
||||||
|
|
||||||
|
|
||||||
def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel:
|
def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel:
|
||||||
query = (
|
query = (
|
||||||
select(EmbeddingModel)
|
select(EmbeddingModel)
|
||||||
|
@ -2,11 +2,34 @@ from sqlalchemy import delete
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||||
|
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||||
|
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||||
from danswer.server.manage.llm.models import FullLLMProvider
|
from danswer.server.manage.llm.models import FullLLMProvider
|
||||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_cloud_embedding_provider(
|
||||||
|
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
||||||
|
) -> CloudEmbeddingProvider:
|
||||||
|
existing_provider = (
|
||||||
|
db_session.query(CloudEmbeddingProviderModel)
|
||||||
|
.filter_by(name=provider.name)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if existing_provider:
|
||||||
|
for key, value in provider.dict().items():
|
||||||
|
setattr(existing_provider, key, value)
|
||||||
|
else:
|
||||||
|
new_provider = CloudEmbeddingProviderModel(**provider.dict())
|
||||||
|
db_session.add(new_provider)
|
||||||
|
existing_provider = new_provider
|
||||||
|
db_session.commit()
|
||||||
|
db_session.refresh(existing_provider)
|
||||||
|
return CloudEmbeddingProvider.from_request(existing_provider)
|
||||||
|
|
||||||
|
|
||||||
def upsert_llm_provider(
|
def upsert_llm_provider(
|
||||||
db_session: Session, llm_provider: LLMProviderUpsertRequest
|
db_session: Session, llm_provider: LLMProviderUpsertRequest
|
||||||
) -> FullLLMProvider:
|
) -> FullLLMProvider:
|
||||||
@ -26,7 +49,6 @@ def upsert_llm_provider(
|
|||||||
existing_llm_provider.model_names = llm_provider.model_names
|
existing_llm_provider.model_names = llm_provider.model_names
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
return FullLLMProvider.from_model(existing_llm_provider)
|
return FullLLMProvider.from_model(existing_llm_provider)
|
||||||
|
|
||||||
# if it does not exist, create a new entry
|
# if it does not exist, create a new entry
|
||||||
llm_provider_model = LLMProviderModel(
|
llm_provider_model = LLMProviderModel(
|
||||||
name=llm_provider.name,
|
name=llm_provider.name,
|
||||||
@ -46,10 +68,26 @@ def upsert_llm_provider(
|
|||||||
return FullLLMProvider.from_model(llm_provider_model)
|
return FullLLMProvider.from_model(llm_provider_model)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_existing_embedding_providers(
|
||||||
|
db_session: Session,
|
||||||
|
) -> list[CloudEmbeddingProviderModel]:
|
||||||
|
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||||
|
|
||||||
|
|
||||||
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_embedding_provider(
|
||||||
|
db_session: Session, provider_id: int
|
||||||
|
) -> CloudEmbeddingProviderModel | None:
|
||||||
|
return db_session.scalar(
|
||||||
|
select(CloudEmbeddingProviderModel).where(
|
||||||
|
CloudEmbeddingProviderModel.id == provider_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||||
provider_model = db_session.scalar(
|
provider_model = db_session.scalar(
|
||||||
select(LLMProviderModel).where(
|
select(LLMProviderModel).where(
|
||||||
@ -70,6 +108,16 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
|
|||||||
return FullLLMProvider.from_model(provider_model)
|
return FullLLMProvider.from_model(provider_model)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_embedding_provider(
|
||||||
|
db_session: Session, embedding_provider_name: str
|
||||||
|
) -> None:
|
||||||
|
db_session.execute(
|
||||||
|
delete(CloudEmbeddingProviderModel).where(
|
||||||
|
CloudEmbeddingProviderModel.name == embedding_provider_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||||
db_session.execute(
|
db_session.execute(
|
||||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||||
|
@ -130,6 +130,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
|||||||
chat_folders: Mapped[list["ChatFolder"]] = relationship(
|
chat_folders: Mapped[list["ChatFolder"]] = relationship(
|
||||||
"ChatFolder", back_populates="user"
|
"ChatFolder", back_populates="user"
|
||||||
)
|
)
|
||||||
|
|
||||||
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
|
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
|
||||||
# Personas owned by this user
|
# Personas owned by this user
|
||||||
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
|
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
|
||||||
@ -469,7 +470,7 @@ class Credential(Base):
|
|||||||
|
|
||||||
class EmbeddingModel(Base):
|
class EmbeddingModel(Base):
|
||||||
__tablename__ = "embedding_model"
|
__tablename__ = "embedding_model"
|
||||||
# ID is used also to indicate the order that the models are configured by the admin
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
model_name: Mapped[str] = mapped_column(String)
|
model_name: Mapped[str] = mapped_column(String)
|
||||||
model_dim: Mapped[int] = mapped_column(Integer)
|
model_dim: Mapped[int] = mapped_column(Integer)
|
||||||
@ -481,6 +482,16 @@ class EmbeddingModel(Base):
|
|||||||
)
|
)
|
||||||
index_name: Mapped[str] = mapped_column(String)
|
index_name: Mapped[str] = mapped_column(String)
|
||||||
|
|
||||||
|
# New field for cloud provider relationship
|
||||||
|
cloud_provider_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("embedding_provider.id")
|
||||||
|
)
|
||||||
|
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
|
||||||
|
"CloudEmbeddingProvider",
|
||||||
|
back_populates="embedding_models",
|
||||||
|
foreign_keys=[cloud_provider_id],
|
||||||
|
)
|
||||||
|
|
||||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||||
"IndexAttempt", back_populates="embedding_model"
|
"IndexAttempt", back_populates="embedding_model"
|
||||||
)
|
)
|
||||||
@ -500,6 +511,18 @@ class EmbeddingModel(Base):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
|
||||||
|
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api_key(self) -> str | None:
|
||||||
|
return self.cloud_provider.api_key if self.cloud_provider else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> str | None:
|
||||||
|
return self.cloud_provider.name if self.cloud_provider else None
|
||||||
|
|
||||||
|
|
||||||
class IndexAttempt(Base):
|
class IndexAttempt(Base):
|
||||||
"""
|
"""
|
||||||
@ -519,6 +542,7 @@ class IndexAttempt(Base):
|
|||||||
ForeignKey("credential.id"),
|
ForeignKey("credential.id"),
|
||||||
nullable=True,
|
nullable=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Some index attempts that run from beginning will still have this as False
|
# Some index attempts that run from beginning will still have this as False
|
||||||
# This is only for attempts that are explicitly marked as from the start via
|
# This is only for attempts that are explicitly marked as from the start via
|
||||||
# the run once API
|
# the run once API
|
||||||
@ -879,11 +903,6 @@ class ChatMessageFeedback(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
Structures, Organizational, Configurations Tables
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class LLMProvider(Base):
|
class LLMProvider(Base):
|
||||||
__tablename__ = "llm_provider"
|
__tablename__ = "llm_provider"
|
||||||
|
|
||||||
@ -912,6 +931,29 @@ class LLMProvider(Base):
|
|||||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||||
|
|
||||||
|
|
||||||
|
class CloudEmbeddingProvider(Base):
|
||||||
|
__tablename__ = "embedding_provider"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String, unique=True)
|
||||||
|
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||||
|
default_model_id: Mapped[int | None] = mapped_column(
|
||||||
|
Integer, ForeignKey("embedding_model.id"), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_models: Mapped[list["EmbeddingModel"]] = relationship(
|
||||||
|
"EmbeddingModel",
|
||||||
|
back_populates="cloud_provider",
|
||||||
|
foreign_keys="EmbeddingModel.cloud_provider_id",
|
||||||
|
)
|
||||||
|
default_model: Mapped["EmbeddingModel"] = relationship(
|
||||||
|
"EmbeddingModel", foreign_keys=[default_model_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<EmbeddingProvider(name='{self.name}')>"
|
||||||
|
|
||||||
|
|
||||||
class DocumentSet(Base):
|
class DocumentSet(Base):
|
||||||
__tablename__ = "document_set"
|
__tablename__ = "document_set"
|
||||||
|
|
||||||
@ -1194,6 +1236,7 @@ class SlackBotConfig(Base):
|
|||||||
response_type: Mapped[SlackBotResponseType] = mapped_column(
|
response_type: Mapped[SlackBotResponseType] = mapped_column(
|
||||||
Enum(SlackBotResponseType, native_enum=False), nullable=False
|
Enum(SlackBotResponseType, native_enum=False), nullable=False
|
||||||
)
|
)
|
||||||
|
|
||||||
enable_auto_filters: Mapped[bool] = mapped_column(
|
enable_auto_filters: Mapped[bool] = mapped_column(
|
||||||
Boolean, nullable=False, default=False
|
Boolean, nullable=False, default=False
|
||||||
)
|
)
|
||||||
|
@ -50,6 +50,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
|||||||
normalize: bool,
|
normalize: bool,
|
||||||
query_prefix: str | None,
|
query_prefix: str | None,
|
||||||
passage_prefix: str | None,
|
passage_prefix: str | None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
provider_type: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(model_name, normalize, query_prefix, passage_prefix)
|
super().__init__(model_name, normalize, query_prefix, passage_prefix)
|
||||||
self.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE # Currently not customizable
|
self.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE # Currently not customizable
|
||||||
@ -59,6 +61,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
|||||||
query_prefix=query_prefix,
|
query_prefix=query_prefix,
|
||||||
passage_prefix=passage_prefix,
|
passage_prefix=passage_prefix,
|
||||||
normalize=normalize,
|
normalize=normalize,
|
||||||
|
api_key=api_key,
|
||||||
|
provider_type=provider_type,
|
||||||
# The below are globally set, this flow always uses the indexing one
|
# The below are globally set, this flow always uses the indexing one
|
||||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||||
|
@ -97,13 +97,19 @@ class EmbeddingModelDetail(BaseModel):
|
|||||||
normalize: bool
|
normalize: bool
|
||||||
query_prefix: str | None
|
query_prefix: str | None
|
||||||
passage_prefix: str | None
|
passage_prefix: str | None
|
||||||
|
cloud_provider_id: int | None = None
|
||||||
|
cloud_provider_name: str | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_model(cls, embedding_model: "EmbeddingModel") -> "EmbeddingModelDetail":
|
def from_model(
|
||||||
|
cls,
|
||||||
|
embedding_model: "EmbeddingModel",
|
||||||
|
) -> "EmbeddingModelDetail":
|
||||||
return cls(
|
return cls(
|
||||||
model_name=embedding_model.model_name,
|
model_name=embedding_model.model_name,
|
||||||
model_dim=embedding_model.model_dim,
|
model_dim=embedding_model.model_dim,
|
||||||
normalize=embedding_model.normalize,
|
normalize=embedding_model.normalize,
|
||||||
query_prefix=embedding_model.query_prefix,
|
query_prefix=embedding_model.query_prefix,
|
||||||
passage_prefix=embedding_model.passage_prefix,
|
passage_prefix=embedding_model.passage_prefix,
|
||||||
|
cloud_provider_id=embedding_model.cloud_provider_id,
|
||||||
)
|
)
|
||||||
|
@ -67,6 +67,8 @@ from danswer.server.features.tool.api import admin_router as admin_tool_router
|
|||||||
from danswer.server.features.tool.api import router as tool_router
|
from danswer.server.features.tool.api import router as tool_router
|
||||||
from danswer.server.gpts.api import router as gpts_router
|
from danswer.server.gpts.api import router as gpts_router
|
||||||
from danswer.server.manage.administrative import router as admin_router
|
from danswer.server.manage.administrative import router as admin_router
|
||||||
|
from danswer.server.manage.embedding.api import admin_router as embedding_admin_router
|
||||||
|
from danswer.server.manage.embedding.api import basic_router as embedding_router
|
||||||
from danswer.server.manage.get_state import router as state_router
|
from danswer.server.manage.get_state import router as state_router
|
||||||
from danswer.server.manage.llm.api import admin_router as llm_admin_router
|
from danswer.server.manage.llm.api import admin_router as llm_admin_router
|
||||||
from danswer.server.manage.llm.api import basic_router as llm_router
|
from danswer.server.manage.llm.api import basic_router as llm_router
|
||||||
@ -247,6 +249,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
|
|
||||||
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||||
|
if db_embedding_model.cloud_provider_id is None:
|
||||||
warm_up_encoders(
|
warm_up_encoders(
|
||||||
model_name=db_embedding_model.model_name,
|
model_name=db_embedding_model.model_name,
|
||||||
normalize=db_embedding_model.normalize,
|
normalize=db_embedding_model.normalize,
|
||||||
@ -291,6 +294,8 @@ def get_application() -> FastAPI:
|
|||||||
include_router_with_global_prefix_prepended(application, settings_admin_router)
|
include_router_with_global_prefix_prepended(application, settings_admin_router)
|
||||||
include_router_with_global_prefix_prepended(application, llm_admin_router)
|
include_router_with_global_prefix_prepended(application, llm_admin_router)
|
||||||
include_router_with_global_prefix_prepended(application, llm_router)
|
include_router_with_global_prefix_prepended(application, llm_router)
|
||||||
|
include_router_with_global_prefix_prepended(application, embedding_admin_router)
|
||||||
|
include_router_with_global_prefix_prepended(application, embedding_router)
|
||||||
include_router_with_global_prefix_prepended(
|
include_router_with_global_prefix_prepended(
|
||||||
application, token_rate_limit_settings_router
|
application, token_rate_limit_settings_router
|
||||||
)
|
)
|
||||||
|
@ -168,6 +168,7 @@ def stream_answer_objects(
|
|||||||
max_tokens=max_document_tokens,
|
max_tokens=max_document_tokens,
|
||||||
use_sections=query_req.chunks_above > 0 or query_req.chunks_below > 0,
|
use_sections=query_req.chunks_above > 0 or query_req.chunks_below > 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
search_tool = SearchTool(
|
search_tool = SearchTool(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
user=user,
|
user=user,
|
||||||
|
@ -131,6 +131,8 @@ def doc_index_retrieval(
|
|||||||
query_prefix=db_embedding_model.query_prefix,
|
query_prefix=db_embedding_model.query_prefix,
|
||||||
passage_prefix=db_embedding_model.passage_prefix,
|
passage_prefix=db_embedding_model.passage_prefix,
|
||||||
normalize=db_embedding_model.normalize,
|
normalize=db_embedding_model.normalize,
|
||||||
|
api_key=db_embedding_model.api_key,
|
||||||
|
provider_type=db_embedding_model.provider_type,
|
||||||
# The below are globally set, this flow always uses the indexing one
|
# The below are globally set, this flow always uses the indexing one
|
||||||
server_host=MODEL_SERVER_HOST,
|
server_host=MODEL_SERVER_HOST,
|
||||||
server_port=MODEL_SERVER_PORT,
|
server_port=MODEL_SERVER_PORT,
|
||||||
|
@ -84,20 +84,24 @@ def build_model_server_url(
|
|||||||
class EmbeddingModel:
|
class EmbeddingModel:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
|
||||||
query_prefix: str | None,
|
|
||||||
passage_prefix: str | None,
|
|
||||||
normalize: bool,
|
|
||||||
server_host: str, # Changes depending on indexing or inference
|
server_host: str, # Changes depending on indexing or inference
|
||||||
server_port: int,
|
server_port: int,
|
||||||
|
model_name: str | None,
|
||||||
|
normalize: bool,
|
||||||
|
query_prefix: str | None,
|
||||||
|
passage_prefix: str | None,
|
||||||
|
api_key: str | None,
|
||||||
|
provider_type: str | None,
|
||||||
# The following are globals are currently not configurable
|
# The following are globals are currently not configurable
|
||||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_name = model_name
|
self.api_key = api_key
|
||||||
|
self.provider_type = provider_type
|
||||||
self.max_seq_length = max_seq_length
|
self.max_seq_length = max_seq_length
|
||||||
self.query_prefix = query_prefix
|
self.query_prefix = query_prefix
|
||||||
self.passage_prefix = passage_prefix
|
self.passage_prefix = passage_prefix
|
||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
model_server_url = build_model_server_url(server_host, server_port)
|
model_server_url = build_model_server_url(server_host, server_port)
|
||||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||||
@ -111,10 +115,13 @@ class EmbeddingModel:
|
|||||||
prefixed_texts = texts
|
prefixed_texts = texts
|
||||||
|
|
||||||
embed_request = EmbedRequest(
|
embed_request = EmbedRequest(
|
||||||
texts=prefixed_texts,
|
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
|
texts=prefixed_texts,
|
||||||
max_context_length=self.max_seq_length,
|
max_context_length=self.max_seq_length,
|
||||||
normalize_embeddings=self.normalize,
|
normalize_embeddings=self.normalize,
|
||||||
|
api_key=self.api_key,
|
||||||
|
provider_type=self.provider_type,
|
||||||
|
text_type=text_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
||||||
@ -187,6 +194,8 @@ def warm_up_encoders(
|
|||||||
passage_prefix=None,
|
passage_prefix=None,
|
||||||
server_host=model_server_host,
|
server_host=model_server_host,
|
||||||
server_port=model_server_port,
|
server_port=model_server_port,
|
||||||
|
api_key=None,
|
||||||
|
provider_type=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# First time downloading the models it may take even longer, but just in case,
|
# First time downloading the models it may take even longer, but just in case,
|
||||||
|
93
backend/danswer/server/manage/embedding/api.py
Normal file
93
backend/danswer/server/manage/embedding/api.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi import Depends
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.auth.users import current_admin_user
|
||||||
|
from danswer.db.embedding_model import get_current_db_embedding_provider
|
||||||
|
from danswer.db.engine import get_session
|
||||||
|
from danswer.db.llm import fetch_existing_embedding_providers
|
||||||
|
from danswer.db.llm import remove_embedding_provider
|
||||||
|
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||||
|
from danswer.db.models import User
|
||||||
|
from danswer.search.enums import EmbedTextType
|
||||||
|
from danswer.search.search_nlp_models import EmbeddingModel
|
||||||
|
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||||
|
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||||
|
from danswer.server.manage.embedding.models import TestEmbeddingRequest
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from shared_configs.configs import MODEL_SERVER_HOST
|
||||||
|
from shared_configs.configs import MODEL_SERVER_PORT
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
admin_router = APIRouter(prefix="/admin/embedding")
|
||||||
|
basic_router = APIRouter(prefix="/embedding")
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.post("/test-embedding")
|
||||||
|
def test_embedding_configuration(
|
||||||
|
test_llm_request: TestEmbeddingRequest,
|
||||||
|
_: User | None = Depends(current_admin_user),
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
test_model = EmbeddingModel(
|
||||||
|
server_host=MODEL_SERVER_HOST,
|
||||||
|
server_port=MODEL_SERVER_PORT,
|
||||||
|
api_key=test_llm_request.api_key,
|
||||||
|
provider_type=test_llm_request.provider,
|
||||||
|
normalize=False,
|
||||||
|
query_prefix=None,
|
||||||
|
passage_prefix=None,
|
||||||
|
model_name=None,
|
||||||
|
)
|
||||||
|
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
|
||||||
|
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=400, detail=error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.get("/embedding-provider")
|
||||||
|
def list_embedding_providers(
|
||||||
|
_: User | None = Depends(current_admin_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> list[CloudEmbeddingProvider]:
|
||||||
|
return [
|
||||||
|
CloudEmbeddingProvider.from_request(embedding_provider_model)
|
||||||
|
for embedding_provider_model in fetch_existing_embedding_providers(db_session)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.delete("/embedding-provider/{embedding_provider_name}")
|
||||||
|
def delete_embedding_provider(
|
||||||
|
embedding_provider_name: str,
|
||||||
|
_: User | None = Depends(current_admin_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
embedding_provider = get_current_db_embedding_provider(db_session=db_session)
|
||||||
|
if (
|
||||||
|
embedding_provider is not None
|
||||||
|
and embedding_provider_name == embedding_provider.name
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="You can't delete a currently active model"
|
||||||
|
)
|
||||||
|
|
||||||
|
remove_embedding_provider(db_session, embedding_provider_name)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.put("/embedding-provider")
|
||||||
|
def put_cloud_embedding_provider(
|
||||||
|
provider: CloudEmbeddingProviderCreationRequest,
|
||||||
|
_: User = Depends(current_admin_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> CloudEmbeddingProvider:
|
||||||
|
return upsert_cloud_embedding_provider(db_session, provider)
|
35
backend/danswer/server/manage/embedding/models.py
Normal file
35
backend/danswer/server/manage/embedding/models.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddingRequest(BaseModel):
|
||||||
|
provider: str
|
||||||
|
api_key: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CloudEmbeddingProvider(BaseModel):
|
||||||
|
name: str
|
||||||
|
api_key: str | None = None
|
||||||
|
default_model_id: int | None = None
|
||||||
|
id: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_request(
|
||||||
|
cls, cloud_provider_model: "CloudEmbeddingProviderModel"
|
||||||
|
) -> "CloudEmbeddingProvider":
|
||||||
|
return cls(
|
||||||
|
id=cloud_provider_model.id,
|
||||||
|
name=cloud_provider_model.name,
|
||||||
|
api_key=cloud_provider_model.api_key,
|
||||||
|
default_model_id=cloud_provider_model.default_model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CloudEmbeddingProviderCreationRequest(BaseModel):
|
||||||
|
name: str
|
||||||
|
api_key: str | None = None
|
||||||
|
default_model_id: int | None = None
|
@ -4,6 +4,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from danswer.llm.llm_provider_options import fetch_models_for_provider
|
from danswer.llm.llm_provider_options import fetch_models_for_provider
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
|||||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||||
from danswer.db.embedding_model import create_embedding_model
|
from danswer.db.embedding_model import create_embedding_model
|
||||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||||
|
from danswer.db.embedding_model import get_model_id_from_name
|
||||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||||
from danswer.db.embedding_model import update_embedding_model_status
|
from danswer.db.embedding_model import update_embedding_model_status
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
@ -38,6 +39,19 @@ def set_new_embedding_model(
|
|||||||
"""
|
"""
|
||||||
current_model = get_current_db_embedding_model(db_session)
|
current_model = get_current_db_embedding_model(db_session)
|
||||||
|
|
||||||
|
if embed_model_details.cloud_provider_name is not None:
|
||||||
|
cloud_id = get_model_id_from_name(
|
||||||
|
db_session, embed_model_details.cloud_provider_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if cloud_id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="No ID exists for given provider name",
|
||||||
|
)
|
||||||
|
|
||||||
|
embed_model_details.cloud_provider_id = cloud_id
|
||||||
|
|
||||||
if embed_model_details.model_name == current_model.model_name:
|
if embed_model_details.model_name == current_model.model_name:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
@ -1 +1,38 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from danswer.search.enums import EmbedTextType
|
||||||
|
|
||||||
|
|
||||||
MODEL_WARM_UP_STRING = "hi " * 512
|
MODEL_WARM_UP_STRING = "hi " * 512
|
||||||
|
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
|
||||||
|
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
|
||||||
|
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
||||||
|
DEFAULT_VERTEX_MODEL = "text-embedding-004"
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProvider(Enum):
|
||||||
|
OPENAI = "openai"
|
||||||
|
COHERE = "cohere"
|
||||||
|
VOYAGE = "voyage"
|
||||||
|
GOOGLE = "google"
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingModelTextType:
|
||||||
|
PROVIDER_TEXT_TYPE_MAP = {
|
||||||
|
EmbeddingProvider.COHERE: {
|
||||||
|
EmbedTextType.QUERY: "search_query",
|
||||||
|
EmbedTextType.PASSAGE: "search_document",
|
||||||
|
},
|
||||||
|
EmbeddingProvider.VOYAGE: {
|
||||||
|
EmbedTextType.QUERY: "query",
|
||||||
|
EmbedTextType.PASSAGE: "document",
|
||||||
|
},
|
||||||
|
EmbeddingProvider.GOOGLE: {
|
||||||
|
EmbedTextType.QUERY: "RETRIEVAL_QUERY",
|
||||||
|
EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
|
||||||
|
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
|
||||||
|
@ -1,12 +1,28 @@
|
|||||||
import gc
|
import gc
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import vertexai # type: ignore
|
||||||
|
import voyageai # type: ignore
|
||||||
|
from cohere import Client as CohereClient
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from google.oauth2 import service_account
|
||||||
from sentence_transformers import CrossEncoder # type: ignore
|
from sentence_transformers import CrossEncoder # type: ignore
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
|
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||||
|
from vertexai.language_models import TextEmbeddingModel # type: ignore
|
||||||
|
|
||||||
|
from danswer.search.enums import EmbedTextType
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
from model_server.constants import DEFAULT_COHERE_MODEL
|
||||||
|
from model_server.constants import DEFAULT_OPENAI_MODEL
|
||||||
|
from model_server.constants import DEFAULT_VERTEX_MODEL
|
||||||
|
from model_server.constants import DEFAULT_VOYAGE_MODEL
|
||||||
|
from model_server.constants import EmbeddingModelTextType
|
||||||
|
from model_server.constants import EmbeddingProvider
|
||||||
from model_server.constants import MODEL_WARM_UP_STRING
|
from model_server.constants import MODEL_WARM_UP_STRING
|
||||||
from model_server.utils import simple_log_function_time
|
from model_server.utils import simple_log_function_time
|
||||||
from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
|
from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
|
||||||
@ -17,6 +33,7 @@ from shared_configs.model_server_models import EmbedResponse
|
|||||||
from shared_configs.model_server_models import RerankRequest
|
from shared_configs.model_server_models import RerankRequest
|
||||||
from shared_configs.model_server_models import RerankResponse
|
from shared_configs.model_server_models import RerankResponse
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
router = APIRouter(prefix="/encoder")
|
router = APIRouter(prefix="/encoder")
|
||||||
@ -25,6 +42,117 @@ _GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
|||||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CloudEmbedding:
|
||||||
|
def __init__(self, api_key: str, provider: str, model: str | None = None):
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
# Only for Google as is needed on client setup
|
||||||
|
self.model = model
|
||||||
|
try:
|
||||||
|
self.provider = EmbeddingProvider(provider.lower())
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Unsupported provider: {provider}")
|
||||||
|
self.client = self._initialize_client()
|
||||||
|
|
||||||
|
def _initialize_client(self) -> Any:
|
||||||
|
if self.provider == EmbeddingProvider.OPENAI:
|
||||||
|
return openai.OpenAI(api_key=self.api_key)
|
||||||
|
elif self.provider == EmbeddingProvider.COHERE:
|
||||||
|
return CohereClient(api_key=self.api_key)
|
||||||
|
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||||
|
return voyageai.Client(api_key=self.api_key)
|
||||||
|
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||||
|
credentials = service_account.Credentials.from_service_account_info(
|
||||||
|
json.loads(self.api_key)
|
||||||
|
)
|
||||||
|
project_id = json.loads(self.api_key)["project_id"]
|
||||||
|
vertexai.init(project=project_id, credentials=credentials)
|
||||||
|
return TextEmbeddingModel.from_pretrained(
|
||||||
|
self.model or DEFAULT_VERTEX_MODEL
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
|
||||||
|
) -> list[list[float]]:
|
||||||
|
return [
|
||||||
|
self.embed(text=text, text_type=text_type, model=model_name)
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self, *, text: str, text_type: EmbedTextType, model: str | None = None
|
||||||
|
) -> list[float]:
|
||||||
|
logger.debug(f"Embedding text with provider: {self.provider}")
|
||||||
|
if self.provider == EmbeddingProvider.OPENAI:
|
||||||
|
return self._embed_openai(text, model)
|
||||||
|
|
||||||
|
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||||
|
|
||||||
|
if self.provider == EmbeddingProvider.COHERE:
|
||||||
|
return self._embed_cohere(text, model, embedding_type)
|
||||||
|
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||||
|
return self._embed_voyage(text, model, embedding_type)
|
||||||
|
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||||
|
return self._embed_vertex(text, model, embedding_type)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||||
|
|
||||||
|
def _embed_openai(self, text: str, model: str | None) -> list[float]:
|
||||||
|
if model is None:
|
||||||
|
model = DEFAULT_OPENAI_MODEL
|
||||||
|
|
||||||
|
response = self.client.embeddings.create(input=text, model=model)
|
||||||
|
return response.data[0].embedding
|
||||||
|
|
||||||
|
def _embed_cohere(
|
||||||
|
self, text: str, model: str | None, embedding_type: str
|
||||||
|
) -> list[float]:
|
||||||
|
if model is None:
|
||||||
|
model = DEFAULT_COHERE_MODEL
|
||||||
|
|
||||||
|
response = self.client.embed(
|
||||||
|
texts=[text],
|
||||||
|
model=model,
|
||||||
|
input_type=embedding_type,
|
||||||
|
)
|
||||||
|
return response.embeddings[0]
|
||||||
|
|
||||||
|
def _embed_voyage(
|
||||||
|
self, text: str, model: str | None, embedding_type: str
|
||||||
|
) -> list[float]:
|
||||||
|
if model is None:
|
||||||
|
model = DEFAULT_VOYAGE_MODEL
|
||||||
|
|
||||||
|
response = self.client.embed(text, model=model, input_type=embedding_type)
|
||||||
|
return response.embeddings[0]
|
||||||
|
|
||||||
|
def _embed_vertex(
|
||||||
|
self, text: str, model: str | None, embedding_type: str
|
||||||
|
) -> list[float]:
|
||||||
|
if model is None:
|
||||||
|
model = DEFAULT_VERTEX_MODEL
|
||||||
|
|
||||||
|
embedding = self.client.get_embeddings(
|
||||||
|
[
|
||||||
|
TextEmbeddingInput(
|
||||||
|
text,
|
||||||
|
embedding_type,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return embedding[0].values
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(
|
||||||
|
api_key: str, provider: str, model: str | None = None
|
||||||
|
) -> "CloudEmbedding":
|
||||||
|
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||||
|
return CloudEmbedding(api_key, provider, model)
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model(
|
def get_embedding_model(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
max_context_length: int,
|
max_context_length: int,
|
||||||
@ -78,18 +206,35 @@ def warm_up_cross_encoders() -> None:
|
|||||||
@simple_log_function_time()
|
@simple_log_function_time()
|
||||||
def embed_text(
|
def embed_text(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
model_name: str,
|
text_type: EmbedTextType,
|
||||||
|
model_name: str | None,
|
||||||
max_context_length: int,
|
max_context_length: int,
|
||||||
normalize_embeddings: bool,
|
normalize_embeddings: bool,
|
||||||
|
api_key: str | None,
|
||||||
|
provider_type: str | None,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
model = get_embedding_model(
|
if provider_type is not None:
|
||||||
|
if api_key is None:
|
||||||
|
raise RuntimeError("API key not provided for cloud model")
|
||||||
|
|
||||||
|
cloud_model = CloudEmbedding(
|
||||||
|
api_key=api_key, provider=provider_type, model=model_name
|
||||||
|
)
|
||||||
|
embeddings = cloud_model.encode(texts, model_name, text_type)
|
||||||
|
|
||||||
|
elif model_name is not None:
|
||||||
|
hosted_model = get_embedding_model(
|
||||||
model_name=model_name, max_context_length=max_context_length
|
model_name=model_name, max_context_length=max_context_length
|
||||||
)
|
)
|
||||||
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
|
embeddings = hosted_model.encode(
|
||||||
|
texts, normalize_embeddings=normalize_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
if embeddings is None:
|
||||||
|
raise RuntimeError("Embeddings were not created")
|
||||||
|
|
||||||
if not isinstance(embeddings, list):
|
if not isinstance(embeddings, list):
|
||||||
embeddings = embeddings.tolist()
|
embeddings = embeddings.tolist()
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
@ -113,6 +258,9 @@ async def process_embed_request(
|
|||||||
model_name=embed_request.model_name,
|
model_name=embed_request.model_name,
|
||||||
max_context_length=embed_request.max_context_length,
|
max_context_length=embed_request.max_context_length,
|
||||||
normalize_embeddings=embed_request.normalize_embeddings,
|
normalize_embeddings=embed_request.normalize_embeddings,
|
||||||
|
api_key=embed_request.api_key,
|
||||||
|
provider_type=embed_request.provider_type,
|
||||||
|
text_type=embed_request.text_type,
|
||||||
)
|
)
|
||||||
return EmbedResponse(embeddings=embeddings)
|
return EmbedResponse(embeddings=embeddings)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -7,3 +7,7 @@ tensorflow==2.15.0
|
|||||||
torch==2.0.1
|
torch==2.0.1
|
||||||
transformers==4.39.2
|
transformers==4.39.2
|
||||||
uvicorn==0.21.1
|
uvicorn==0.21.1
|
||||||
|
voyageai==0.2.3
|
||||||
|
openai==1.14.3
|
||||||
|
cohere==5.5.8
|
||||||
|
google-cloud-aiplatform==1.58.0
|
@ -1,12 +1,19 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from danswer.search.enums import EmbedTextType
|
||||||
|
|
||||||
|
|
||||||
class EmbedRequest(BaseModel):
|
class EmbedRequest(BaseModel):
|
||||||
# This already includes any prefixes, the text is just passed directly to the model
|
# This already includes any prefixes, the text is just passed directly to the model
|
||||||
texts: list[str]
|
texts: list[str]
|
||||||
model_name: str
|
|
||||||
|
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||||
|
model_name: str | None
|
||||||
max_context_length: int
|
max_context_length: int
|
||||||
normalize_embeddings: bool
|
normalize_embeddings: bool
|
||||||
|
api_key: str | None
|
||||||
|
provider_type: str | None
|
||||||
|
text_type: EmbedTextType
|
||||||
|
|
||||||
|
|
||||||
class EmbedResponse(BaseModel):
|
class EmbedResponse(BaseModel):
|
||||||
|
30
web/public/Cohere.svg
Normal file
30
web/public/Cohere.svg
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
<svg version="1.1" id="Layer_1" xmlns:x="ns_extend;" xmlns:i="ns_ai;" xmlns:graph="ns_graphs;" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" viewBox="0 0 75 75" style="enable-background:new 0 0 75 75;" xml:space="preserve">
|
||||||
|
<style type="text/css">
|
||||||
|
.st0{fill-rule:evenodd;clip-rule:evenodd;fill:#39594D;}
|
||||||
|
.st1{fill-rule:evenodd;clip-rule:evenodd;fill:#D18EE2;}
|
||||||
|
.st2{fill:#FF7759;}
|
||||||
|
</style>
|
||||||
|
<metadata>
|
||||||
|
<sfw xmlns="ns_sfw;">
|
||||||
|
<slices>
|
||||||
|
</slices>
|
||||||
|
<sliceSourceBounds bottomLeftOrigin="true" height="75" width="75" x="-347.6" y="0.5">
|
||||||
|
</sliceSourceBounds>
|
||||||
|
</sfw>
|
||||||
|
</metadata>
|
||||||
|
<g>
|
||||||
|
<g>
|
||||||
|
<g>
|
||||||
|
<path class="st0" d="M24.3,44.7c2,0,6-0.1,11.6-2.4c6.5-2.7,19.3-7.5,28.6-12.5c6.5-3.5,9.3-8.1,9.3-14.3C73.8,7,66.9,0,58.3,0
|
||||||
|
h-36C10,0,0,10,0,22.3S9.4,44.7,24.3,44.7z">
|
||||||
|
</path>
|
||||||
|
<path class="st1" d="M30.4,60c0-6,3.6-11.5,9.2-13.8l11.3-4.7C62.4,36.8,75,45.2,75,57.6C75,67.2,67.2,75,57.6,75l-12.3,0
|
||||||
|
C37.1,75,30.4,68.3,30.4,60z">
|
||||||
|
</path>
|
||||||
|
<path class="st2" d="M12.9,47.6L12.9,47.6C5.8,47.6,0,53.4,0,60.5v1.7C0,69.2,5.8,75,12.9,75h0c7.1,0,12.9-5.8,12.9-12.9v-1.7
|
||||||
|
C25.7,53.4,20,47.6,12.9,47.6z">
|
||||||
|
</path>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
After Width: | Height: | Size: 1.2 KiB |
BIN
web/public/Google.webp
Normal file
BIN
web/public/Google.webp
Normal file
Binary file not shown.
After Width: | Height: | Size: 6.4 KiB |
BIN
web/public/Voyage.png
Normal file
BIN
web/public/Voyage.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 14 KiB |
163
web/src/app/admin/models/embedding/CloudEmbeddingPage.tsx
Normal file
163
web/src/app/admin/models/embedding/CloudEmbeddingPage.tsx
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { Text, Title } from "@tremor/react";
|
||||||
|
|
||||||
|
import {
|
||||||
|
CloudEmbeddingProvider,
|
||||||
|
CloudEmbeddingModel,
|
||||||
|
AVAILABLE_CLOUD_PROVIDERS,
|
||||||
|
CloudEmbeddingProviderFull,
|
||||||
|
EmbeddingModelDescriptor,
|
||||||
|
} from "./components/types";
|
||||||
|
import { EmbeddingDetails } from "./page";
|
||||||
|
import { FiInfo } from "react-icons/fi";
|
||||||
|
import { HoverPopup } from "@/components/HoverPopup";
|
||||||
|
import { Dispatch, SetStateAction } from "react";
|
||||||
|
|
||||||
|
export default function CloudEmbeddingPage({
|
||||||
|
currentModel,
|
||||||
|
embeddingProviderDetails,
|
||||||
|
newEnabledProviders,
|
||||||
|
newUnenabledProviders,
|
||||||
|
setShowTentativeProvider,
|
||||||
|
setChangeCredentialsProvider,
|
||||||
|
setAlreadySelectedModel,
|
||||||
|
setShowTentativeModel,
|
||||||
|
setShowModelInQueue,
|
||||||
|
}: {
|
||||||
|
setShowModelInQueue: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||||
|
setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||||
|
currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel;
|
||||||
|
setAlreadySelectedModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||||
|
newUnenabledProviders: string[];
|
||||||
|
embeddingProviderDetails?: EmbeddingDetails[];
|
||||||
|
newEnabledProviders: string[];
|
||||||
|
selectedModel: CloudEmbeddingProvider;
|
||||||
|
|
||||||
|
// create modal functions
|
||||||
|
|
||||||
|
setShowTentativeProvider: React.Dispatch<
|
||||||
|
React.SetStateAction<CloudEmbeddingProvider | null>
|
||||||
|
>;
|
||||||
|
setChangeCredentialsProvider: React.Dispatch<
|
||||||
|
React.SetStateAction<CloudEmbeddingProvider | null>
|
||||||
|
>;
|
||||||
|
}) {
|
||||||
|
function hasNameInArray(
|
||||||
|
arr: Array<{ name: string }>,
|
||||||
|
searchName: string
|
||||||
|
): boolean {
|
||||||
|
return arr.some(
|
||||||
|
(item) => item.name.toLowerCase() === searchName.toLowerCase()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let providers: CloudEmbeddingProviderFull[] = [];
|
||||||
|
AVAILABLE_CLOUD_PROVIDERS.forEach((model, ind) => {
|
||||||
|
let temporary_model: CloudEmbeddingProviderFull = {
|
||||||
|
...model,
|
||||||
|
configured:
|
||||||
|
!newUnenabledProviders.includes(model.name) &&
|
||||||
|
(newEnabledProviders.includes(model.name) ||
|
||||||
|
(embeddingProviderDetails &&
|
||||||
|
hasNameInArray(embeddingProviderDetails, model.name))!),
|
||||||
|
};
|
||||||
|
providers.push(temporary_model);
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<Title className="mt-8">
|
||||||
|
Here are some cloud-based models to choose from.
|
||||||
|
</Title>
|
||||||
|
<Text className="mb-4">
|
||||||
|
They require API keys and run in the clouds of the respective providers.
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<div className="gap-4 mt-2 pb-10 flex content-start flex-wrap">
|
||||||
|
{providers.map((provider, ind) => (
|
||||||
|
<div
|
||||||
|
key={ind}
|
||||||
|
className="p-4 border border-border rounded-lg shadow-md bg-hover-light w-96 flex flex-col"
|
||||||
|
>
|
||||||
|
<div className="font-bold text-neutral-900 text-lg items-center py-1 gap-x-2 flex">
|
||||||
|
{provider.icon({ size: 40 })}
|
||||||
|
<p className="my-auto">{provider.name}</p>
|
||||||
|
<button
|
||||||
|
onClick={() => {
|
||||||
|
setShowTentativeProvider(provider);
|
||||||
|
}}
|
||||||
|
className="cursor-pointer ml-auto"
|
||||||
|
>
|
||||||
|
<a className="my-auto hover:underline cursor-pointer">
|
||||||
|
<HoverPopup
|
||||||
|
mainContent={
|
||||||
|
<FiInfo className="cusror-pointer" size={20} />
|
||||||
|
}
|
||||||
|
popupContent={
|
||||||
|
<div className="text-sm text-neutral-800 w-52 flex">
|
||||||
|
<div className="flex mx-auto">
|
||||||
|
<div className="my-auto">{provider.description}</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
direction="left-top"
|
||||||
|
style="dark"
|
||||||
|
/>
|
||||||
|
</a>
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
{provider.embedding_models.map((model, index) => {
|
||||||
|
const enabled = model.model_name == currentModel.model_name;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
key={index}
|
||||||
|
className={`p-3 my-2 border-2 border-neutral-300 border-opacity-40 rounded-md rounded cursor-pointer
|
||||||
|
${!provider.configured ? "opacity-80 hover:opacity-100" : enabled ? "bg-background-stronger" : "hover:bg-background-strong"}`}
|
||||||
|
onClick={() => {
|
||||||
|
if (enabled) {
|
||||||
|
setAlreadySelectedModel(model);
|
||||||
|
} else if (provider.configured) {
|
||||||
|
setShowTentativeModel(model);
|
||||||
|
} else {
|
||||||
|
setShowModelInQueue(model);
|
||||||
|
setShowTentativeProvider(provider);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<div className="font-medium text-sm">
|
||||||
|
{model.model_name}
|
||||||
|
</div>
|
||||||
|
<p className="text-sm flex-none">
|
||||||
|
${model.pricePerMillion}/M tokens
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div className="text-sm text-gray-600">
|
||||||
|
{model.description}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
onClick={() => {
|
||||||
|
if (!provider.configured) {
|
||||||
|
setShowTentativeProvider(provider);
|
||||||
|
} else {
|
||||||
|
setChangeCredentialsProvider(provider);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
className="hover:underline mb-1 text-sm mr-auto cursor-pointer"
|
||||||
|
>
|
||||||
|
{provider.configured && "Modify credentials"}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
@ -1,74 +0,0 @@
|
|||||||
import { Modal } from "@/components/Modal";
|
|
||||||
import { Button, Text, Callout } from "@tremor/react";
|
|
||||||
import { EmbeddingModelDescriptor } from "./embeddingModels";
|
|
||||||
|
|
||||||
export function ModelSelectionConfirmaion({
|
|
||||||
selectedModel,
|
|
||||||
isCustom,
|
|
||||||
onConfirm,
|
|
||||||
}: {
|
|
||||||
selectedModel: EmbeddingModelDescriptor;
|
|
||||||
isCustom: boolean;
|
|
||||||
onConfirm: () => void;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<div className="mb-4">
|
|
||||||
<Text className="text-lg mb-4">
|
|
||||||
You have selected: <b>{selectedModel.model_name}</b>. Are you sure you
|
|
||||||
want to update to this new embedding model?
|
|
||||||
</Text>
|
|
||||||
<Text className="text-lg mb-2">
|
|
||||||
We will re-index all your documents in the background so you will be
|
|
||||||
able to continue to use Danswer as normal with the old model in the
|
|
||||||
meantime. Depending on how many documents you have indexed, this may
|
|
||||||
take a while.
|
|
||||||
</Text>
|
|
||||||
<Text className="text-lg mb-2">
|
|
||||||
<i>NOTE:</i> this re-indexing process will consume more resources than
|
|
||||||
normal. If you are self-hosting, we recommend that you allocate at least
|
|
||||||
16GB of RAM to Danswer during this process.
|
|
||||||
</Text>
|
|
||||||
|
|
||||||
{isCustom && (
|
|
||||||
<Callout title="IMPORTANT" color="yellow" className="mt-4">
|
|
||||||
We've detected that this is a custom-specified embedding model.
|
|
||||||
Since we have to download the model files before verifying the
|
|
||||||
configuration's correctness, we won't be able to let you
|
|
||||||
know if the configuration is valid until <b>after</b> we start
|
|
||||||
re-indexing your documents. If there is an issue, it will show up on
|
|
||||||
this page as an indexing error on this page after clicking Confirm.
|
|
||||||
</Callout>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<div className="flex mt-8">
|
|
||||||
<Button className="mx-auto" color="green" onClick={onConfirm}>
|
|
||||||
Confirm
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function ModelSelectionConfirmaionModal({
|
|
||||||
selectedModel,
|
|
||||||
isCustom,
|
|
||||||
onConfirm,
|
|
||||||
onCancel,
|
|
||||||
}: {
|
|
||||||
selectedModel: EmbeddingModelDescriptor;
|
|
||||||
isCustom: boolean;
|
|
||||||
onConfirm: () => void;
|
|
||||||
onCancel: () => void;
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<Modal title="Update Embedding Model" onOutsideClick={onCancel}>
|
|
||||||
<div>
|
|
||||||
<ModelSelectionConfirmaion
|
|
||||||
selectedModel={selectedModel}
|
|
||||||
isCustom={isCustom}
|
|
||||||
onConfirm={onConfirm}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</Modal>
|
|
||||||
);
|
|
||||||
}
|
|
55
web/src/app/admin/models/embedding/OpenEmbeddingPage.tsx
Normal file
55
web/src/app/admin/models/embedding/OpenEmbeddingPage.tsx
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
"use client";
|
||||||
|
import { Card, Text, Title } from "@tremor/react";
|
||||||
|
import { ModelSelector } from "./components/ModelSelector";
|
||||||
|
import {
|
||||||
|
AVAILABLE_MODELS,
|
||||||
|
EmbeddingModelDescriptor,
|
||||||
|
HostedEmbeddingModel,
|
||||||
|
} from "./components/types";
|
||||||
|
import { CustomModelForm } from "./components/CustomModelForm";
|
||||||
|
|
||||||
|
export default function OpenEmbeddingPage({
|
||||||
|
onSelectOpenSource,
|
||||||
|
currentModelName,
|
||||||
|
}: {
|
||||||
|
currentModelName: string;
|
||||||
|
onSelectOpenSource: (model: HostedEmbeddingModel) => Promise<void>;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<ModelSelector
|
||||||
|
modelOptions={AVAILABLE_MODELS.filter(
|
||||||
|
(modelOption) => modelOption.model_name !== currentModelName
|
||||||
|
)}
|
||||||
|
setSelectedModel={onSelectOpenSource}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Text className="mt-6">
|
||||||
|
Alternatively, (if you know what you're doing) you can specify a{" "}
|
||||||
|
<a target="_blank" href="https://www.sbert.net/" className="text-link">
|
||||||
|
SentenceTransformers
|
||||||
|
</a>
|
||||||
|
-compatible model of your choice below. The rough list of supported
|
||||||
|
models can be found{" "}
|
||||||
|
<a
|
||||||
|
target="_blank"
|
||||||
|
href="https://huggingface.co/models?library=sentence-transformers&sort=trending"
|
||||||
|
className="text-link"
|
||||||
|
>
|
||||||
|
here
|
||||||
|
</a>
|
||||||
|
.
|
||||||
|
<br />
|
||||||
|
<b>NOTE:</b> not all models listed will work with Danswer, since some
|
||||||
|
have unique interfaces or special requirements. If in doubt, reach out
|
||||||
|
to the Danswer team.
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<div className="w-full flex">
|
||||||
|
<Card className="mt-4 2xl:w-4/6 mx-auto">
|
||||||
|
<CustomModelForm onSubmit={onSelectOpenSource} />
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
@ -2,16 +2,15 @@ import {
|
|||||||
BooleanFormField,
|
BooleanFormField,
|
||||||
TextFormField,
|
TextFormField,
|
||||||
} from "@/components/admin/connectors/Field";
|
} from "@/components/admin/connectors/Field";
|
||||||
import { Button, Divider, Text } from "@tremor/react";
|
import { Button } from "@tremor/react";
|
||||||
import { Form, Formik } from "formik";
|
import { Form, Formik } from "formik";
|
||||||
|
|
||||||
import * as Yup from "yup";
|
import * as Yup from "yup";
|
||||||
import { EmbeddingModelDescriptor } from "./embeddingModels";
|
import { EmbeddingModelDescriptor, HostedEmbeddingModel } from "./types";
|
||||||
|
|
||||||
export function CustomModelForm({
|
export function CustomModelForm({
|
||||||
onSubmit,
|
onSubmit,
|
||||||
}: {
|
}: {
|
||||||
onSubmit: (model: EmbeddingModelDescriptor) => void;
|
onSubmit: (model: HostedEmbeddingModel) => void;
|
||||||
}) {
|
}) {
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
@ -21,6 +20,7 @@ export function CustomModelForm({
|
|||||||
model_dim: "",
|
model_dim: "",
|
||||||
query_prefix: "",
|
query_prefix: "",
|
||||||
passage_prefix: "",
|
passage_prefix: "",
|
||||||
|
description: "",
|
||||||
normalize: true,
|
normalize: true,
|
||||||
}}
|
}}
|
||||||
validationSchema={Yup.object().shape({
|
validationSchema={Yup.object().shape({
|
||||||
@ -62,6 +62,13 @@ export function CustomModelForm({
|
|||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
<TextFormField
|
||||||
|
name="description"
|
||||||
|
label="Description:"
|
||||||
|
subtext="Description of your model"
|
||||||
|
placeholder=""
|
||||||
|
autoCompleteDisabled={true}
|
||||||
|
/>
|
||||||
|
|
||||||
<TextFormField
|
<TextFormField
|
||||||
name="query_prefix"
|
name="query_prefix"
|
||||||
@ -77,7 +84,6 @@ export function CustomModelForm({
|
|||||||
placeholder="E.g. 'query: '"
|
placeholder="E.g. 'query: '"
|
||||||
autoCompleteDisabled={true}
|
autoCompleteDisabled={true}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<TextFormField
|
<TextFormField
|
||||||
name="passage_prefix"
|
name="passage_prefix"
|
||||||
label="[Optional] Passage Prefix:"
|
label="[Optional] Passage Prefix:"
|
@ -1,18 +1,29 @@
|
|||||||
import { DefaultDropdown, StringOrNumberOption } from "@/components/Dropdown";
|
import { EmbeddingModelDescriptor, HostedEmbeddingModel } from "./types";
|
||||||
import { Title, Text, Divider, Card } from "@tremor/react";
|
|
||||||
import {
|
|
||||||
EmbeddingModelDescriptor,
|
|
||||||
FullEmbeddingModelDescriptor,
|
|
||||||
} from "./embeddingModels";
|
|
||||||
import { FiStar } from "react-icons/fi";
|
import { FiStar } from "react-icons/fi";
|
||||||
import { CustomModelForm } from "./CustomModelForm";
|
|
||||||
|
export function ModelPreview({ model }: { model: EmbeddingModelDescriptor }) {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={
|
||||||
|
"p-2 border border-border rounded shadow-md bg-hover-light w-96 flex flex-col"
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<div className="font-bold text-lg flex">{model.model_name}</div>
|
||||||
|
<div className="text-sm mt-1 mx-1">
|
||||||
|
{model.description
|
||||||
|
? model.description
|
||||||
|
: "Custom model—no description is available."}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export function ModelOption({
|
export function ModelOption({
|
||||||
model,
|
model,
|
||||||
onSelect,
|
onSelect,
|
||||||
}: {
|
}: {
|
||||||
model: FullEmbeddingModelDescriptor;
|
model: HostedEmbeddingModel;
|
||||||
onSelect?: (model: EmbeddingModelDescriptor) => void;
|
onSelect?: (model: HostedEmbeddingModel) => void;
|
||||||
}) {
|
}) {
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
@ -68,8 +79,8 @@ export function ModelSelector({
|
|||||||
modelOptions,
|
modelOptions,
|
||||||
setSelectedModel,
|
setSelectedModel,
|
||||||
}: {
|
}: {
|
||||||
modelOptions: FullEmbeddingModelDescriptor[];
|
modelOptions: HostedEmbeddingModel[];
|
||||||
setSelectedModel: (model: EmbeddingModelDescriptor) => void;
|
setSelectedModel: (model: HostedEmbeddingModel) => void;
|
||||||
}) {
|
}) {
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
286
web/src/app/admin/models/embedding/components/types.ts
Normal file
286
web/src/app/admin/models/embedding/components/types.ts
Normal file
@ -0,0 +1,286 @@
|
|||||||
|
import {
|
||||||
|
CohereIcon,
|
||||||
|
GoogleIcon,
|
||||||
|
IconProps,
|
||||||
|
OpenAIIcon,
|
||||||
|
VoyageIcon,
|
||||||
|
} from "@/components/icons/icons";
|
||||||
|
|
||||||
|
// Cloud Provider (not needed for hosted ones)
|
||||||
|
|
||||||
|
export interface CloudEmbeddingProvider {
|
||||||
|
id: number;
|
||||||
|
name: string;
|
||||||
|
api_key?: string;
|
||||||
|
custom_config?: Record<string, string>;
|
||||||
|
docsLink?: string;
|
||||||
|
|
||||||
|
// Frontend-specific properties
|
||||||
|
website: string;
|
||||||
|
icon: ({ size, className }: IconProps) => JSX.Element;
|
||||||
|
description: string;
|
||||||
|
apiLink: string;
|
||||||
|
costslink?: string;
|
||||||
|
|
||||||
|
// Relationships
|
||||||
|
embedding_models: CloudEmbeddingModel[];
|
||||||
|
default_model?: CloudEmbeddingModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embedding Models
|
||||||
|
export interface EmbeddingModelDescriptor {
|
||||||
|
model_name: string;
|
||||||
|
model_dim: number;
|
||||||
|
normalize: boolean;
|
||||||
|
query_prefix: string;
|
||||||
|
passage_prefix: string;
|
||||||
|
cloud_provider_name?: string | null;
|
||||||
|
description: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CloudEmbeddingModel extends EmbeddingModelDescriptor {
|
||||||
|
cloud_provider_name: string | null;
|
||||||
|
pricePerMillion: number;
|
||||||
|
enabled?: boolean;
|
||||||
|
mtebScore: number;
|
||||||
|
maxContext: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface HostedEmbeddingModel extends EmbeddingModelDescriptor {
|
||||||
|
link?: string;
|
||||||
|
model_dim: number;
|
||||||
|
normalize: boolean;
|
||||||
|
query_prefix: string;
|
||||||
|
passage_prefix: string;
|
||||||
|
isDefault?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Responses
|
||||||
|
export interface FullEmbeddingModelResponse {
|
||||||
|
current_model_name: string;
|
||||||
|
secondary_model_name: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CloudEmbeddingProviderFull extends CloudEmbeddingProvider {
|
||||||
|
configured: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||||
|
{
|
||||||
|
model_name: "intfloat/e5-base-v2",
|
||||||
|
model_dim: 768,
|
||||||
|
normalize: true,
|
||||||
|
description:
|
||||||
|
"The recommended default for most situations. If you aren't sure which model to use, this is probably the one.",
|
||||||
|
isDefault: true,
|
||||||
|
link: "https://huggingface.co/intfloat/e5-base-v2",
|
||||||
|
query_prefix: "query: ",
|
||||||
|
passage_prefix: "passage: ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: "intfloat/e5-small-v2",
|
||||||
|
model_dim: 384,
|
||||||
|
normalize: true,
|
||||||
|
description:
|
||||||
|
"A smaller / faster version of the default model. If you're running Danswer on a resource constrained system, then this is a good choice.",
|
||||||
|
link: "https://huggingface.co/intfloat/e5-small-v2",
|
||||||
|
query_prefix: "query: ",
|
||||||
|
passage_prefix: "passage: ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: "intfloat/multilingual-e5-base",
|
||||||
|
model_dim: 768,
|
||||||
|
normalize: true,
|
||||||
|
description:
|
||||||
|
"If you have many documents in other languages besides English, this is the one to go for.",
|
||||||
|
link: "https://huggingface.co/intfloat/multilingual-e5-base",
|
||||||
|
query_prefix: "query: ",
|
||||||
|
passage_prefix: "passage: ",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: "intfloat/multilingual-e5-small",
|
||||||
|
model_dim: 384,
|
||||||
|
normalize: true,
|
||||||
|
description:
|
||||||
|
"If you have many documents in other languages besides English, and you're running on a resource constrained system, then this is the one to go for.",
|
||||||
|
link: "https://huggingface.co/intfloat/multilingual-e5-base",
|
||||||
|
query_prefix: "query: ",
|
||||||
|
passage_prefix: "passage: ",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||||
|
{
|
||||||
|
id: 0,
|
||||||
|
name: "OpenAI",
|
||||||
|
website: "https://openai.com",
|
||||||
|
icon: OpenAIIcon,
|
||||||
|
description: "AI industry leader known for ChatGPT and DALL-E",
|
||||||
|
apiLink: "https://platform.openai.com/api-keys",
|
||||||
|
docsLink:
|
||||||
|
"https://docs.danswer.dev/guides/embedding_providers#openai-models",
|
||||||
|
costslink: "https://openai.com/pricing",
|
||||||
|
embedding_models: [
|
||||||
|
{
|
||||||
|
model_name: "text-embedding-3-large",
|
||||||
|
cloud_provider_name: "OpenAI",
|
||||||
|
description:
|
||||||
|
"OpenAI's large embedding model. Best performance, but more expensive.",
|
||||||
|
pricePerMillion: 0.13,
|
||||||
|
model_dim: 3072,
|
||||||
|
normalize: false,
|
||||||
|
query_prefix: "",
|
||||||
|
passage_prefix: "",
|
||||||
|
mtebScore: 64.6,
|
||||||
|
maxContext: 8191,
|
||||||
|
enabled: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: "text-embedding-3-small",
|
||||||
|
cloud_provider_name: "OpenAI",
|
||||||
|
model_dim: 1536,
|
||||||
|
normalize: false,
|
||||||
|
query_prefix: "",
|
||||||
|
passage_prefix: "",
|
||||||
|
description:
|
||||||
|
"OpenAI's newer, more efficient embedding model. Good balance of performance and cost.",
|
||||||
|
pricePerMillion: 0.02,
|
||||||
|
enabled: false,
|
||||||
|
mtebScore: 62.3,
|
||||||
|
maxContext: 8191,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 1,
|
||||||
|
name: "Cohere",
|
||||||
|
website: "https://cohere.ai",
|
||||||
|
icon: CohereIcon,
|
||||||
|
docsLink:
|
||||||
|
"https://docs.danswer.dev/guides/embedding_providers#cohere-models",
|
||||||
|
description:
|
||||||
|
"AI company specializing in NLP models for various text-based tasks",
|
||||||
|
apiLink: "https://dashboard.cohere.ai/api-keys",
|
||||||
|
costslink: "https://cohere.com/pricing",
|
||||||
|
embedding_models: [
|
||||||
|
{
|
||||||
|
model_name: "embed-english-v3.0",
|
||||||
|
cloud_provider_name: "Cohere",
|
||||||
|
description:
|
||||||
|
"Cohere's English embedding model. Good performance for English-language tasks.",
|
||||||
|
pricePerMillion: 0.1,
|
||||||
|
mtebScore: 64.5,
|
||||||
|
maxContext: 512,
|
||||||
|
enabled: false,
|
||||||
|
model_dim: 1024,
|
||||||
|
normalize: false,
|
||||||
|
query_prefix: "",
|
||||||
|
passage_prefix: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
model_name: "embed-english-light-v3.0",
|
||||||
|
cloud_provider_name: "Cohere",
|
||||||
|
description:
|
||||||
|
"Cohere's lightweight English embedding model. Faster and more efficient for simpler tasks.",
|
||||||
|
pricePerMillion: 0.1,
|
||||||
|
mtebScore: 62,
|
||||||
|
maxContext: 512,
|
||||||
|
enabled: false,
|
||||||
|
model_dim: 384,
|
||||||
|
normalize: false,
|
||||||
|
query_prefix: "",
|
||||||
|
passage_prefix: "",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
id: 2,
|
||||||
|
name: "Google",
|
||||||
|
website: "https://ai.google",
|
||||||
|
icon: GoogleIcon,
|
||||||
|
docsLink:
|
||||||
|
"https://docs.danswer.dev/guides/embedding_providers#vertex-ai-google-model",
|
||||||
|
description:
|
||||||
|
"Offers a wide range of AI services including language and vision models",
|
||||||
|
apiLink: "https://console.cloud.google.com/apis/credentials",
|
||||||
|
costslink: "https://cloud.google.com/vertex-ai/pricing",
|
||||||
|
embedding_models: [
|
||||||
|
{
|
||||||
|
cloud_provider_name: "Google",
|
||||||
|
model_name: "text-embedding-004",
|
||||||
|
description: "Google's most recent text embedding model.",
|
||||||
|
pricePerMillion: 0.025,
|
||||||
|
mtebScore: 66.31,
|
||||||
|
maxContext: 2048,
|
||||||
|
enabled: false,
|
||||||
|
model_dim: 768,
|
||||||
|
normalize: false,
|
||||||
|
query_prefix: "",
|
||||||
|
passage_prefix: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
cloud_provider_name: "Google",
|
||||||
|
model_name: "textembedding-gecko@003",
|
||||||
|
description: "Google's Gecko embedding model. Powerful and efficient.",
|
||||||
|
pricePerMillion: 0.025,
|
||||||
|
mtebScore: 66.31,
|
||||||
|
maxContext: 2048,
|
||||||
|
enabled: false,
|
||||||
|
model_dim: 768,
|
||||||
|
normalize: false,
|
||||||
|
query_prefix: "",
|
||||||
|
passage_prefix: "",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 3,
|
||||||
|
name: "Voyage",
|
||||||
|
website: "https://www.voyageai.com",
|
||||||
|
icon: VoyageIcon,
|
||||||
|
description: "Advanced NLP research startup born from Stanford AI Labs",
|
||||||
|
docsLink:
|
||||||
|
"https://docs.danswer.dev/guides/embedding_providers#voyage-models",
|
||||||
|
apiLink: "https://www.voyageai.com/dashboard",
|
||||||
|
costslink: "https://www.voyageai.com/pricing",
|
||||||
|
embedding_models: [
|
||||||
|
{
|
||||||
|
cloud_provider_name: "Voyage",
|
||||||
|
model_name: "voyage-large-2-instruct",
|
||||||
|
description:
|
||||||
|
"Voyage's large embedding model. High performance with instruction fine-tuning.",
|
||||||
|
pricePerMillion: 0.12,
|
||||||
|
mtebScore: 68.28,
|
||||||
|
maxContext: 4000,
|
||||||
|
enabled: false,
|
||||||
|
model_dim: 1024,
|
||||||
|
normalize: false,
|
||||||
|
query_prefix: "",
|
||||||
|
passage_prefix: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
cloud_provider_name: "Voyage",
|
||||||
|
model_name: "voyage-light-2-instruct",
|
||||||
|
description:
|
||||||
|
"Voyage's lightweight embedding model. Good balance of performance and efficiency.",
|
||||||
|
pricePerMillion: 0.12,
|
||||||
|
mtebScore: 67.13,
|
||||||
|
maxContext: 16000,
|
||||||
|
enabled: false,
|
||||||
|
model_dim: 1024,
|
||||||
|
normalize: false,
|
||||||
|
query_prefix: "",
|
||||||
|
passage_prefix: "",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
export const INVALID_OLD_MODEL = "thenlper/gte-small";
|
||||||
|
|
||||||
|
export function checkModelNameIsValid(
|
||||||
|
modelName: string | undefined | null
|
||||||
|
): boolean {
|
||||||
|
return !!modelName && modelName !== INVALID_OLD_MODEL;
|
||||||
|
}
|
@ -1,87 +0,0 @@
|
|||||||
export interface EmbeddingModelResponse {
|
|
||||||
model_name: string | null;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface FullEmbeddingModelResponse {
|
|
||||||
current_model_name: string;
|
|
||||||
secondary_model_name: string | null;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface EmbeddingModelDescriptor {
|
|
||||||
model_name: string;
|
|
||||||
model_dim: number;
|
|
||||||
normalize: boolean;
|
|
||||||
query_prefix?: string;
|
|
||||||
passage_prefix?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface FullEmbeddingModelDescriptor extends EmbeddingModelDescriptor {
|
|
||||||
description: string;
|
|
||||||
isDefault?: boolean;
|
|
||||||
link?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export const AVAILABLE_MODELS: FullEmbeddingModelDescriptor[] = [
|
|
||||||
{
|
|
||||||
model_name: "intfloat/e5-base-v2",
|
|
||||||
model_dim: 768,
|
|
||||||
normalize: true,
|
|
||||||
description:
|
|
||||||
"The recommended default for most situations. If you aren't sure which model to use, this is probably the one.",
|
|
||||||
isDefault: true,
|
|
||||||
link: "https://huggingface.co/intfloat/e5-base-v2",
|
|
||||||
query_prefix: "query: ",
|
|
||||||
passage_prefix: "passage: ",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
model_name: "intfloat/e5-small-v2",
|
|
||||||
model_dim: 384,
|
|
||||||
normalize: true,
|
|
||||||
description:
|
|
||||||
"A smaller / faster version of the default model. If you're running Danswer on a resource constrained system, then this is a good choice.",
|
|
||||||
link: "https://huggingface.co/intfloat/e5-small-v2",
|
|
||||||
query_prefix: "query: ",
|
|
||||||
passage_prefix: "passage: ",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
model_name: "intfloat/multilingual-e5-base",
|
|
||||||
model_dim: 768,
|
|
||||||
normalize: true,
|
|
||||||
description:
|
|
||||||
"If you have many documents in other languages besides English, this is the one to go for.",
|
|
||||||
link: "https://huggingface.co/intfloat/multilingual-e5-base",
|
|
||||||
query_prefix: "query: ",
|
|
||||||
passage_prefix: "passage: ",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
model_name: "intfloat/multilingual-e5-small",
|
|
||||||
model_dim: 384,
|
|
||||||
normalize: true,
|
|
||||||
description:
|
|
||||||
"If you have many documents in other languages besides English, and you're running on a resource constrained system, then this is the one to go for.",
|
|
||||||
link: "https://huggingface.co/intfloat/multilingual-e5-base",
|
|
||||||
query_prefix: "query: ",
|
|
||||||
passage_prefix: "passage: ",
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
export const INVALID_OLD_MODEL = "thenlper/gte-small";
|
|
||||||
|
|
||||||
export function checkModelNameIsValid(modelName: string | undefined | null) {
|
|
||||||
if (!modelName) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (modelName === INVALID_OLD_MODEL) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function fillOutEmeddingModelDescriptor(
|
|
||||||
embeddingModel: EmbeddingModelDescriptor | FullEmbeddingModelDescriptor
|
|
||||||
): FullEmbeddingModelDescriptor {
|
|
||||||
return {
|
|
||||||
...embeddingModel,
|
|
||||||
description: "",
|
|
||||||
};
|
|
||||||
}
|
|
@ -0,0 +1,31 @@
|
|||||||
|
import React from "react";
|
||||||
|
import { Modal } from "@/components/Modal";
|
||||||
|
import { Button, Text } from "@tremor/react";
|
||||||
|
|
||||||
|
import { CloudEmbeddingModel } from "../components/types";
|
||||||
|
|
||||||
|
export function AlreadyPickedModal({
|
||||||
|
model,
|
||||||
|
onClose,
|
||||||
|
}: {
|
||||||
|
model: CloudEmbeddingModel;
|
||||||
|
onClose: () => void;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
title={`${model.model_name} already chosen`}
|
||||||
|
onOutsideClick={onClose}
|
||||||
|
>
|
||||||
|
<div className="mb-4">
|
||||||
|
<Text className="text-sm mb-2">
|
||||||
|
You can select a different one if you want!
|
||||||
|
</Text>
|
||||||
|
<div className="flex mt-8 justify-between">
|
||||||
|
<Button color="blue" onClick={onClose}>
|
||||||
|
Close
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,244 @@
|
|||||||
|
import React, { useRef, useState } from "react";
|
||||||
|
import { Modal } from "@/components/Modal";
|
||||||
|
import { Button, Text, Callout, Subtitle, Divider } from "@tremor/react";
|
||||||
|
import { Label, TextFormField } from "@/components/admin/connectors/Field";
|
||||||
|
import { CloudEmbeddingProvider } from "../components/types";
|
||||||
|
import {
|
||||||
|
EMBEDDING_PROVIDERS_ADMIN_URL,
|
||||||
|
LLM_PROVIDERS_ADMIN_URL,
|
||||||
|
} from "../../llm/constants";
|
||||||
|
import { mutate } from "swr";
|
||||||
|
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||||
|
import { Field } from "formik";
|
||||||
|
|
||||||
|
export function ChangeCredentialsModal({
|
||||||
|
provider,
|
||||||
|
onConfirm,
|
||||||
|
onCancel,
|
||||||
|
onDeleted,
|
||||||
|
useFileUpload,
|
||||||
|
}: {
|
||||||
|
provider: CloudEmbeddingProvider;
|
||||||
|
onConfirm: () => void;
|
||||||
|
onCancel: () => void;
|
||||||
|
onDeleted: () => void;
|
||||||
|
useFileUpload: boolean;
|
||||||
|
}) {
|
||||||
|
const [apiKey, setApiKey] = useState("");
|
||||||
|
const [testError, setTestError] = useState<string>("");
|
||||||
|
const [fileName, setFileName] = useState<string>("");
|
||||||
|
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||||
|
const [isProcessing, setIsProcessing] = useState(false);
|
||||||
|
const [deletionError, setDeletionError] = useState<string>("");
|
||||||
|
|
||||||
|
const clearFileInput = () => {
|
||||||
|
setFileName("");
|
||||||
|
if (fileInputRef.current) {
|
||||||
|
fileInputRef.current.value = "";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleFileUpload = async (
|
||||||
|
event: React.ChangeEvent<HTMLInputElement>
|
||||||
|
) => {
|
||||||
|
const file = event.target.files?.[0];
|
||||||
|
setFileName("");
|
||||||
|
|
||||||
|
if (file) {
|
||||||
|
setFileName(file.name);
|
||||||
|
try {
|
||||||
|
setDeletionError("");
|
||||||
|
const fileContent = await file.text();
|
||||||
|
let jsonContent;
|
||||||
|
try {
|
||||||
|
jsonContent = JSON.parse(fileContent);
|
||||||
|
setApiKey(JSON.stringify(jsonContent));
|
||||||
|
} catch (parseError) {
|
||||||
|
throw new Error(
|
||||||
|
"Failed to parse JSON file. Please ensure it's a valid JSON."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
setTestError(
|
||||||
|
error instanceof Error
|
||||||
|
? error.message
|
||||||
|
: "An unknown error occurred while processing the file."
|
||||||
|
);
|
||||||
|
setApiKey("");
|
||||||
|
clearFileInput();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleDelete = async () => {
|
||||||
|
setDeletionError("");
|
||||||
|
setIsProcessing(true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch(
|
||||||
|
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.name}`,
|
||||||
|
{
|
||||||
|
method: "DELETE",
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorData = await response.json();
|
||||||
|
setDeletionError(errorData.detail);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||||
|
onDeleted();
|
||||||
|
} catch (error) {
|
||||||
|
setDeletionError(
|
||||||
|
error instanceof Error ? error.message : "An unknown error occurred"
|
||||||
|
);
|
||||||
|
} finally {
|
||||||
|
setIsProcessing(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSubmit = async () => {
|
||||||
|
setTestError("");
|
||||||
|
|
||||||
|
try {
|
||||||
|
const body = JSON.stringify({
|
||||||
|
api_key: apiKey,
|
||||||
|
provider: provider.name.toLowerCase().split(" ")[0],
|
||||||
|
default_model_id: provider.name,
|
||||||
|
});
|
||||||
|
|
||||||
|
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({
|
||||||
|
provider: provider.name.toLowerCase().split(" ")[0],
|
||||||
|
api_key: apiKey,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!testResponse.ok) {
|
||||||
|
const errorMsg = (await testResponse.json()).detail;
|
||||||
|
throw new Error(errorMsg);
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateResponse = await fetch(EMBEDDING_PROVIDERS_ADMIN_URL, {
|
||||||
|
method: "PUT",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({
|
||||||
|
name: provider.name,
|
||||||
|
api_key: apiKey,
|
||||||
|
is_default_provider: false,
|
||||||
|
is_configured: true,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!updateResponse.ok) {
|
||||||
|
const errorData = await updateResponse.json();
|
||||||
|
throw new Error(
|
||||||
|
errorData.detail || "Failed to update provider- check your API key"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
onConfirm();
|
||||||
|
} catch (error) {
|
||||||
|
setTestError(
|
||||||
|
error instanceof Error ? error.message : "An unknown error occurred"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
icon={provider.icon}
|
||||||
|
title={`Modify your ${provider.name} key`}
|
||||||
|
onOutsideClick={onCancel}
|
||||||
|
>
|
||||||
|
<div className="mb-4">
|
||||||
|
<Subtitle className="mt-4 font-bold text-lg mb-2">
|
||||||
|
Want to swap out your key?
|
||||||
|
</Subtitle>
|
||||||
|
|
||||||
|
<div className="flex flex-col gap-y-2">
|
||||||
|
{useFileUpload ? (
|
||||||
|
<>
|
||||||
|
<Label>Upload JSON File</Label>
|
||||||
|
<input
|
||||||
|
ref={fileInputRef}
|
||||||
|
type="file"
|
||||||
|
accept=".json"
|
||||||
|
onChange={handleFileUpload}
|
||||||
|
className="text-lg w-full p-1"
|
||||||
|
/>
|
||||||
|
{fileName && <p>Uploaded file: {fileName}</p>}
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<div className="flex gap-x-2 items-center">
|
||||||
|
<Label>New API Key</Label>
|
||||||
|
</div>
|
||||||
|
<input
|
||||||
|
className={`
|
||||||
|
border
|
||||||
|
border-border
|
||||||
|
rounded
|
||||||
|
w-full
|
||||||
|
py-2
|
||||||
|
px-3
|
||||||
|
mt-1
|
||||||
|
bg-background-emphasis
|
||||||
|
`}
|
||||||
|
value={apiKey}
|
||||||
|
onChange={(e: any) => setApiKey(e.target.value)}
|
||||||
|
placeholder="Paste your API key here"
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
<a
|
||||||
|
href={provider.apiLink}
|
||||||
|
target="_blank"
|
||||||
|
rel="noopener noreferrer"
|
||||||
|
className="underline cursor-pointer"
|
||||||
|
>
|
||||||
|
Visit API
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{testError && (
|
||||||
|
<Callout title="Error" color="red" className="mt-4">
|
||||||
|
{testError}
|
||||||
|
</Callout>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="flex mt-8 justify-between">
|
||||||
|
<Button
|
||||||
|
color="blue"
|
||||||
|
onClick={() => handleSubmit()}
|
||||||
|
disabled={!apiKey}
|
||||||
|
>
|
||||||
|
Execute Key Swap
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
<Divider />
|
||||||
|
|
||||||
|
<Subtitle className="mt-4 font-bold text-lg mb-2">
|
||||||
|
You can also delete your key.
|
||||||
|
</Subtitle>
|
||||||
|
<Text className="mb-2">
|
||||||
|
This is only possible if you have already switched to a different
|
||||||
|
embedding type!
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<Button onClick={handleDelete} color="red">
|
||||||
|
Delete key
|
||||||
|
</Button>
|
||||||
|
{deletionError && (
|
||||||
|
<Callout title="Error" color="red" className="mt-4">
|
||||||
|
{deletionError}
|
||||||
|
</Callout>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,41 @@
|
|||||||
|
import React from "react";
|
||||||
|
import { Modal } from "@/components/Modal";
|
||||||
|
import { Button, Text, Callout } from "@tremor/react";
|
||||||
|
import { CloudEmbeddingProvider } from "../components/types";
|
||||||
|
|
||||||
|
export function DeleteCredentialsModal({
|
||||||
|
modelProvider,
|
||||||
|
onConfirm,
|
||||||
|
onCancel,
|
||||||
|
}: {
|
||||||
|
modelProvider: CloudEmbeddingProvider;
|
||||||
|
onConfirm: () => void;
|
||||||
|
onCancel: () => void;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
title={`Nuke ${modelProvider.name} Credentials?`}
|
||||||
|
onOutsideClick={onCancel}
|
||||||
|
>
|
||||||
|
<div className="mb-4">
|
||||||
|
<Text className="text-lg mb-2">
|
||||||
|
You're about to delete your {modelProvider.name} credentials. Are
|
||||||
|
you sure?
|
||||||
|
</Text>
|
||||||
|
<Callout
|
||||||
|
title="Point of No Return"
|
||||||
|
color="red"
|
||||||
|
className="mt-4"
|
||||||
|
></Callout>
|
||||||
|
<div className="flex mt-8 justify-between">
|
||||||
|
<Button color="gray" onClick={onCancel}>
|
||||||
|
Keep Credentaisl
|
||||||
|
</Button>
|
||||||
|
<Button color="red" onClick={onConfirm}>
|
||||||
|
Delete Credentials
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,61 @@
|
|||||||
|
import { Modal } from "@/components/Modal";
|
||||||
|
import { Button, Text, Callout } from "@tremor/react";
|
||||||
|
import {
|
||||||
|
EmbeddingModelDescriptor,
|
||||||
|
HostedEmbeddingModel,
|
||||||
|
} from "../components/types";
|
||||||
|
|
||||||
|
export function ModelSelectionConfirmationModal({
|
||||||
|
selectedModel,
|
||||||
|
isCustom,
|
||||||
|
onConfirm,
|
||||||
|
onCancel,
|
||||||
|
}: {
|
||||||
|
selectedModel: HostedEmbeddingModel;
|
||||||
|
isCustom: boolean;
|
||||||
|
onConfirm: () => void;
|
||||||
|
onCancel: () => void;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Modal title="Update Embedding Model" onOutsideClick={onCancel}>
|
||||||
|
<div>
|
||||||
|
<div className="mb-4">
|
||||||
|
<Text className="text-lg mb-4">
|
||||||
|
You have selected: <b>{selectedModel.model_name}</b>. Are you sure
|
||||||
|
you want to update to this new embedding model?
|
||||||
|
</Text>
|
||||||
|
<Text className="text-lg mb-2">
|
||||||
|
We will re-index all your documents in the background so you will be
|
||||||
|
able to continue to use Danswer as normal with the old model in the
|
||||||
|
meantime. Depending on how many documents you have indexed, this may
|
||||||
|
take a while.
|
||||||
|
</Text>
|
||||||
|
<Text className="text-lg mb-2">
|
||||||
|
<i>NOTE:</i> this re-indexing process will consume more resources
|
||||||
|
than normal. If you are self-hosting, we recommend that you allocate
|
||||||
|
at least 16GB of RAM to Danswer during this process.
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
{/* TODO Change this back- ensure functional */}
|
||||||
|
{!isCustom && (
|
||||||
|
<Callout title="IMPORTANT" color="yellow" className="mt-4">
|
||||||
|
We've detected that this is a custom-specified embedding
|
||||||
|
model. Since we have to download the model files before verifying
|
||||||
|
the configuration's correctness, we won't be able to let
|
||||||
|
you know if the configuration is valid until <b>after</b> we start
|
||||||
|
re-indexing your documents. If there is an issue, it will show up
|
||||||
|
on this page as an indexing error on this page after clicking
|
||||||
|
Confirm.
|
||||||
|
</Callout>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<div className="flex mt-8">
|
||||||
|
<Button className="mx-auto" color="green" onClick={onConfirm}>
|
||||||
|
Confirm
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,232 @@
|
|||||||
|
import React, { useRef, useState } from "react";
|
||||||
|
import { Text, Button, Callout } from "@tremor/react";
|
||||||
|
import { Formik, Form, Field } from "formik";
|
||||||
|
import * as Yup from "yup";
|
||||||
|
import { Label, TextFormField } from "@/components/admin/connectors/Field";
|
||||||
|
import { LoadingAnimation } from "@/components/Loading";
|
||||||
|
import { CloudEmbeddingProvider } from "../components/types";
|
||||||
|
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../../llm/constants";
|
||||||
|
import { Modal } from "@/components/Modal";
|
||||||
|
|
||||||
|
export function ProviderCreationModal({
|
||||||
|
selectedProvider,
|
||||||
|
onConfirm,
|
||||||
|
onCancel,
|
||||||
|
existingProvider,
|
||||||
|
}: {
|
||||||
|
selectedProvider: CloudEmbeddingProvider;
|
||||||
|
onConfirm: () => void;
|
||||||
|
onCancel: () => void;
|
||||||
|
existingProvider?: CloudEmbeddingProvider;
|
||||||
|
}) {
|
||||||
|
const useFileUpload = selectedProvider.name == "Google";
|
||||||
|
|
||||||
|
const [isProcessing, setIsProcessing] = useState(false);
|
||||||
|
const [errorMsg, setErrorMsg] = useState<string>("");
|
||||||
|
const [fileName, setFileName] = useState<string>("");
|
||||||
|
|
||||||
|
const initialValues = {
|
||||||
|
name: existingProvider?.name || selectedProvider.name,
|
||||||
|
api_key: existingProvider?.api_key || "",
|
||||||
|
custom_config: existingProvider?.custom_config
|
||||||
|
? Object.entries(existingProvider.custom_config)
|
||||||
|
: [],
|
||||||
|
default_model_name: "",
|
||||||
|
model_id: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
const validationSchema = Yup.object({
|
||||||
|
name: Yup.string().required("Name is required"),
|
||||||
|
api_key: useFileUpload
|
||||||
|
? Yup.string()
|
||||||
|
: Yup.string().required("API Key is required"),
|
||||||
|
custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)),
|
||||||
|
});
|
||||||
|
|
||||||
|
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
|
const handleFileUpload = async (
|
||||||
|
event: React.ChangeEvent<HTMLInputElement>,
|
||||||
|
setFieldValue: (field: string, value: any) => void
|
||||||
|
) => {
|
||||||
|
const file = event.target.files?.[0];
|
||||||
|
setFileName("");
|
||||||
|
if (file) {
|
||||||
|
setFileName(file.name);
|
||||||
|
try {
|
||||||
|
const fileContent = await file.text();
|
||||||
|
let jsonContent;
|
||||||
|
try {
|
||||||
|
jsonContent = JSON.parse(fileContent);
|
||||||
|
} catch (parseError) {
|
||||||
|
throw new Error(
|
||||||
|
"Failed to parse JSON file. Please ensure it's a valid JSON."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
setFieldValue("api_key", JSON.stringify(jsonContent));
|
||||||
|
} catch (error) {
|
||||||
|
setFieldValue("api_key", "");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSubmit = async (
|
||||||
|
values: any,
|
||||||
|
{ setSubmitting }: { setSubmitting: (isSubmitting: boolean) => void }
|
||||||
|
) => {
|
||||||
|
setIsProcessing(true);
|
||||||
|
setErrorMsg("");
|
||||||
|
|
||||||
|
try {
|
||||||
|
const customConfig = Object.fromEntries(values.custom_config);
|
||||||
|
|
||||||
|
const initialResponse = await fetch(
|
||||||
|
"/api/admin/embedding/test-embedding",
|
||||||
|
{
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({
|
||||||
|
provider: values.name.toLowerCase().split(" ")[0],
|
||||||
|
api_key: values.api_key,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!initialResponse.ok) {
|
||||||
|
const errorMsg = (await initialResponse.json()).detail;
|
||||||
|
setErrorMsg(errorMsg);
|
||||||
|
setIsProcessing(false);
|
||||||
|
setSubmitting(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await fetch(EMBEDDING_PROVIDERS_ADMIN_URL, {
|
||||||
|
method: "PUT",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({
|
||||||
|
...values,
|
||||||
|
custom_config: customConfig,
|
||||||
|
is_default_provider: false,
|
||||||
|
is_configured: true,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
const errorData = await response.json();
|
||||||
|
throw new Error(
|
||||||
|
errorData.detail || "Failed to update provider- check your API key"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
onConfirm();
|
||||||
|
} catch (error: unknown) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
setErrorMsg(error.message);
|
||||||
|
} else {
|
||||||
|
setErrorMsg("An unknown error occurred");
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setIsProcessing(false);
|
||||||
|
setSubmitting(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
title={`Configure ${selectedProvider.name}`}
|
||||||
|
onOutsideClick={onCancel}
|
||||||
|
icon={selectedProvider.icon}
|
||||||
|
>
|
||||||
|
<div>
|
||||||
|
<Formik
|
||||||
|
initialValues={initialValues}
|
||||||
|
validationSchema={validationSchema}
|
||||||
|
onSubmit={handleSubmit}
|
||||||
|
>
|
||||||
|
{({
|
||||||
|
values,
|
||||||
|
errors,
|
||||||
|
touched,
|
||||||
|
isSubmitting,
|
||||||
|
handleSubmit,
|
||||||
|
setFieldValue,
|
||||||
|
}) => (
|
||||||
|
<Form onSubmit={handleSubmit} className="space-y-4">
|
||||||
|
<Text className="text-lg mb-2">
|
||||||
|
You are setting the credentials for this provider. To access
|
||||||
|
this information, follow the instructions{" "}
|
||||||
|
<a
|
||||||
|
className="cursor-pointer underline"
|
||||||
|
target="_blank"
|
||||||
|
href={selectedProvider.docsLink}
|
||||||
|
>
|
||||||
|
here
|
||||||
|
</a>{" "}
|
||||||
|
and gather your{" "}
|
||||||
|
<a
|
||||||
|
className="cursor-pointer underline"
|
||||||
|
target="_blank"
|
||||||
|
href={selectedProvider.apiLink}
|
||||||
|
>
|
||||||
|
API KEY
|
||||||
|
</a>
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<div className="flex flex-col gap-y-2">
|
||||||
|
{useFileUpload ? (
|
||||||
|
<>
|
||||||
|
<Label>Upload JSON File</Label>
|
||||||
|
<input
|
||||||
|
ref={fileInputRef}
|
||||||
|
type="file"
|
||||||
|
accept=".json"
|
||||||
|
onChange={(e) => handleFileUpload(e, setFieldValue)}
|
||||||
|
className="text-lg w-full p-1"
|
||||||
|
/>
|
||||||
|
{fileName && <p>Uploaded file: {fileName}</p>}
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<TextFormField
|
||||||
|
name="api_key"
|
||||||
|
label="API Key"
|
||||||
|
placeholder="API Key"
|
||||||
|
type="password"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<a
|
||||||
|
href={selectedProvider.apiLink}
|
||||||
|
target="_blank"
|
||||||
|
className="underline cursor-pointer"
|
||||||
|
>
|
||||||
|
Learn more here
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{errorMsg && (
|
||||||
|
<Callout title="Error" color="red">
|
||||||
|
{errorMsg}
|
||||||
|
</Callout>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<Button
|
||||||
|
type="submit"
|
||||||
|
color="blue"
|
||||||
|
className="w-full"
|
||||||
|
disabled={isSubmitting}
|
||||||
|
>
|
||||||
|
{isProcessing ? (
|
||||||
|
<LoadingAnimation />
|
||||||
|
) : existingProvider ? (
|
||||||
|
"Update"
|
||||||
|
) : (
|
||||||
|
"Create"
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
</Form>
|
||||||
|
)}
|
||||||
|
</Formik>
|
||||||
|
</div>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
}
|
@ -0,0 +1,37 @@
|
|||||||
|
import React from "react";
|
||||||
|
import { Modal } from "@/components/Modal";
|
||||||
|
import { Button, Text, Callout } from "@tremor/react";
|
||||||
|
import { CloudEmbeddingModel } from "../components/types";
|
||||||
|
|
||||||
|
export function SelectModelModal({
|
||||||
|
model,
|
||||||
|
onConfirm,
|
||||||
|
onCancel,
|
||||||
|
}: {
|
||||||
|
model: CloudEmbeddingModel;
|
||||||
|
onConfirm: () => void;
|
||||||
|
onCancel: () => void;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
title={`Elevate Your Game with ${model.model_name}`}
|
||||||
|
onOutsideClick={onCancel}
|
||||||
|
>
|
||||||
|
<div className="mb-4">
|
||||||
|
<Text className="text-lg mb-2">
|
||||||
|
You're about to set your embedding model to {model.model_name}.
|
||||||
|
<br />
|
||||||
|
Are you sure?
|
||||||
|
</Text>
|
||||||
|
<div className="flex mt-8 justify-between">
|
||||||
|
<Button color="gray" onClick={onCancel}>
|
||||||
|
Exit
|
||||||
|
</Button>
|
||||||
|
<Button color="green" onClick={onConfirm}>
|
||||||
|
Continue
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
}
|
@ -3,28 +3,75 @@
|
|||||||
import { ThreeDotsLoader } from "@/components/Loading";
|
import { ThreeDotsLoader } from "@/components/Loading";
|
||||||
import { AdminPageTitle } from "@/components/admin/Title";
|
import { AdminPageTitle } from "@/components/admin/Title";
|
||||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||||
import { Button, Card, Text, Title } from "@tremor/react";
|
import { Button, Text, Title } from "@tremor/react";
|
||||||
import { FiPackage } from "react-icons/fi";
|
import { FiPackage } from "react-icons/fi";
|
||||||
import useSWR, { mutate } from "swr";
|
import useSWR, { mutate } from "swr";
|
||||||
import { ModelOption, ModelSelector } from "./ModelSelector";
|
import { ModelOption, ModelPreview } from "./components/ModelSelector";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { ModelSelectionConfirmaionModal } from "./ModelSelectionConfirmation";
|
import { ReindexingProgressTable } from "./components/ReindexingProgressTable";
|
||||||
import { ReindexingProgressTable } from "./ReindexingProgressTable";
|
|
||||||
import { Modal } from "@/components/Modal";
|
import { Modal } from "@/components/Modal";
|
||||||
import {
|
import {
|
||||||
|
CloudEmbeddingProvider,
|
||||||
|
CloudEmbeddingModel,
|
||||||
|
AVAILABLE_CLOUD_PROVIDERS,
|
||||||
AVAILABLE_MODELS,
|
AVAILABLE_MODELS,
|
||||||
EmbeddingModelDescriptor,
|
|
||||||
INVALID_OLD_MODEL,
|
INVALID_OLD_MODEL,
|
||||||
fillOutEmeddingModelDescriptor,
|
HostedEmbeddingModel,
|
||||||
} from "./embeddingModels";
|
EmbeddingModelDescriptor,
|
||||||
|
} from "./components/types";
|
||||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||||
import { Connector, ConnectorIndexingStatus } from "@/lib/types";
|
import { Connector, ConnectorIndexingStatus } from "@/lib/types";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { CustomModelForm } from "./CustomModelForm";
|
import OpenEmbeddingPage from "./OpenEmbeddingPage";
|
||||||
|
import CloudEmbeddingPage from "./CloudEmbeddingPage";
|
||||||
|
import { ProviderCreationModal } from "./modals/ProviderCreationModal";
|
||||||
|
|
||||||
|
import { DeleteCredentialsModal } from "./modals/DeleteCredentialsModal";
|
||||||
|
import { SelectModelModal } from "./modals/SelectModelModal";
|
||||||
|
import { ChangeCredentialsModal } from "./modals/ChangeCredentialsModal";
|
||||||
|
import { ModelSelectionConfirmationModal } from "./modals/ModelSelectionModal";
|
||||||
|
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../llm/constants";
|
||||||
|
import { AlreadyPickedModal } from "./modals/AlreadyPickedModal";
|
||||||
|
|
||||||
|
export interface EmbeddingDetails {
|
||||||
|
api_key: string;
|
||||||
|
custom_config: any;
|
||||||
|
default_model_id?: number;
|
||||||
|
name: string;
|
||||||
|
}
|
||||||
|
|
||||||
function Main() {
|
function Main() {
|
||||||
const [tentativeNewEmbeddingModel, setTentativeNewEmbeddingModel] =
|
const [openToggle, setOpenToggle] = useState(true);
|
||||||
useState<EmbeddingModelDescriptor | null>(null);
|
|
||||||
|
// Cloud Provider based modals
|
||||||
|
const [showTentativeProvider, setShowTentativeProvider] =
|
||||||
|
useState<CloudEmbeddingProvider | null>(null);
|
||||||
|
const [showUnconfiguredProvider, setShowUnconfiguredProvider] =
|
||||||
|
useState<CloudEmbeddingProvider | null>(null);
|
||||||
|
const [changeCredentialsProvider, setChangeCredentialsProvider] =
|
||||||
|
useState<CloudEmbeddingProvider | null>(null);
|
||||||
|
|
||||||
|
// Cloud Model based modals
|
||||||
|
const [alreadySelectedModel, setAlreadySelectedModel] =
|
||||||
|
useState<CloudEmbeddingModel | null>(null);
|
||||||
|
const [showTentativeModel, setShowTentativeModel] =
|
||||||
|
useState<CloudEmbeddingModel | null>(null);
|
||||||
|
|
||||||
|
const [showModelInQueue, setShowModelInQueue] =
|
||||||
|
useState<CloudEmbeddingModel | null>(null);
|
||||||
|
|
||||||
|
// Open Model based modals
|
||||||
|
const [showTentativeOpenProvider, setShowTentativeOpenProvider] =
|
||||||
|
useState<HostedEmbeddingModel | null>(null);
|
||||||
|
|
||||||
|
// Enabled / unenabled providers
|
||||||
|
const [newEnabledProviders, setNewEnabledProviders] = useState<string[]>([]);
|
||||||
|
const [newUnenabledProviders, setNewUnenabledProviders] = useState<string[]>(
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] =
|
||||||
|
useState<boolean>(false);
|
||||||
const [isCancelling, setIsCancelling] = useState<boolean>(false);
|
const [isCancelling, setIsCancelling] = useState<boolean>(false);
|
||||||
const [showAddConnectorPopup, setShowAddConnectorPopup] =
|
const [showAddConnectorPopup, setShowAddConnectorPopup] =
|
||||||
useState<boolean>(false);
|
useState<boolean>(false);
|
||||||
@ -33,16 +80,22 @@ function Main() {
|
|||||||
data: currentEmeddingModel,
|
data: currentEmeddingModel,
|
||||||
isLoading: isLoadingCurrentModel,
|
isLoading: isLoadingCurrentModel,
|
||||||
error: currentEmeddingModelError,
|
error: currentEmeddingModelError,
|
||||||
} = useSWR<EmbeddingModelDescriptor>(
|
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
|
||||||
"/api/secondary-index/get-current-embedding-model",
|
"/api/secondary-index/get-current-embedding-model",
|
||||||
errorHandlingFetcher,
|
errorHandlingFetcher,
|
||||||
{ refreshInterval: 5000 } // 5 seconds
|
{ refreshInterval: 5000 } // 5 seconds
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const { data: embeddingProviderDetails } = useSWR<EmbeddingDetails[]>(
|
||||||
|
EMBEDDING_PROVIDERS_ADMIN_URL,
|
||||||
|
errorHandlingFetcher
|
||||||
|
);
|
||||||
|
|
||||||
const {
|
const {
|
||||||
data: futureEmbeddingModel,
|
data: futureEmbeddingModel,
|
||||||
isLoading: isLoadingFutureModel,
|
isLoading: isLoadingFutureModel,
|
||||||
error: futureEmeddingModelError,
|
error: futureEmeddingModelError,
|
||||||
} = useSWR<EmbeddingModelDescriptor | null>(
|
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
|
||||||
"/api/secondary-index/get-secondary-embedding-model",
|
"/api/secondary-index/get-secondary-embedding-model",
|
||||||
errorHandlingFetcher,
|
errorHandlingFetcher,
|
||||||
{ refreshInterval: 5000 } // 5 seconds
|
{ refreshInterval: 5000 } // 5 seconds
|
||||||
@ -61,27 +114,41 @@ function Main() {
|
|||||||
{ refreshInterval: 5000 } // 5 seconds
|
{ refreshInterval: 5000 } // 5 seconds
|
||||||
);
|
);
|
||||||
|
|
||||||
const onSelect = async (model: EmbeddingModelDescriptor) => {
|
const onConfirm = async (
|
||||||
if (currentEmeddingModel?.model_name === INVALID_OLD_MODEL) {
|
model: CloudEmbeddingModel | HostedEmbeddingModel
|
||||||
await onConfirm(model);
|
) => {
|
||||||
} else {
|
let newModel: EmbeddingModelDescriptor;
|
||||||
setTentativeNewEmbeddingModel(model);
|
|
||||||
}
|
if ("cloud_provider_name" in model) {
|
||||||
};
|
// This is a CloudEmbeddingModel
|
||||||
|
newModel = {
|
||||||
|
...model,
|
||||||
|
model_name: model.model_name,
|
||||||
|
cloud_provider_name: model.cloud_provider_name,
|
||||||
|
// cloud_provider_id: model.cloud_provider_id || 0,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
// This is an EmbeddingModelDescriptor
|
||||||
|
newModel = {
|
||||||
|
...model,
|
||||||
|
model_name: model.model_name!,
|
||||||
|
description: "",
|
||||||
|
cloud_provider_name: null,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
const onConfirm = async (model: EmbeddingModelDescriptor) => {
|
|
||||||
const response = await fetch(
|
const response = await fetch(
|
||||||
"/api/secondary-index/set-new-embedding-model",
|
"/api/secondary-index/set-new-embedding-model",
|
||||||
{
|
{
|
||||||
method: "POST",
|
method: "POST",
|
||||||
body: JSON.stringify(model),
|
body: JSON.stringify(newModel),
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
if (response.ok) {
|
if (response.ok) {
|
||||||
setTentativeNewEmbeddingModel(null);
|
setShowTentativeModel(null);
|
||||||
mutate("/api/secondary-index/get-secondary-embedding-model");
|
mutate("/api/secondary-index/get-secondary-embedding-model");
|
||||||
if (!connectors || !connectors.length) {
|
if (!connectors || !connectors.length) {
|
||||||
setShowAddConnectorPopup(true);
|
setShowAddConnectorPopup(true);
|
||||||
@ -96,14 +163,13 @@ function Main() {
|
|||||||
method: "POST",
|
method: "POST",
|
||||||
});
|
});
|
||||||
if (response.ok) {
|
if (response.ok) {
|
||||||
setTentativeNewEmbeddingModel(null);
|
setShowTentativeModel(null);
|
||||||
mutate("/api/secondary-index/get-secondary-embedding-model");
|
mutate("/api/secondary-index/get-secondary-embedding-model");
|
||||||
} else {
|
} else {
|
||||||
alert(
|
alert(
|
||||||
`Failed to cancel embedding model update - ${await response.text()}`
|
`Failed to cancel embedding model update - ${await response.text()}`
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
setIsCancelling(false);
|
setIsCancelling(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -119,38 +185,235 @@ function Main() {
|
|||||||
return <ErrorCallout errorTitle="Failed to fetch embedding model status" />;
|
return <ErrorCallout errorTitle="Failed to fetch embedding model status" />;
|
||||||
}
|
}
|
||||||
|
|
||||||
const currentModelName = currentEmeddingModel.model_name;
|
const onConfirmSelection = async (model: EmbeddingModelDescriptor) => {
|
||||||
const currentModel =
|
const response = await fetch(
|
||||||
AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
|
"/api/secondary-index/set-new-embedding-model",
|
||||||
fillOutEmeddingModelDescriptor(currentEmeddingModel);
|
{
|
||||||
|
method: "POST",
|
||||||
|
body: JSON.stringify(model),
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
);
|
||||||
|
if (response.ok) {
|
||||||
|
setShowTentativeModel(null);
|
||||||
|
mutate("/api/secondary-index/get-secondary-embedding-model");
|
||||||
|
if (!connectors || !connectors.length) {
|
||||||
|
setShowAddConnectorPopup(true);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
alert(`Failed to update embedding model - ${await response.text()}`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const newModelSelection = futureEmbeddingModel
|
const currentModelName = currentEmeddingModel?.model_name;
|
||||||
? AVAILABLE_MODELS.find(
|
const AVAILABLE_CLOUD_PROVIDERS_FLATTENED = AVAILABLE_CLOUD_PROVIDERS.flatMap(
|
||||||
(model) => model.model_name === futureEmbeddingModel.model_name
|
(provider) =>
|
||||||
) || fillOutEmeddingModelDescriptor(futureEmbeddingModel)
|
provider.embedding_models.map((model) => ({
|
||||||
: null;
|
...model,
|
||||||
|
cloud_provider_id: provider.id,
|
||||||
|
model_name: model.model_name, // Ensure model_name is set for consistency
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
|
||||||
|
const currentModel: CloudEmbeddingModel | HostedEmbeddingModel =
|
||||||
|
AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
|
||||||
|
AVAILABLE_CLOUD_PROVIDERS_FLATTENED.find(
|
||||||
|
(model) => model.model_name === currentEmeddingModel.model_name
|
||||||
|
)!;
|
||||||
|
// ||
|
||||||
|
// fillOutEmeddingModelDescriptor(currentEmeddingModel);
|
||||||
|
|
||||||
|
const onSelectOpenSource = async (model: HostedEmbeddingModel) => {
|
||||||
|
if (currentEmeddingModel?.model_name === INVALID_OLD_MODEL) {
|
||||||
|
await onConfirmSelection(model);
|
||||||
|
} else {
|
||||||
|
setShowTentativeOpenProvider(model);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const selectedModel = AVAILABLE_CLOUD_PROVIDERS[0];
|
||||||
|
const clientsideAddProvider = (provider: CloudEmbeddingProvider) => {
|
||||||
|
const providerName = provider.name;
|
||||||
|
setNewEnabledProviders((newEnabledProviders) => [
|
||||||
|
...newEnabledProviders,
|
||||||
|
providerName,
|
||||||
|
]);
|
||||||
|
setNewUnenabledProviders((newUnenabledProviders) =>
|
||||||
|
newUnenabledProviders.filter(
|
||||||
|
(givenProvidername) => givenProvidername != providerName
|
||||||
|
)
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
const clientsideRemoveProvider = (provider: CloudEmbeddingProvider) => {
|
||||||
|
const providerName = provider.name;
|
||||||
|
setNewEnabledProviders((newEnabledProviders) =>
|
||||||
|
newEnabledProviders.filter(
|
||||||
|
(givenProvidername) => givenProvidername != providerName
|
||||||
|
)
|
||||||
|
);
|
||||||
|
setNewUnenabledProviders((newUnenabledProviders) => [
|
||||||
|
...newUnenabledProviders,
|
||||||
|
providerName,
|
||||||
|
]);
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div className="h-screen">
|
||||||
{tentativeNewEmbeddingModel && (
|
<Text>
|
||||||
<ModelSelectionConfirmaionModal
|
Embedding models are used to generate embeddings for your documents,
|
||||||
selectedModel={tentativeNewEmbeddingModel}
|
which then power Danswer's search.
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
{alreadySelectedModel && (
|
||||||
|
<AlreadyPickedModal
|
||||||
|
model={alreadySelectedModel}
|
||||||
|
onClose={() => setAlreadySelectedModel(null)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{showTentativeOpenProvider && (
|
||||||
|
<ModelSelectionConfirmationModal
|
||||||
|
selectedModel={showTentativeOpenProvider}
|
||||||
isCustom={
|
isCustom={
|
||||||
AVAILABLE_MODELS.find(
|
AVAILABLE_MODELS.find(
|
||||||
(model) =>
|
(model) =>
|
||||||
model.model_name === tentativeNewEmbeddingModel.model_name
|
model.model_name === showTentativeOpenProvider.model_name
|
||||||
) === undefined
|
) === undefined
|
||||||
}
|
}
|
||||||
onConfirm={() => onConfirm(tentativeNewEmbeddingModel)}
|
onConfirm={() => onConfirm(showTentativeOpenProvider)}
|
||||||
onCancel={() => setTentativeNewEmbeddingModel(null)}
|
onCancel={() => setShowTentativeOpenProvider(null)}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
{showTentativeProvider && (
|
||||||
|
<ProviderCreationModal
|
||||||
|
selectedProvider={showTentativeProvider}
|
||||||
|
onConfirm={() => {
|
||||||
|
setShowTentativeProvider(showUnconfiguredProvider);
|
||||||
|
clientsideAddProvider(showTentativeProvider);
|
||||||
|
if (showModelInQueue) {
|
||||||
|
setShowTentativeModel(showModelInQueue);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
onCancel={() => {
|
||||||
|
setShowModelInQueue(null);
|
||||||
|
setShowTentativeProvider(null);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{changeCredentialsProvider && (
|
||||||
|
<ChangeCredentialsModal
|
||||||
|
// setPopup={setPopup}
|
||||||
|
useFileUpload={changeCredentialsProvider.name == "Google"}
|
||||||
|
onDeleted={() => {
|
||||||
|
clientsideRemoveProvider(changeCredentialsProvider);
|
||||||
|
setChangeCredentialsProvider(null);
|
||||||
|
}}
|
||||||
|
provider={changeCredentialsProvider}
|
||||||
|
onConfirm={() => setChangeCredentialsProvider(null)}
|
||||||
|
onCancel={() => setChangeCredentialsProvider(null)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{showTentativeModel && (
|
||||||
|
<SelectModelModal
|
||||||
|
model={showTentativeModel}
|
||||||
|
onConfirm={() => {
|
||||||
|
setShowModelInQueue(null);
|
||||||
|
onConfirm(showTentativeModel);
|
||||||
|
}}
|
||||||
|
onCancel={() => {
|
||||||
|
setShowModelInQueue(null);
|
||||||
|
setShowTentativeModel(null);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{showDeleteCredentialsModal && (
|
||||||
|
<DeleteCredentialsModal
|
||||||
|
modelProvider={showTentativeProvider!}
|
||||||
|
onConfirm={() => {
|
||||||
|
setShowDeleteCredentialsModal(false);
|
||||||
|
}}
|
||||||
|
onCancel={() => setShowDeleteCredentialsModal(false)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{currentModel ? (
|
||||||
|
<>
|
||||||
|
<Title className="mt-8 mb-2">Current Embedding Model</Title>
|
||||||
|
<Text>
|
||||||
|
<ModelPreview model={currentModel} />
|
||||||
|
</Text>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<Title className="mt-8 mb-4">Choose your Embedding Model</Title>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{!(futureEmbeddingModel && connectors && connectors.length > 0) && (
|
||||||
|
<>
|
||||||
|
<Title className="mt-8">Switch your Embedding Model</Title>
|
||||||
|
<Text className="mb-4">
|
||||||
|
If the current model is not working for you, you can update your
|
||||||
|
model choice below. Note that this will require a complete
|
||||||
|
re-indexing of all your documents across every connected source. We
|
||||||
|
will take care of this in the background, but depending on the size
|
||||||
|
of your corpus, this could take hours, day, or even weeks. You can
|
||||||
|
monitor the progress of the re-indexing on this page.
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<div className="mt-8 text-sm mr-auto mb-12 divide-x-2 flex ">
|
||||||
|
<button
|
||||||
|
onClick={() => setOpenToggle(true)}
|
||||||
|
className={` mx-2 p-2 font-bold ${openToggle ? "rounded bg-neutral-900 text-neutral-100 underline" : "hover:underline"}`}
|
||||||
|
>
|
||||||
|
Self-hosted
|
||||||
|
</button>
|
||||||
|
<div className="px-2 ">
|
||||||
|
<button
|
||||||
|
onClick={() => setOpenToggle(false)}
|
||||||
|
className={`mx-2 p-2 font-bold ${!openToggle ? "rounded bg-neutral-900 text-neutral-100 underline" : " hover:underline"}`}
|
||||||
|
>
|
||||||
|
Cloud-based
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{!showAddConnectorPopup &&
|
||||||
|
!futureEmbeddingModel &&
|
||||||
|
(openToggle ? (
|
||||||
|
<OpenEmbeddingPage
|
||||||
|
onSelectOpenSource={onSelectOpenSource}
|
||||||
|
currentModelName={currentModelName!}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<CloudEmbeddingPage
|
||||||
|
setShowModelInQueue={setShowModelInQueue}
|
||||||
|
setShowTentativeModel={setShowTentativeModel}
|
||||||
|
currentModel={currentModel}
|
||||||
|
setAlreadySelectedModel={setAlreadySelectedModel}
|
||||||
|
embeddingProviderDetails={embeddingProviderDetails}
|
||||||
|
newEnabledProviders={newEnabledProviders}
|
||||||
|
newUnenabledProviders={newUnenabledProviders}
|
||||||
|
setShowTentativeProvider={setShowTentativeProvider}
|
||||||
|
selectedModel={selectedModel}
|
||||||
|
setChangeCredentialsProvider={setChangeCredentialsProvider}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
|
||||||
|
{openToggle && (
|
||||||
|
<>
|
||||||
{showAddConnectorPopup && (
|
{showAddConnectorPopup && (
|
||||||
<Modal>
|
<Modal>
|
||||||
<div>
|
<div>
|
||||||
<div>
|
<div>
|
||||||
<b className="text-base">Embeding model successfully selected</b>{" "}
|
<b className="text-base">
|
||||||
|
Embedding model successfully selected
|
||||||
|
</b>{" "}
|
||||||
🙌
|
🙌
|
||||||
<br />
|
<br />
|
||||||
<br />
|
<br />
|
||||||
@ -158,12 +421,16 @@ function Main() {
|
|||||||
<br />
|
<br />
|
||||||
<br />
|
<br />
|
||||||
Connectors are the way that Danswer gets data from your
|
Connectors are the way that Danswer gets data from your
|
||||||
organization's various data sources. Once setup, we'll
|
organization's various data sources. Once setup,
|
||||||
automatically sync data from your apps and docs into Danswer, so
|
we'll automatically sync data from your apps and docs
|
||||||
you can search all through all of them in one place.
|
into Danswer, so you can search all through all of them in one
|
||||||
|
place.
|
||||||
</div>
|
</div>
|
||||||
<div className="flex">
|
<div className="flex">
|
||||||
<Link className="mx-auto mt-2 w-fit" href="/admin/add-connector">
|
<Link
|
||||||
|
className="mx-auto mt-2 w-fit"
|
||||||
|
href="/admin/add-connector"
|
||||||
|
>
|
||||||
<Button className="mt-3 mx-auto" size="xs">
|
<Button className="mt-3 mx-auto" size="xs">
|
||||||
Add Connector
|
Add Connector
|
||||||
</Button>
|
</Button>
|
||||||
@ -183,121 +450,33 @@ function Main() {
|
|||||||
Are you sure you want to cancel?
|
Are you sure you want to cancel?
|
||||||
<br />
|
<br />
|
||||||
<br />
|
<br />
|
||||||
Cancelling will revert to the previous model and all progress will
|
Cancelling will revert to the previous model and all progress
|
||||||
be lost.
|
will be lost.
|
||||||
</div>
|
</div>
|
||||||
<div className="flex">
|
<div className="flex">
|
||||||
<Button onClick={onCancel} className="mt-3 mx-auto" color="green">
|
<Button
|
||||||
|
onClick={onCancel}
|
||||||
|
className="mt-3 mx-auto"
|
||||||
|
color="green"
|
||||||
|
>
|
||||||
Confirm
|
Confirm
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</Modal>
|
</Modal>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<Text>
|
|
||||||
Embedding models are used to generate embeddings for your documents,
|
|
||||||
which then power Danswer's search.
|
|
||||||
</Text>
|
|
||||||
|
|
||||||
{currentModel ? (
|
|
||||||
<>
|
|
||||||
<Title className="mt-8 mb-2">Current Embedding Model</Title>
|
|
||||||
|
|
||||||
<Text>
|
|
||||||
<ModelOption model={currentModel} />
|
|
||||||
</Text>
|
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
newModelSelection &&
|
|
||||||
(!connectors || !connectors.length) && (
|
|
||||||
<>
|
|
||||||
<Title className="mt-8 mb-2">Current Embedding Model</Title>
|
|
||||||
|
|
||||||
<Text>
|
|
||||||
<ModelOption model={newModelSelection} />
|
|
||||||
</Text>
|
|
||||||
</>
|
|
||||||
)
|
|
||||||
)}
|
|
||||||
|
|
||||||
{!showAddConnectorPopup &&
|
|
||||||
(!newModelSelection ? (
|
|
||||||
<div>
|
|
||||||
{currentModel ? (
|
|
||||||
<>
|
|
||||||
<Title className="mt-8">Switch your Embedding Model</Title>
|
|
||||||
|
|
||||||
<Text className="mb-4">
|
|
||||||
If the current model is not working for you, you can update
|
|
||||||
your model choice below. Note that this will require a
|
|
||||||
complete re-indexing of all your documents across every
|
|
||||||
connected source. We will take care of this in the background,
|
|
||||||
but depending on the size of your corpus, this could take
|
|
||||||
hours, day, or even weeks. You can monitor the progress of the
|
|
||||||
re-indexing on this page.
|
|
||||||
</Text>
|
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
<>
|
|
||||||
<Title className="mt-8 mb-4">Choose your Embedding Model</Title>
|
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<Text className="mb-4">
|
{futureEmbeddingModel && connectors && connectors.length > 0 && (
|
||||||
Below are a curated selection of quality models that we recommend
|
|
||||||
you choose from.
|
|
||||||
</Text>
|
|
||||||
|
|
||||||
<ModelSelector
|
|
||||||
modelOptions={AVAILABLE_MODELS.filter(
|
|
||||||
(modelOption) => modelOption.model_name !== currentModelName
|
|
||||||
)}
|
|
||||||
setSelectedModel={onSelect}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<Text className="mt-6">
|
|
||||||
Alternatively, (if you know what you're doing) you can
|
|
||||||
specify a{" "}
|
|
||||||
<a
|
|
||||||
target="_blank"
|
|
||||||
href="https://www.sbert.net/"
|
|
||||||
className="text-link"
|
|
||||||
>
|
|
||||||
SentenceTransformers
|
|
||||||
</a>
|
|
||||||
-compatible model of your choice below. The rough list of
|
|
||||||
supported models can be found{" "}
|
|
||||||
<a
|
|
||||||
target="_blank"
|
|
||||||
href="https://huggingface.co/models?library=sentence-transformers&sort=trending"
|
|
||||||
className="text-link"
|
|
||||||
>
|
|
||||||
here
|
|
||||||
</a>
|
|
||||||
.
|
|
||||||
<br />
|
|
||||||
<b>NOTE:</b> not all models listed will work with Danswer, since
|
|
||||||
some have unique interfaces or special requirements. If in doubt,
|
|
||||||
reach out to the Danswer team.
|
|
||||||
</Text>
|
|
||||||
|
|
||||||
<div className="w-full flex">
|
|
||||||
<Card className="mt-4 2xl:w-4/6 mx-auto">
|
|
||||||
<CustomModelForm onSubmit={onSelect} />
|
|
||||||
</Card>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
connectors &&
|
|
||||||
connectors.length > 0 && (
|
|
||||||
<div>
|
<div>
|
||||||
<Title className="mt-8">Current Upgrade Status</Title>
|
<Title className="mt-8">Current Upgrade Status</Title>
|
||||||
<div className="mt-4">
|
<div className="mt-4">
|
||||||
<div className="italic text-sm mb-2">
|
<div className="italic text-lg mb-2">
|
||||||
Currently in the process of switching to:
|
Currently in the process of switching to:{" "}
|
||||||
|
{futureEmbeddingModel.model_name}
|
||||||
</div>
|
</div>
|
||||||
<ModelOption model={newModelSelection} />
|
{/* <ModelOption model={futureEmbeddingModel} /> */}
|
||||||
|
|
||||||
<Button
|
<Button
|
||||||
color="red"
|
color="red"
|
||||||
@ -310,10 +489,10 @@ function Main() {
|
|||||||
|
|
||||||
<Text className="my-4">
|
<Text className="my-4">
|
||||||
The table below shows the re-indexing progress of all existing
|
The table below shows the re-indexing progress of all existing
|
||||||
connectors. Once all connectors have been re-indexed
|
connectors. Once all connectors have been re-indexed successfully,
|
||||||
successfully, the new model will be used for all search
|
the new model will be used for all search queries. Until then, we
|
||||||
queries. Until then, we will use the old model so that no
|
will use the old model so that no downtime is necessary during
|
||||||
downtime is necessary during this transition.
|
this transition.
|
||||||
</Text>
|
</Text>
|
||||||
|
|
||||||
{isLoadingOngoingReIndexingStatus ? (
|
{isLoadingOngoingReIndexingStatus ? (
|
||||||
@ -327,8 +506,7 @@ function Main() {
|
|||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)
|
)}
|
||||||
))}
|
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -1 +1,4 @@
|
|||||||
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
|
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
|
||||||
|
|
||||||
|
export const EMBEDDING_PROVIDERS_ADMIN_URL =
|
||||||
|
"/api/admin/embedding/embedding-provider";
|
||||||
|
@ -20,7 +20,7 @@ import {
|
|||||||
import { unstable_noStore as noStore } from "next/cache";
|
import { unstable_noStore as noStore } from "next/cache";
|
||||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||||
import { personaComparator } from "../admin/assistants/lib";
|
import { personaComparator } from "../admin/assistants/lib";
|
||||||
import { FullEmbeddingModelResponse } from "../admin/models/embedding/embeddingModels";
|
import { FullEmbeddingModelResponse } from "../admin/models/embedding/components/types";
|
||||||
import { NoSourcesModal } from "@/components/initialSetup/search/NoSourcesModal";
|
import { NoSourcesModal } from "@/components/initialSetup/search/NoSourcesModal";
|
||||||
import { NoCompleteSourcesModal } from "@/components/initialSetup/search/NoCompleteSourceModal";
|
import { NoCompleteSourcesModal } from "@/components/initialSetup/search/NoCompleteSourceModal";
|
||||||
import { ChatPopup } from "../chat/ChatPopup";
|
import { ChatPopup } from "../chat/ChatPopup";
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import { Divider } from "@tremor/react";
|
import { Divider } from "@tremor/react";
|
||||||
import { FiX } from "react-icons/fi";
|
import { FiX } from "react-icons/fi";
|
||||||
|
import { IconProps } from "./icons/icons";
|
||||||
|
|
||||||
interface ModalProps {
|
interface ModalProps {
|
||||||
|
icon?: ({ size, className }: IconProps) => JSX.Element;
|
||||||
children: JSX.Element | string;
|
children: JSX.Element | string;
|
||||||
title?: JSX.Element | string;
|
title?: JSX.Element | string;
|
||||||
onOutsideClick?: () => void;
|
onOutsideClick?: () => void;
|
||||||
@ -13,6 +15,7 @@ interface ModalProps {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function Modal({
|
export function Modal({
|
||||||
|
icon,
|
||||||
children,
|
children,
|
||||||
title,
|
title,
|
||||||
onOutsideClick,
|
onOutsideClick,
|
||||||
@ -44,10 +47,15 @@ export function Modal({
|
|||||||
<>
|
<>
|
||||||
<div className="flex mb-4">
|
<div className="flex mb-4">
|
||||||
<h2
|
<h2
|
||||||
className={"my-auto font-bold " + (titleSize || "text-2xl")}
|
className={
|
||||||
|
"my-auto flex content-start gap-x-4 font-bold " +
|
||||||
|
(titleSize || "text-2xl")
|
||||||
|
}
|
||||||
>
|
>
|
||||||
{title}
|
{title}
|
||||||
|
{icon && icon({ size: 30 })}
|
||||||
</h2>
|
</h2>
|
||||||
|
|
||||||
{onOutsideClick && (
|
{onOutsideClick && (
|
||||||
<div
|
<div
|
||||||
onClick={onOutsideClick}
|
onClick={onOutsideClick}
|
||||||
|
@ -72,11 +72,16 @@ import sharepointIcon from "../../../public/Sharepoint.png";
|
|||||||
import teamsIcon from "../../../public/Teams.png";
|
import teamsIcon from "../../../public/Teams.png";
|
||||||
import mediawikiIcon from "../../../public/MediaWiki.svg";
|
import mediawikiIcon from "../../../public/MediaWiki.svg";
|
||||||
import wikipediaIcon from "../../../public/Wikipedia.svg";
|
import wikipediaIcon from "../../../public/Wikipedia.svg";
|
||||||
|
|
||||||
import discourseIcon from "../../../public/Discourse.png";
|
import discourseIcon from "../../../public/Discourse.png";
|
||||||
import clickupIcon from "../../../public/Clickup.svg";
|
import clickupIcon from "../../../public/Clickup.svg";
|
||||||
|
import cohereIcon from "../../../public/Cohere.svg";
|
||||||
|
import voyageIcon from "../../../public/Voyage.png";
|
||||||
|
import googleIcon from "../../../public/Google.webp";
|
||||||
|
|
||||||
import { FaRobot } from "react-icons/fa";
|
import { FaRobot } from "react-icons/fa";
|
||||||
|
|
||||||
interface IconProps {
|
export interface IconProps {
|
||||||
size?: number;
|
size?: number;
|
||||||
className?: string;
|
className?: string;
|
||||||
}
|
}
|
||||||
@ -84,20 +89,6 @@ interface IconProps {
|
|||||||
export const defaultTailwindCSS = "my-auto flex flex-shrink-0 text-default";
|
export const defaultTailwindCSS = "my-auto flex flex-shrink-0 text-default";
|
||||||
export const defaultTailwindCSSBlue = "my-auto flex flex-shrink-0 text-link";
|
export const defaultTailwindCSSBlue = "my-auto flex flex-shrink-0 text-link";
|
||||||
|
|
||||||
export const OpenAIIcon = ({
|
|
||||||
size = 16,
|
|
||||||
className = defaultTailwindCSS,
|
|
||||||
}: IconProps) => {
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
style={{ width: `${size + 4}px`, height: `${size + 4}px` }}
|
|
||||||
className={`w-[${size + 4}px] h-[${size + 4}px] -m-0.5 ` + className}
|
|
||||||
>
|
|
||||||
<Image src={openAISVG} alt="Logo" width="96" height="96" />
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export const OpenSourceIcon = ({
|
export const OpenSourceIcon = ({
|
||||||
size = 16,
|
size = 16,
|
||||||
className = defaultTailwindCSS,
|
className = defaultTailwindCSS,
|
||||||
@ -528,6 +519,62 @@ export const ZulipIcon = ({
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const OpenAIIcon = ({
|
||||||
|
size = 16,
|
||||||
|
className = defaultTailwindCSS,
|
||||||
|
}: IconProps) => {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
style={{ width: `${size}px`, height: `${size}px` }}
|
||||||
|
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||||
|
>
|
||||||
|
<Image src={openAISVG} alt="Logo" width="96" height="96" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const VoyageIcon = ({
|
||||||
|
size = 16,
|
||||||
|
className = defaultTailwindCSS,
|
||||||
|
}: IconProps) => {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
style={{ width: `${size}px`, height: `${size}px` }}
|
||||||
|
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||||
|
>
|
||||||
|
<Image src={voyageIcon} alt="Logo" width="96" height="96" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const GoogleIcon = ({
|
||||||
|
size = 16,
|
||||||
|
className = defaultTailwindCSS,
|
||||||
|
}: IconProps) => {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
style={{ width: `${size}px`, height: `${size}px` }}
|
||||||
|
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||||
|
>
|
||||||
|
<Image src={googleIcon} alt="Logo" width="96" height="96" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const CohereIcon = ({
|
||||||
|
size = 16,
|
||||||
|
className = defaultTailwindCSS,
|
||||||
|
}: IconProps) => {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
style={{ width: `${size}px`, height: `${size}px` }}
|
||||||
|
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||||
|
>
|
||||||
|
<Image src={cohereIcon} alt="Logo" width="96" height="96" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
export const GoogleStorageIcon = ({
|
export const GoogleStorageIcon = ({
|
||||||
size = 16,
|
size = 16,
|
||||||
className = defaultTailwindCSS,
|
className = defaultTailwindCSS,
|
||||||
|
@ -13,7 +13,7 @@ import {
|
|||||||
} from "@/lib/types";
|
} from "@/lib/types";
|
||||||
import { ChatSession } from "@/app/chat/interfaces";
|
import { ChatSession } from "@/app/chat/interfaces";
|
||||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||||
import { FullEmbeddingModelResponse } from "@/app/admin/models/embedding/embeddingModels";
|
import { FullEmbeddingModelResponse } from "@/app/admin/models/embedding/components/types";
|
||||||
import { Settings } from "@/app/admin/settings/interfaces";
|
import { Settings } from "@/app/admin/settings/interfaces";
|
||||||
import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs";
|
import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs";
|
||||||
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
|
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
|
||||||
|
Loading…
x
Reference in New Issue
Block a user