mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
add third party embedding models (#1818)
This commit is contained in:
parent
b6bd818e60
commit
e7f81d1688
@ -0,0 +1,65 @@
|
||||
"""add cloud embedding model and update embedding_model
|
||||
|
||||
Revision ID: 44f856ae2a4a
|
||||
Revises: d716b0791ddd
|
||||
Create Date: 2024-06-28 20:01:05.927647
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "44f856ae2a4a"
|
||||
down_revision = "d716b0791ddd"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create embedding_provider table
|
||||
op.create_table(
|
||||
"embedding_provider",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("default_model_id", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
# Add cloud_provider_id to embedding_model table
|
||||
op.add_column(
|
||||
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Add foreign key constraints
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_model_cloud_provider",
|
||||
"embedding_model",
|
||||
"embedding_provider",
|
||||
["cloud_provider_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_embedding_provider_default_model",
|
||||
"embedding_provider",
|
||||
"embedding_model",
|
||||
["default_model_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove foreign key constraints
|
||||
op.drop_constraint(
|
||||
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Remove cloud_provider_id column
|
||||
op.drop_column("embedding_model", "cloud_provider_id")
|
||||
|
||||
# Drop embedding_provider table
|
||||
op.drop_table("embedding_provider")
|
@ -10,8 +10,8 @@ from alembic import op
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d716b0791ddd"
|
||||
down_revision = "7aea705850d5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -98,7 +98,6 @@ def _run_indexing(
|
||||
3. Updates Postgres to record the indexed documents + the outcome of this run
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
db_embedding_model = index_attempt.embedding_model
|
||||
index_name = db_embedding_model.index_name
|
||||
|
||||
@ -116,6 +115,8 @@ def _run_indexing(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@ -287,6 +288,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
)
|
||||
|
||||
if attempt is None:
|
||||
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
|
||||
|
||||
|
@ -343,13 +343,15 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
|
@ -469,13 +469,13 @@ if __name__ == "__main__":
|
||||
# or the tokens have updated (set up for the first time)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
warm_up_encoders(
|
||||
model_name=embedding_model.model_name,
|
||||
normalize=embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
if embedding_model.cloud_provider_id is None:
|
||||
warm_up_encoders(
|
||||
model_name=embedding_model.model_name,
|
||||
normalize=embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
slack_bot_tokens = latest_slack_bot_tokens
|
||||
# potentially may cause a message to be dropped, but it is complicated
|
||||
|
@ -10,10 +10,15 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from danswer.db.llm import fetch_embedding_provider
|
||||
from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.search.search_nlp_models import clean_model_name
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -31,6 +36,7 @@ def create_embedding_model(
|
||||
query_prefix=model_details.query_prefix,
|
||||
passage_prefix=model_details.passage_prefix,
|
||||
status=status,
|
||||
cloud_provider_id=model_details.cloud_provider_id,
|
||||
# Every single embedding model except the initial one from migrations has this name
|
||||
# The initial one from migration is called "danswer_chunk"
|
||||
index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}",
|
||||
@ -42,6 +48,42 @@ def create_embedding_model(
|
||||
return embedding_model
|
||||
|
||||
|
||||
def get_model_id_from_name(
|
||||
db_session: Session, embedding_provider_name: str
|
||||
) -> int | None:
|
||||
query = select(CloudEmbeddingProvider).where(
|
||||
CloudEmbeddingProvider.name == embedding_provider_name
|
||||
)
|
||||
provider = db_session.execute(query).scalars().first()
|
||||
return provider.id if provider else None
|
||||
|
||||
|
||||
def get_current_db_embedding_provider(
|
||||
db_session: Session,
|
||||
) -> ServerCloudEmbeddingProvider | None:
|
||||
current_embedding_model = EmbeddingModelDetail.from_model(
|
||||
get_current_db_embedding_model(db_session=db_session)
|
||||
)
|
||||
|
||||
if (
|
||||
current_embedding_model is None
|
||||
or current_embedding_model.cloud_provider_id is None
|
||||
):
|
||||
return None
|
||||
|
||||
embedding_provider = fetch_embedding_provider(
|
||||
db_session=db_session, provider_id=current_embedding_model.cloud_provider_id
|
||||
)
|
||||
if embedding_provider is None:
|
||||
raise RuntimeError("No embedding provider exists for this model.")
|
||||
|
||||
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
|
||||
cloud_provider_model=embedding_provider
|
||||
)
|
||||
|
||||
return current_embedding_provider
|
||||
|
||||
|
||||
def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel:
|
||||
query = (
|
||||
select(EmbeddingModel)
|
||||
|
@ -2,11 +2,34 @@ from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
|
||||
|
||||
def upsert_cloud_embedding_provider(
|
||||
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
||||
) -> CloudEmbeddingProvider:
|
||||
existing_provider = (
|
||||
db_session.query(CloudEmbeddingProviderModel)
|
||||
.filter_by(name=provider.name)
|
||||
.first()
|
||||
)
|
||||
if existing_provider:
|
||||
for key, value in provider.dict().items():
|
||||
setattr(existing_provider, key, value)
|
||||
else:
|
||||
new_provider = CloudEmbeddingProviderModel(**provider.dict())
|
||||
db_session.add(new_provider)
|
||||
existing_provider = new_provider
|
||||
db_session.commit()
|
||||
db_session.refresh(existing_provider)
|
||||
return CloudEmbeddingProvider.from_request(existing_provider)
|
||||
|
||||
|
||||
def upsert_llm_provider(
|
||||
db_session: Session, llm_provider: LLMProviderUpsertRequest
|
||||
) -> FullLLMProvider:
|
||||
@ -26,7 +49,6 @@ def upsert_llm_provider(
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
db_session.commit()
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
# if it does not exist, create a new entry
|
||||
llm_provider_model = LLMProviderModel(
|
||||
name=llm_provider.name,
|
||||
@ -46,10 +68,26 @@ def upsert_llm_provider(
|
||||
return FullLLMProvider.from_model(llm_provider_model)
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
db_session: Session,
|
||||
) -> list[CloudEmbeddingProviderModel]:
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
db_session: Session, provider_id: int
|
||||
) -> CloudEmbeddingProviderModel | None:
|
||||
return db_session.scalar(
|
||||
select(CloudEmbeddingProviderModel).where(
|
||||
CloudEmbeddingProviderModel.id == provider_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
@ -70,6 +108,16 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def remove_embedding_provider(
|
||||
db_session: Session, embedding_provider_name: str
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(CloudEmbeddingProviderModel).where(
|
||||
CloudEmbeddingProviderModel.name == embedding_provider_name
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
|
@ -130,6 +130,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
chat_folders: Mapped[list["ChatFolder"]] = relationship(
|
||||
"ChatFolder", back_populates="user"
|
||||
)
|
||||
|
||||
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
|
||||
# Personas owned by this user
|
||||
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
|
||||
@ -469,7 +470,7 @@ class Credential(Base):
|
||||
|
||||
class EmbeddingModel(Base):
|
||||
__tablename__ = "embedding_model"
|
||||
# ID is used also to indicate the order that the models are configured by the admin
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
model_name: Mapped[str] = mapped_column(String)
|
||||
model_dim: Mapped[int] = mapped_column(Integer)
|
||||
@ -481,6 +482,16 @@ class EmbeddingModel(Base):
|
||||
)
|
||||
index_name: Mapped[str] = mapped_column(String)
|
||||
|
||||
# New field for cloud provider relationship
|
||||
cloud_provider_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("embedding_provider.id")
|
||||
)
|
||||
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
|
||||
"CloudEmbeddingProvider",
|
||||
back_populates="embedding_models",
|
||||
foreign_keys=[cloud_provider_id],
|
||||
)
|
||||
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="embedding_model"
|
||||
)
|
||||
@ -500,6 +511,18 @@ class EmbeddingModel(Base):
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
|
||||
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider else None
|
||||
|
||||
@property
|
||||
def provider_type(self) -> str | None:
|
||||
return self.cloud_provider.name if self.cloud_provider else None
|
||||
|
||||
|
||||
class IndexAttempt(Base):
|
||||
"""
|
||||
@ -519,6 +542,7 @@ class IndexAttempt(Base):
|
||||
ForeignKey("credential.id"),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Some index attempts that run from beginning will still have this as False
|
||||
# This is only for attempts that are explicitly marked as from the start via
|
||||
# the run once API
|
||||
@ -879,11 +903,6 @@ class ChatMessageFeedback(Base):
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Structures, Organizational, Configurations Tables
|
||||
"""
|
||||
|
||||
|
||||
class LLMProvider(Base):
|
||||
__tablename__ = "llm_provider"
|
||||
|
||||
@ -912,6 +931,29 @@ class LLMProvider(Base):
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
__tablename__ = "embedding_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||
default_model_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("embedding_model.id"), nullable=True
|
||||
)
|
||||
|
||||
embedding_models: Mapped[list["EmbeddingModel"]] = relationship(
|
||||
"EmbeddingModel",
|
||||
back_populates="cloud_provider",
|
||||
foreign_keys="EmbeddingModel.cloud_provider_id",
|
||||
)
|
||||
default_model: Mapped["EmbeddingModel"] = relationship(
|
||||
"EmbeddingModel", foreign_keys=[default_model_id]
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<EmbeddingProvider(name='{self.name}')>"
|
||||
|
||||
|
||||
class DocumentSet(Base):
|
||||
__tablename__ = "document_set"
|
||||
|
||||
@ -1194,6 +1236,7 @@ class SlackBotConfig(Base):
|
||||
response_type: Mapped[SlackBotResponseType] = mapped_column(
|
||||
Enum(SlackBotResponseType, native_enum=False), nullable=False
|
||||
)
|
||||
|
||||
enable_auto_filters: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
|
@ -50,6 +50,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
normalize: bool,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
api_key: str | None = None,
|
||||
provider_type: str | None = None,
|
||||
):
|
||||
super().__init__(model_name, normalize, query_prefix, passage_prefix)
|
||||
self.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE # Currently not customizable
|
||||
@ -59,6 +61,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
query_prefix=query_prefix,
|
||||
passage_prefix=passage_prefix,
|
||||
normalize=normalize,
|
||||
api_key=api_key,
|
||||
provider_type=provider_type,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
|
@ -97,13 +97,19 @@ class EmbeddingModelDetail(BaseModel):
|
||||
normalize: bool
|
||||
query_prefix: str | None
|
||||
passage_prefix: str | None
|
||||
cloud_provider_id: int | None = None
|
||||
cloud_provider_name: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, embedding_model: "EmbeddingModel") -> "EmbeddingModelDetail":
|
||||
def from_model(
|
||||
cls,
|
||||
embedding_model: "EmbeddingModel",
|
||||
) -> "EmbeddingModelDetail":
|
||||
return cls(
|
||||
model_name=embedding_model.model_name,
|
||||
model_dim=embedding_model.model_dim,
|
||||
normalize=embedding_model.normalize,
|
||||
query_prefix=embedding_model.query_prefix,
|
||||
passage_prefix=embedding_model.passage_prefix,
|
||||
cloud_provider_id=embedding_model.cloud_provider_id,
|
||||
)
|
||||
|
@ -67,6 +67,8 @@ from danswer.server.features.tool.api import admin_router as admin_tool_router
|
||||
from danswer.server.features.tool.api import router as tool_router
|
||||
from danswer.server.gpts.api import router as gpts_router
|
||||
from danswer.server.manage.administrative import router as admin_router
|
||||
from danswer.server.manage.embedding.api import admin_router as embedding_admin_router
|
||||
from danswer.server.manage.embedding.api import basic_router as embedding_router
|
||||
from danswer.server.manage.get_state import router as state_router
|
||||
from danswer.server.manage.llm.api import admin_router as llm_admin_router
|
||||
from danswer.server.manage.llm.api import basic_router as llm_router
|
||||
@ -247,12 +249,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
time.sleep(wait_time)
|
||||
|
||||
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
yield
|
||||
@ -291,6 +294,8 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, settings_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_router)
|
||||
include_router_with_global_prefix_prepended(application, embedding_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, embedding_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
|
@ -168,6 +168,7 @@ def stream_answer_objects(
|
||||
max_tokens=max_document_tokens,
|
||||
use_sections=query_req.chunks_above > 0 or query_req.chunks_below > 0,
|
||||
)
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
|
@ -131,6 +131,8 @@ def doc_index_retrieval(
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
normalize=db_embedding_model.normalize,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
|
@ -84,20 +84,24 @@ def build_model_server_url(
|
||||
class EmbeddingModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
normalize: bool,
|
||||
server_host: str, # Changes depending on indexing or inference
|
||||
server_port: int,
|
||||
model_name: str | None,
|
||||
normalize: bool,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
# The following are globals are currently not configurable
|
||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
self.max_seq_length = max_seq_length
|
||||
self.query_prefix = query_prefix
|
||||
self.passage_prefix = passage_prefix
|
||||
self.normalize = normalize
|
||||
self.model_name = model_name
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
@ -111,10 +115,13 @@ class EmbeddingModel:
|
||||
prefixed_texts = texts
|
||||
|
||||
embed_request = EmbedRequest(
|
||||
texts=prefixed_texts,
|
||||
model_name=self.model_name,
|
||||
texts=prefixed_texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
||||
@ -187,6 +194,8 @@ def warm_up_encoders(
|
||||
passage_prefix=None,
|
||||
server_host=model_server_host,
|
||||
server_port=model_server_port,
|
||||
api_key=None,
|
||||
provider_type=None,
|
||||
)
|
||||
|
||||
# First time downloading the models it may take even longer, but just in case,
|
||||
|
93
backend/danswer/server/manage/embedding/api.py
Normal file
93
backend/danswer/server/manage/embedding/api.py
Normal file
@ -0,0 +1,93 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.embedding_model import get_current_db_embedding_provider
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.llm import fetch_existing_embedding_providers
|
||||
from danswer.db.llm import remove_embedding_provider
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
from danswer.db.models import User
|
||||
from danswer.search.enums import EmbedTextType
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.embedding.models import TestEmbeddingRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/embedding")
|
||||
basic_router = APIRouter(prefix="/embedding")
|
||||
|
||||
|
||||
@admin_router.post("/test-embedding")
|
||||
def test_embedding_configuration(
|
||||
test_llm_request: TestEmbeddingRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
test_model = EmbeddingModel(
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
api_key=test_llm_request.api_key,
|
||||
provider_type=test_llm_request.provider,
|
||||
normalize=False,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
model_name=None,
|
||||
)
|
||||
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
|
||||
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.get("/embedding-provider")
|
||||
def list_embedding_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[CloudEmbeddingProvider]:
|
||||
return [
|
||||
CloudEmbeddingProvider.from_request(embedding_provider_model)
|
||||
for embedding_provider_model in fetch_existing_embedding_providers(db_session)
|
||||
]
|
||||
|
||||
|
||||
@admin_router.delete("/embedding-provider/{embedding_provider_name}")
|
||||
def delete_embedding_provider(
|
||||
embedding_provider_name: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
embedding_provider = get_current_db_embedding_provider(db_session=db_session)
|
||||
if (
|
||||
embedding_provider is not None
|
||||
and embedding_provider_name == embedding_provider.name
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="You can't delete a currently active model"
|
||||
)
|
||||
|
||||
remove_embedding_provider(db_session, embedding_provider_name)
|
||||
|
||||
|
||||
@admin_router.put("/embedding-provider")
|
||||
def put_cloud_embedding_provider(
|
||||
provider: CloudEmbeddingProviderCreationRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CloudEmbeddingProvider:
|
||||
return upsert_cloud_embedding_provider(db_session, provider)
|
35
backend/danswer/server/manage/embedding/models.py
Normal file
35
backend/danswer/server/manage/embedding/models.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
|
||||
|
||||
class TestEmbeddingRequest(BaseModel):
|
||||
provider: str
|
||||
api_key: str | None = None
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(BaseModel):
|
||||
name: str
|
||||
api_key: str | None = None
|
||||
default_model_id: int | None = None
|
||||
id: int
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls, cloud_provider_model: "CloudEmbeddingProviderModel"
|
||||
) -> "CloudEmbeddingProvider":
|
||||
return cls(
|
||||
id=cloud_provider_model.id,
|
||||
name=cloud_provider_model.name,
|
||||
api_key=cloud_provider_model.api_key,
|
||||
default_model_id=cloud_provider_model.default_model_id,
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProviderCreationRequest(BaseModel):
|
||||
name: str
|
||||
api_key: str | None = None
|
||||
default_model_id: int | None = None
|
@ -4,6 +4,7 @@ from pydantic import BaseModel
|
||||
|
||||
from danswer.llm.llm_provider_options import fetch_models_for_provider
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
|
||||
|
@ -11,6 +11,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.embedding_model import create_embedding_model
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_model_id_from_name
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.embedding_model import update_embedding_model_status
|
||||
from danswer.db.engine import get_session
|
||||
@ -38,6 +39,19 @@ def set_new_embedding_model(
|
||||
"""
|
||||
current_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
if embed_model_details.cloud_provider_name is not None:
|
||||
cloud_id = get_model_id_from_name(
|
||||
db_session, embed_model_details.cloud_provider_name
|
||||
)
|
||||
|
||||
if cloud_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No ID exists for given provider name",
|
||||
)
|
||||
|
||||
embed_model_details.cloud_provider_id = cloud_id
|
||||
|
||||
if embed_model_details.model_name == current_model.model_name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
@ -1 +1,38 @@
|
||||
from enum import Enum
|
||||
|
||||
from danswer.search.enums import EmbedTextType
|
||||
|
||||
|
||||
MODEL_WARM_UP_STRING = "hi " * 512
|
||||
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
|
||||
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
||||
DEFAULT_VERTEX_MODEL = "text-embedding-004"
|
||||
|
||||
|
||||
class EmbeddingProvider(Enum):
|
||||
OPENAI = "openai"
|
||||
COHERE = "cohere"
|
||||
VOYAGE = "voyage"
|
||||
GOOGLE = "google"
|
||||
|
||||
|
||||
class EmbeddingModelTextType:
|
||||
PROVIDER_TEXT_TYPE_MAP = {
|
||||
EmbeddingProvider.COHERE: {
|
||||
EmbedTextType.QUERY: "search_query",
|
||||
EmbedTextType.PASSAGE: "search_document",
|
||||
},
|
||||
EmbeddingProvider.VOYAGE: {
|
||||
EmbedTextType.QUERY: "query",
|
||||
EmbedTextType.PASSAGE: "document",
|
||||
},
|
||||
EmbeddingProvider.GOOGLE: {
|
||||
EmbedTextType.QUERY: "RETRIEVAL_QUERY",
|
||||
EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT",
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
|
||||
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
|
||||
|
@ -1,12 +1,28 @@
|
||||
import gc
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
import vertexai # type: ignore
|
||||
import voyageai # type: ignore
|
||||
from cohere import Client as CohereClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingModel # type: ignore
|
||||
|
||||
from danswer.search.enums import EmbedTextType
|
||||
from danswer.utils.logger import setup_logger
|
||||
from model_server.constants import DEFAULT_COHERE_MODEL
|
||||
from model_server.constants import DEFAULT_OPENAI_MODEL
|
||||
from model_server.constants import DEFAULT_VERTEX_MODEL
|
||||
from model_server.constants import DEFAULT_VOYAGE_MODEL
|
||||
from model_server.constants import EmbeddingModelTextType
|
||||
from model_server.constants import EmbeddingProvider
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.utils import simple_log_function_time
|
||||
from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
|
||||
@ -17,6 +33,7 @@ from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/encoder")
|
||||
@ -25,6 +42,117 @@ _GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(self, api_key: str, provider: str, model: str | None = None):
|
||||
self.api_key = api_key
|
||||
|
||||
# Only for Google as is needed on client setup
|
||||
self.model = model
|
||||
try:
|
||||
self.provider = EmbeddingProvider(provider.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
self.client = self._initialize_client()
|
||||
|
||||
def _initialize_client(self) -> Any:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return openai.OpenAI(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.COHERE:
|
||||
return CohereClient(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return voyageai.Client(api_key=self.api_key)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.api_key)
|
||||
)
|
||||
project_id = json.loads(self.api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
return TextEmbeddingModel.from_pretrained(
|
||||
self.model or DEFAULT_VERTEX_MODEL
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def encode(
|
||||
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
|
||||
) -> list[list[float]]:
|
||||
return [
|
||||
self.embed(text=text, text_type=text_type, model=model_name)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
def embed(
|
||||
self, *, text: str, text_type: EmbedTextType, model: str | None = None
|
||||
) -> list[float]:
|
||||
logger.debug(f"Embedding text with provider: {self.provider}")
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(text, model)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(text, model, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(text, model, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
def _embed_openai(self, text: str, model: str | None) -> list[float]:
|
||||
if model is None:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
response = self.client.embeddings.create(input=text, model=model)
|
||||
return response.data[0].embedding
|
||||
|
||||
def _embed_cohere(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
if model is None:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
response = self.client.embed(
|
||||
texts=[text],
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
)
|
||||
return response.embeddings[0]
|
||||
|
||||
def _embed_voyage(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
if model is None:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
response = self.client.embed(text, model=model, input_type=embedding_type)
|
||||
return response.embeddings[0]
|
||||
|
||||
def _embed_vertex(
|
||||
self, text: str, model: str | None, embedding_type: str
|
||||
) -> list[float]:
|
||||
if model is None:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
|
||||
embedding = self.client.get_embeddings(
|
||||
[
|
||||
TextEmbeddingInput(
|
||||
text,
|
||||
embedding_type,
|
||||
)
|
||||
]
|
||||
)
|
||||
return embedding[0].values
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
api_key: str, provider: str, model: str | None = None
|
||||
) -> "CloudEmbedding":
|
||||
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||
return CloudEmbedding(api_key, provider, model)
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name: str,
|
||||
max_context_length: int,
|
||||
@ -78,18 +206,35 @@ def warm_up_cross_encoders() -> None:
|
||||
@simple_log_function_time()
|
||||
def embed_text(
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None,
|
||||
max_context_length: int,
|
||||
normalize_embeddings: bool,
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
) -> list[list[float]]:
|
||||
model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||
if provider_type is not None:
|
||||
if api_key is None:
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
|
||||
cloud_model = CloudEmbedding(
|
||||
api_key=api_key, provider=provider_type, model=model_name
|
||||
)
|
||||
embeddings = cloud_model.encode(texts, model_name, text_type)
|
||||
|
||||
elif model_name is not None:
|
||||
hosted_model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings = hosted_model.encode(
|
||||
texts, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
|
||||
if embeddings is None:
|
||||
raise RuntimeError("Embeddings were not created")
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
embeddings = embeddings.tolist()
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
@ -113,6 +258,9 @@ async def process_embed_request(
|
||||
model_name=embed_request.model_name,
|
||||
max_context_length=embed_request.max_context_length,
|
||||
normalize_embeddings=embed_request.normalize_embeddings,
|
||||
api_key=embed_request.api_key,
|
||||
provider_type=embed_request.provider_type,
|
||||
text_type=embed_request.text_type,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
|
@ -7,3 +7,7 @@ tensorflow==2.15.0
|
||||
torch==2.0.1
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
openai==1.14.3
|
||||
cohere==5.5.8
|
||||
google-cloud-aiplatform==1.58.0
|
@ -1,12 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.search.enums import EmbedTextType
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
# This already includes any prefixes, the text is just passed directly to the model
|
||||
texts: list[str]
|
||||
model_name: str
|
||||
|
||||
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||
model_name: str | None
|
||||
max_context_length: int
|
||||
normalize_embeddings: bool
|
||||
api_key: str | None
|
||||
provider_type: str | None
|
||||
text_type: EmbedTextType
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
|
30
web/public/Cohere.svg
Normal file
30
web/public/Cohere.svg
Normal file
@ -0,0 +1,30 @@
|
||||
<svg version="1.1" id="Layer_1" xmlns:x="ns_extend;" xmlns:i="ns_ai;" xmlns:graph="ns_graphs;" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" viewBox="0 0 75 75" style="enable-background:new 0 0 75 75;" xml:space="preserve">
|
||||
<style type="text/css">
|
||||
.st0{fill-rule:evenodd;clip-rule:evenodd;fill:#39594D;}
|
||||
.st1{fill-rule:evenodd;clip-rule:evenodd;fill:#D18EE2;}
|
||||
.st2{fill:#FF7759;}
|
||||
</style>
|
||||
<metadata>
|
||||
<sfw xmlns="ns_sfw;">
|
||||
<slices>
|
||||
</slices>
|
||||
<sliceSourceBounds bottomLeftOrigin="true" height="75" width="75" x="-347.6" y="0.5">
|
||||
</sliceSourceBounds>
|
||||
</sfw>
|
||||
</metadata>
|
||||
<g>
|
||||
<g>
|
||||
<g>
|
||||
<path class="st0" d="M24.3,44.7c2,0,6-0.1,11.6-2.4c6.5-2.7,19.3-7.5,28.6-12.5c6.5-3.5,9.3-8.1,9.3-14.3C73.8,7,66.9,0,58.3,0
|
||||
h-36C10,0,0,10,0,22.3S9.4,44.7,24.3,44.7z">
|
||||
</path>
|
||||
<path class="st1" d="M30.4,60c0-6,3.6-11.5,9.2-13.8l11.3-4.7C62.4,36.8,75,45.2,75,57.6C75,67.2,67.2,75,57.6,75l-12.3,0
|
||||
C37.1,75,30.4,68.3,30.4,60z">
|
||||
</path>
|
||||
<path class="st2" d="M12.9,47.6L12.9,47.6C5.8,47.6,0,53.4,0,60.5v1.7C0,69.2,5.8,75,12.9,75h0c7.1,0,12.9-5.8,12.9-12.9v-1.7
|
||||
C25.7,53.4,20,47.6,12.9,47.6z">
|
||||
</path>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
After Width: | Height: | Size: 1.2 KiB |
BIN
web/public/Google.webp
Normal file
BIN
web/public/Google.webp
Normal file
Binary file not shown.
After Width: | Height: | Size: 6.4 KiB |
@ -1 +1 @@
|
||||
<svg viewBox="0 0 320 320" xmlns="http://www.w3.org/2000/svg"><path d="m297.06 130.97c7.26-21.79 4.76-45.66-6.85-65.48-17.46-30.4-52.56-46.04-86.84-38.68-15.25-17.18-37.16-26.95-60.13-26.81-35.04-.08-66.13 22.48-76.91 55.82-22.51 4.61-41.94 18.7-53.31 38.67-17.59 30.32-13.58 68.54 9.92 94.54-7.26 21.79-4.76 45.66 6.85 65.48 17.46 30.4 52.56 46.04 86.84 38.68 15.24 17.18 37.16 26.95 60.13 26.8 35.06.09 66.16-22.49 76.94-55.86 22.51-4.61 41.94-18.7 53.31-38.67 17.57-30.32 13.55-68.51-9.94-94.51zm-120.28 168.11c-14.03.02-27.62-4.89-38.39-13.88.49-.26 1.34-.73 1.89-1.07l63.72-36.8c3.26-1.85 5.26-5.32 5.24-9.07v-89.83l26.93 15.55c.29.14.48.42.52.74v74.39c-.04 33.08-26.83 59.9-59.91 59.97zm-128.84-55.03c-7.03-12.14-9.56-26.37-7.15-40.18.47.28 1.3.79 1.89 1.13l63.72 36.8c3.23 1.89 7.23 1.89 10.47 0l77.79-44.92v31.1c.02.32-.13.63-.38.83l-64.41 37.19c-28.69 16.52-65.33 6.7-81.92-21.95zm-16.77-139.09c7-12.16 18.05-21.46 31.21-26.29 0 .55-.03 1.52-.03 2.2v73.61c-.02 3.74 1.98 7.21 5.23 9.06l77.79 44.91-26.93 15.55c-.27.18-.61.21-.91.08l-64.42-37.22c-28.63-16.58-38.45-53.21-21.95-81.89zm221.26 51.49-77.79-44.92 26.93-15.54c.27-.18.61-.21.91-.08l64.42 37.19c28.68 16.57 38.51 53.26 21.94 81.94-7.01 12.14-18.05 21.44-31.2 26.28v-75.81c.03-3.74-1.96-7.2-5.2-9.06zm26.8-40.34c-.47-.29-1.3-.79-1.89-1.13l-63.72-36.8c-3.23-1.89-7.23-1.89-10.47 0l-77.79 44.92v-31.1c-.02-.32.13-.63.38-.83l64.41-37.16c28.69-16.55 65.37-6.7 81.91 22 6.99 12.12 9.52 26.31 7.15 40.1zm-168.51 55.43-26.94-15.55c-.29-.14-.48-.42-.52-.74v-74.39c.02-33.12 26.89-59.96 60.01-59.94 14.01 0 27.57 4.92 38.34 13.88-.49.26-1.33.73-1.89 1.07l-63.72 36.8c-3.26 1.85-5.26 5.31-5.24 9.06l-.04 89.79zm14.63-31.54 34.65-20.01 34.65 20v40.01l-34.65 20-34.65-20z"/></svg>
|
||||
<svg viewBox="0 0 320 320" xmlns="http://www.w3.org/2000/svg"><path d="m297.06 130.97c7.26-21.79 4.76-45.66-6.85-65.48-17.46-30.4-52.56-46.04-86.84-38.68-15.25-17.18-37.16-26.95-60.13-26.81-35.04-.08-66.13 22.48-76.91 55.82-22.51 4.61-41.94 18.7-53.31 38.67-17.59 30.32-13.58 68.54 9.92 94.54-7.26 21.79-4.76 45.66 6.85 65.48 17.46 30.4 52.56 46.04 86.84 38.68 15.24 17.18 37.16 26.95 60.13 26.8 35.06.09 66.16-22.49 76.94-55.86 22.51-4.61 41.94-18.7 53.31-38.67 17.57-30.32 13.55-68.51-9.94-94.51zm-120.28 168.11c-14.03.02-27.62-4.89-38.39-13.88.49-.26 1.34-.73 1.89-1.07l63.72-36.8c3.26-1.85 5.26-5.32 5.24-9.07v-89.83l26.93 15.55c.29.14.48.42.52.74v74.39c-.04 33.08-26.83 59.9-59.91 59.97zm-128.84-55.03c-7.03-12.14-9.56-26.37-7.15-40.18.47.28 1.3.79 1.89 1.13l63.72 36.8c3.23 1.89 7.23 1.89 10.47 0l77.79-44.92v31.1c.02.32-.13.63-.38.83l-64.41 37.19c-28.69 16.52-65.33 6.7-81.92-21.95zm-16.77-139.09c7-12.16 18.05-21.46 31.21-26.29 0 .55-.03 1.52-.03 2.2v73.61c-.02 3.74 1.98 7.21 5.23 9.06l77.79 44.91-26.93 15.55c-.27.18-.61.21-.91.08l-64.42-37.22c-28.63-16.58-38.45-53.21-21.95-81.89zm221.26 51.49-77.79-44.92 26.93-15.54c.27-.18.61-.21.91-.08l64.42 37.19c28.68 16.57 38.51 53.26 21.94 81.94-7.01 12.14-18.05 21.44-31.2 26.28v-75.81c.03-3.74-1.96-7.2-5.2-9.06zm26.8-40.34c-.47-.29-1.3-.79-1.89-1.13l-63.72-36.8c-3.23-1.89-7.23-1.89-10.47 0l-77.79 44.92v-31.1c-.02-.32.13-.63.38-.83l64.41-37.16c28.69-16.55 65.37-6.7 81.91 22 6.99 12.12 9.52 26.31 7.15 40.1zm-168.51 55.43-26.94-15.55c-.29-.14-.48-.42-.52-.74v-74.39c.02-33.12 26.89-59.96 60.01-59.94 14.01 0 27.57 4.92 38.34 13.88-.49.26-1.33.73-1.89 1.07l-63.72 36.8c-3.26 1.85-5.26 5.31-5.24 9.06l-.04 89.79zm14.63-31.54 34.65-20.01 34.65 20v40.01l-34.65 20-34.65-20z"/></svg>
|
||||
|
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 1.7 KiB |
BIN
web/public/Voyage.png
Normal file
BIN
web/public/Voyage.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 14 KiB |
163
web/src/app/admin/models/embedding/CloudEmbeddingPage.tsx
Normal file
163
web/src/app/admin/models/embedding/CloudEmbeddingPage.tsx
Normal file
@ -0,0 +1,163 @@
|
||||
"use client";
|
||||
|
||||
import { Text, Title } from "@tremor/react";
|
||||
|
||||
import {
|
||||
CloudEmbeddingProvider,
|
||||
CloudEmbeddingModel,
|
||||
AVAILABLE_CLOUD_PROVIDERS,
|
||||
CloudEmbeddingProviderFull,
|
||||
EmbeddingModelDescriptor,
|
||||
} from "./components/types";
|
||||
import { EmbeddingDetails } from "./page";
|
||||
import { FiInfo } from "react-icons/fi";
|
||||
import { HoverPopup } from "@/components/HoverPopup";
|
||||
import { Dispatch, SetStateAction } from "react";
|
||||
|
||||
export default function CloudEmbeddingPage({
|
||||
currentModel,
|
||||
embeddingProviderDetails,
|
||||
newEnabledProviders,
|
||||
newUnenabledProviders,
|
||||
setShowTentativeProvider,
|
||||
setChangeCredentialsProvider,
|
||||
setAlreadySelectedModel,
|
||||
setShowTentativeModel,
|
||||
setShowModelInQueue,
|
||||
}: {
|
||||
setShowModelInQueue: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||
setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||
currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel;
|
||||
setAlreadySelectedModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||
newUnenabledProviders: string[];
|
||||
embeddingProviderDetails?: EmbeddingDetails[];
|
||||
newEnabledProviders: string[];
|
||||
selectedModel: CloudEmbeddingProvider;
|
||||
|
||||
// create modal functions
|
||||
|
||||
setShowTentativeProvider: React.Dispatch<
|
||||
React.SetStateAction<CloudEmbeddingProvider | null>
|
||||
>;
|
||||
setChangeCredentialsProvider: React.Dispatch<
|
||||
React.SetStateAction<CloudEmbeddingProvider | null>
|
||||
>;
|
||||
}) {
|
||||
function hasNameInArray(
|
||||
arr: Array<{ name: string }>,
|
||||
searchName: string
|
||||
): boolean {
|
||||
return arr.some(
|
||||
(item) => item.name.toLowerCase() === searchName.toLowerCase()
|
||||
);
|
||||
}
|
||||
|
||||
let providers: CloudEmbeddingProviderFull[] = [];
|
||||
AVAILABLE_CLOUD_PROVIDERS.forEach((model, ind) => {
|
||||
let temporary_model: CloudEmbeddingProviderFull = {
|
||||
...model,
|
||||
configured:
|
||||
!newUnenabledProviders.includes(model.name) &&
|
||||
(newEnabledProviders.includes(model.name) ||
|
||||
(embeddingProviderDetails &&
|
||||
hasNameInArray(embeddingProviderDetails, model.name))!),
|
||||
};
|
||||
providers.push(temporary_model);
|
||||
});
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Title className="mt-8">
|
||||
Here are some cloud-based models to choose from.
|
||||
</Title>
|
||||
<Text className="mb-4">
|
||||
They require API keys and run in the clouds of the respective providers.
|
||||
</Text>
|
||||
|
||||
<div className="gap-4 mt-2 pb-10 flex content-start flex-wrap">
|
||||
{providers.map((provider, ind) => (
|
||||
<div
|
||||
key={ind}
|
||||
className="p-4 border border-border rounded-lg shadow-md bg-hover-light w-96 flex flex-col"
|
||||
>
|
||||
<div className="font-bold text-neutral-900 text-lg items-center py-1 gap-x-2 flex">
|
||||
{provider.icon({ size: 40 })}
|
||||
<p className="my-auto">{provider.name}</p>
|
||||
<button
|
||||
onClick={() => {
|
||||
setShowTentativeProvider(provider);
|
||||
}}
|
||||
className="cursor-pointer ml-auto"
|
||||
>
|
||||
<a className="my-auto hover:underline cursor-pointer">
|
||||
<HoverPopup
|
||||
mainContent={
|
||||
<FiInfo className="cusror-pointer" size={20} />
|
||||
}
|
||||
popupContent={
|
||||
<div className="text-sm text-neutral-800 w-52 flex">
|
||||
<div className="flex mx-auto">
|
||||
<div className="my-auto">{provider.description}</div>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
direction="left-top"
|
||||
style="dark"
|
||||
/>
|
||||
</a>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
{provider.embedding_models.map((model, index) => {
|
||||
const enabled = model.model_name == currentModel.model_name;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
className={`p-3 my-2 border-2 border-neutral-300 border-opacity-40 rounded-md rounded cursor-pointer
|
||||
${!provider.configured ? "opacity-80 hover:opacity-100" : enabled ? "bg-background-stronger" : "hover:bg-background-strong"}`}
|
||||
onClick={() => {
|
||||
if (enabled) {
|
||||
setAlreadySelectedModel(model);
|
||||
} else if (provider.configured) {
|
||||
setShowTentativeModel(model);
|
||||
} else {
|
||||
setShowModelInQueue(model);
|
||||
setShowTentativeProvider(provider);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="flex justify-between">
|
||||
<div className="font-medium text-sm">
|
||||
{model.model_name}
|
||||
</div>
|
||||
<p className="text-sm flex-none">
|
||||
${model.pricePerMillion}/M tokens
|
||||
</p>
|
||||
</div>
|
||||
<div className="text-sm text-gray-600">
|
||||
{model.description}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
<button
|
||||
onClick={() => {
|
||||
if (!provider.configured) {
|
||||
setShowTentativeProvider(provider);
|
||||
} else {
|
||||
setChangeCredentialsProvider(provider);
|
||||
}
|
||||
}}
|
||||
className="hover:underline mb-1 text-sm mr-auto cursor-pointer"
|
||||
>
|
||||
{provider.configured && "Modify credentials"}
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
@ -1,74 +0,0 @@
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button, Text, Callout } from "@tremor/react";
|
||||
import { EmbeddingModelDescriptor } from "./embeddingModels";
|
||||
|
||||
export function ModelSelectionConfirmaion({
|
||||
selectedModel,
|
||||
isCustom,
|
||||
onConfirm,
|
||||
}: {
|
||||
selectedModel: EmbeddingModelDescriptor;
|
||||
isCustom: boolean;
|
||||
onConfirm: () => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="mb-4">
|
||||
<Text className="text-lg mb-4">
|
||||
You have selected: <b>{selectedModel.model_name}</b>. Are you sure you
|
||||
want to update to this new embedding model?
|
||||
</Text>
|
||||
<Text className="text-lg mb-2">
|
||||
We will re-index all your documents in the background so you will be
|
||||
able to continue to use Danswer as normal with the old model in the
|
||||
meantime. Depending on how many documents you have indexed, this may
|
||||
take a while.
|
||||
</Text>
|
||||
<Text className="text-lg mb-2">
|
||||
<i>NOTE:</i> this re-indexing process will consume more resources than
|
||||
normal. If you are self-hosting, we recommend that you allocate at least
|
||||
16GB of RAM to Danswer during this process.
|
||||
</Text>
|
||||
|
||||
{isCustom && (
|
||||
<Callout title="IMPORTANT" color="yellow" className="mt-4">
|
||||
We've detected that this is a custom-specified embedding model.
|
||||
Since we have to download the model files before verifying the
|
||||
configuration's correctness, we won't be able to let you
|
||||
know if the configuration is valid until <b>after</b> we start
|
||||
re-indexing your documents. If there is an issue, it will show up on
|
||||
this page as an indexing error on this page after clicking Confirm.
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
<div className="flex mt-8">
|
||||
<Button className="mx-auto" color="green" onClick={onConfirm}>
|
||||
Confirm
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function ModelSelectionConfirmaionModal({
|
||||
selectedModel,
|
||||
isCustom,
|
||||
onConfirm,
|
||||
onCancel,
|
||||
}: {
|
||||
selectedModel: EmbeddingModelDescriptor;
|
||||
isCustom: boolean;
|
||||
onConfirm: () => void;
|
||||
onCancel: () => void;
|
||||
}) {
|
||||
return (
|
||||
<Modal title="Update Embedding Model" onOutsideClick={onCancel}>
|
||||
<div>
|
||||
<ModelSelectionConfirmaion
|
||||
selectedModel={selectedModel}
|
||||
isCustom={isCustom}
|
||||
onConfirm={onConfirm}
|
||||
/>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
55
web/src/app/admin/models/embedding/OpenEmbeddingPage.tsx
Normal file
55
web/src/app/admin/models/embedding/OpenEmbeddingPage.tsx
Normal file
@ -0,0 +1,55 @@
|
||||
"use client";
|
||||
import { Card, Text, Title } from "@tremor/react";
|
||||
import { ModelSelector } from "./components/ModelSelector";
|
||||
import {
|
||||
AVAILABLE_MODELS,
|
||||
EmbeddingModelDescriptor,
|
||||
HostedEmbeddingModel,
|
||||
} from "./components/types";
|
||||
import { CustomModelForm } from "./components/CustomModelForm";
|
||||
|
||||
export default function OpenEmbeddingPage({
|
||||
onSelectOpenSource,
|
||||
currentModelName,
|
||||
}: {
|
||||
currentModelName: string;
|
||||
onSelectOpenSource: (model: HostedEmbeddingModel) => Promise<void>;
|
||||
}) {
|
||||
return (
|
||||
<div>
|
||||
<ModelSelector
|
||||
modelOptions={AVAILABLE_MODELS.filter(
|
||||
(modelOption) => modelOption.model_name !== currentModelName
|
||||
)}
|
||||
setSelectedModel={onSelectOpenSource}
|
||||
/>
|
||||
|
||||
<Text className="mt-6">
|
||||
Alternatively, (if you know what you're doing) you can specify a{" "}
|
||||
<a target="_blank" href="https://www.sbert.net/" className="text-link">
|
||||
SentenceTransformers
|
||||
</a>
|
||||
-compatible model of your choice below. The rough list of supported
|
||||
models can be found{" "}
|
||||
<a
|
||||
target="_blank"
|
||||
href="https://huggingface.co/models?library=sentence-transformers&sort=trending"
|
||||
className="text-link"
|
||||
>
|
||||
here
|
||||
</a>
|
||||
.
|
||||
<br />
|
||||
<b>NOTE:</b> not all models listed will work with Danswer, since some
|
||||
have unique interfaces or special requirements. If in doubt, reach out
|
||||
to the Danswer team.
|
||||
</Text>
|
||||
|
||||
<div className="w-full flex">
|
||||
<Card className="mt-4 2xl:w-4/6 mx-auto">
|
||||
<CustomModelForm onSubmit={onSelectOpenSource} />
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
@ -2,16 +2,15 @@ import {
|
||||
BooleanFormField,
|
||||
TextFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { Button, Divider, Text } from "@tremor/react";
|
||||
import { Button } from "@tremor/react";
|
||||
import { Form, Formik } from "formik";
|
||||
|
||||
import * as Yup from "yup";
|
||||
import { EmbeddingModelDescriptor } from "./embeddingModels";
|
||||
import { EmbeddingModelDescriptor, HostedEmbeddingModel } from "./types";
|
||||
|
||||
export function CustomModelForm({
|
||||
onSubmit,
|
||||
}: {
|
||||
onSubmit: (model: EmbeddingModelDescriptor) => void;
|
||||
onSubmit: (model: HostedEmbeddingModel) => void;
|
||||
}) {
|
||||
return (
|
||||
<div>
|
||||
@ -21,6 +20,7 @@ export function CustomModelForm({
|
||||
model_dim: "",
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
description: "",
|
||||
normalize: true,
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
@ -62,6 +62,13 @@ export function CustomModelForm({
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<TextFormField
|
||||
name="description"
|
||||
label="Description:"
|
||||
subtext="Description of your model"
|
||||
placeholder=""
|
||||
autoCompleteDisabled={true}
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="query_prefix"
|
||||
@ -77,7 +84,6 @@ export function CustomModelForm({
|
||||
placeholder="E.g. 'query: '"
|
||||
autoCompleteDisabled={true}
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="passage_prefix"
|
||||
label="[Optional] Passage Prefix:"
|
@ -1,18 +1,29 @@
|
||||
import { DefaultDropdown, StringOrNumberOption } from "@/components/Dropdown";
|
||||
import { Title, Text, Divider, Card } from "@tremor/react";
|
||||
import {
|
||||
EmbeddingModelDescriptor,
|
||||
FullEmbeddingModelDescriptor,
|
||||
} from "./embeddingModels";
|
||||
import { EmbeddingModelDescriptor, HostedEmbeddingModel } from "./types";
|
||||
import { FiStar } from "react-icons/fi";
|
||||
import { CustomModelForm } from "./CustomModelForm";
|
||||
|
||||
export function ModelPreview({ model }: { model: EmbeddingModelDescriptor }) {
|
||||
return (
|
||||
<div
|
||||
className={
|
||||
"p-2 border border-border rounded shadow-md bg-hover-light w-96 flex flex-col"
|
||||
}
|
||||
>
|
||||
<div className="font-bold text-lg flex">{model.model_name}</div>
|
||||
<div className="text-sm mt-1 mx-1">
|
||||
{model.description
|
||||
? model.description
|
||||
: "Custom model—no description is available."}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function ModelOption({
|
||||
model,
|
||||
onSelect,
|
||||
}: {
|
||||
model: FullEmbeddingModelDescriptor;
|
||||
onSelect?: (model: EmbeddingModelDescriptor) => void;
|
||||
model: HostedEmbeddingModel;
|
||||
onSelect?: (model: HostedEmbeddingModel) => void;
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
@ -68,8 +79,8 @@ export function ModelSelector({
|
||||
modelOptions,
|
||||
setSelectedModel,
|
||||
}: {
|
||||
modelOptions: FullEmbeddingModelDescriptor[];
|
||||
setSelectedModel: (model: EmbeddingModelDescriptor) => void;
|
||||
modelOptions: HostedEmbeddingModel[];
|
||||
setSelectedModel: (model: HostedEmbeddingModel) => void;
|
||||
}) {
|
||||
return (
|
||||
<div>
|
286
web/src/app/admin/models/embedding/components/types.ts
Normal file
286
web/src/app/admin/models/embedding/components/types.ts
Normal file
@ -0,0 +1,286 @@
|
||||
import {
|
||||
CohereIcon,
|
||||
GoogleIcon,
|
||||
IconProps,
|
||||
OpenAIIcon,
|
||||
VoyageIcon,
|
||||
} from "@/components/icons/icons";
|
||||
|
||||
// Cloud Provider (not needed for hosted ones)
|
||||
|
||||
export interface CloudEmbeddingProvider {
|
||||
id: number;
|
||||
name: string;
|
||||
api_key?: string;
|
||||
custom_config?: Record<string, string>;
|
||||
docsLink?: string;
|
||||
|
||||
// Frontend-specific properties
|
||||
website: string;
|
||||
icon: ({ size, className }: IconProps) => JSX.Element;
|
||||
description: string;
|
||||
apiLink: string;
|
||||
costslink?: string;
|
||||
|
||||
// Relationships
|
||||
embedding_models: CloudEmbeddingModel[];
|
||||
default_model?: CloudEmbeddingModel;
|
||||
}
|
||||
|
||||
// Embedding Models
|
||||
export interface EmbeddingModelDescriptor {
|
||||
model_name: string;
|
||||
model_dim: number;
|
||||
normalize: boolean;
|
||||
query_prefix: string;
|
||||
passage_prefix: string;
|
||||
cloud_provider_name?: string | null;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export interface CloudEmbeddingModel extends EmbeddingModelDescriptor {
|
||||
cloud_provider_name: string | null;
|
||||
pricePerMillion: number;
|
||||
enabled?: boolean;
|
||||
mtebScore: number;
|
||||
maxContext: number;
|
||||
}
|
||||
|
||||
export interface HostedEmbeddingModel extends EmbeddingModelDescriptor {
|
||||
link?: string;
|
||||
model_dim: number;
|
||||
normalize: boolean;
|
||||
query_prefix: string;
|
||||
passage_prefix: string;
|
||||
isDefault?: boolean;
|
||||
}
|
||||
|
||||
// Responses
|
||||
export interface FullEmbeddingModelResponse {
|
||||
current_model_name: string;
|
||||
secondary_model_name: string | null;
|
||||
}
|
||||
|
||||
export interface CloudEmbeddingProviderFull extends CloudEmbeddingProvider {
|
||||
configured: boolean;
|
||||
}
|
||||
|
||||
export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
{
|
||||
model_name: "intfloat/e5-base-v2",
|
||||
model_dim: 768,
|
||||
normalize: true,
|
||||
description:
|
||||
"The recommended default for most situations. If you aren't sure which model to use, this is probably the one.",
|
||||
isDefault: true,
|
||||
link: "https://huggingface.co/intfloat/e5-base-v2",
|
||||
query_prefix: "query: ",
|
||||
passage_prefix: "passage: ",
|
||||
},
|
||||
{
|
||||
model_name: "intfloat/e5-small-v2",
|
||||
model_dim: 384,
|
||||
normalize: true,
|
||||
description:
|
||||
"A smaller / faster version of the default model. If you're running Danswer on a resource constrained system, then this is a good choice.",
|
||||
link: "https://huggingface.co/intfloat/e5-small-v2",
|
||||
query_prefix: "query: ",
|
||||
passage_prefix: "passage: ",
|
||||
},
|
||||
{
|
||||
model_name: "intfloat/multilingual-e5-base",
|
||||
model_dim: 768,
|
||||
normalize: true,
|
||||
description:
|
||||
"If you have many documents in other languages besides English, this is the one to go for.",
|
||||
link: "https://huggingface.co/intfloat/multilingual-e5-base",
|
||||
query_prefix: "query: ",
|
||||
passage_prefix: "passage: ",
|
||||
},
|
||||
{
|
||||
model_name: "intfloat/multilingual-e5-small",
|
||||
model_dim: 384,
|
||||
normalize: true,
|
||||
description:
|
||||
"If you have many documents in other languages besides English, and you're running on a resource constrained system, then this is the one to go for.",
|
||||
link: "https://huggingface.co/intfloat/multilingual-e5-base",
|
||||
query_prefix: "query: ",
|
||||
passage_prefix: "passage: ",
|
||||
},
|
||||
];
|
||||
|
||||
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
{
|
||||
id: 0,
|
||||
name: "OpenAI",
|
||||
website: "https://openai.com",
|
||||
icon: OpenAIIcon,
|
||||
description: "AI industry leader known for ChatGPT and DALL-E",
|
||||
apiLink: "https://platform.openai.com/api-keys",
|
||||
docsLink:
|
||||
"https://docs.danswer.dev/guides/embedding_providers#openai-models",
|
||||
costslink: "https://openai.com/pricing",
|
||||
embedding_models: [
|
||||
{
|
||||
model_name: "text-embedding-3-large",
|
||||
cloud_provider_name: "OpenAI",
|
||||
description:
|
||||
"OpenAI's large embedding model. Best performance, but more expensive.",
|
||||
pricePerMillion: 0.13,
|
||||
model_dim: 3072,
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
mtebScore: 64.6,
|
||||
maxContext: 8191,
|
||||
enabled: false,
|
||||
},
|
||||
{
|
||||
model_name: "text-embedding-3-small",
|
||||
cloud_provider_name: "OpenAI",
|
||||
model_dim: 1536,
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
description:
|
||||
"OpenAI's newer, more efficient embedding model. Good balance of performance and cost.",
|
||||
pricePerMillion: 0.02,
|
||||
enabled: false,
|
||||
mtebScore: 62.3,
|
||||
maxContext: 8191,
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: 1,
|
||||
name: "Cohere",
|
||||
website: "https://cohere.ai",
|
||||
icon: CohereIcon,
|
||||
docsLink:
|
||||
"https://docs.danswer.dev/guides/embedding_providers#cohere-models",
|
||||
description:
|
||||
"AI company specializing in NLP models for various text-based tasks",
|
||||
apiLink: "https://dashboard.cohere.ai/api-keys",
|
||||
costslink: "https://cohere.com/pricing",
|
||||
embedding_models: [
|
||||
{
|
||||
model_name: "embed-english-v3.0",
|
||||
cloud_provider_name: "Cohere",
|
||||
description:
|
||||
"Cohere's English embedding model. Good performance for English-language tasks.",
|
||||
pricePerMillion: 0.1,
|
||||
mtebScore: 64.5,
|
||||
maxContext: 512,
|
||||
enabled: false,
|
||||
model_dim: 1024,
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
},
|
||||
{
|
||||
model_name: "embed-english-light-v3.0",
|
||||
cloud_provider_name: "Cohere",
|
||||
description:
|
||||
"Cohere's lightweight English embedding model. Faster and more efficient for simpler tasks.",
|
||||
pricePerMillion: 0.1,
|
||||
mtebScore: 62,
|
||||
maxContext: 512,
|
||||
enabled: false,
|
||||
model_dim: 384,
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
{
|
||||
id: 2,
|
||||
name: "Google",
|
||||
website: "https://ai.google",
|
||||
icon: GoogleIcon,
|
||||
docsLink:
|
||||
"https://docs.danswer.dev/guides/embedding_providers#vertex-ai-google-model",
|
||||
description:
|
||||
"Offers a wide range of AI services including language and vision models",
|
||||
apiLink: "https://console.cloud.google.com/apis/credentials",
|
||||
costslink: "https://cloud.google.com/vertex-ai/pricing",
|
||||
embedding_models: [
|
||||
{
|
||||
cloud_provider_name: "Google",
|
||||
model_name: "text-embedding-004",
|
||||
description: "Google's most recent text embedding model.",
|
||||
pricePerMillion: 0.025,
|
||||
mtebScore: 66.31,
|
||||
maxContext: 2048,
|
||||
enabled: false,
|
||||
model_dim: 768,
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
},
|
||||
{
|
||||
cloud_provider_name: "Google",
|
||||
model_name: "textembedding-gecko@003",
|
||||
description: "Google's Gecko embedding model. Powerful and efficient.",
|
||||
pricePerMillion: 0.025,
|
||||
mtebScore: 66.31,
|
||||
maxContext: 2048,
|
||||
enabled: false,
|
||||
model_dim: 768,
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: 3,
|
||||
name: "Voyage",
|
||||
website: "https://www.voyageai.com",
|
||||
icon: VoyageIcon,
|
||||
description: "Advanced NLP research startup born from Stanford AI Labs",
|
||||
docsLink:
|
||||
"https://docs.danswer.dev/guides/embedding_providers#voyage-models",
|
||||
apiLink: "https://www.voyageai.com/dashboard",
|
||||
costslink: "https://www.voyageai.com/pricing",
|
||||
embedding_models: [
|
||||
{
|
||||
cloud_provider_name: "Voyage",
|
||||
model_name: "voyage-large-2-instruct",
|
||||
description:
|
||||
"Voyage's large embedding model. High performance with instruction fine-tuning.",
|
||||
pricePerMillion: 0.12,
|
||||
mtebScore: 68.28,
|
||||
maxContext: 4000,
|
||||
enabled: false,
|
||||
model_dim: 1024,
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
},
|
||||
{
|
||||
cloud_provider_name: "Voyage",
|
||||
model_name: "voyage-light-2-instruct",
|
||||
description:
|
||||
"Voyage's lightweight embedding model. Good balance of performance and efficiency.",
|
||||
pricePerMillion: 0.12,
|
||||
mtebScore: 67.13,
|
||||
maxContext: 16000,
|
||||
enabled: false,
|
||||
model_dim: 1024,
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
export const INVALID_OLD_MODEL = "thenlper/gte-small";
|
||||
|
||||
export function checkModelNameIsValid(
|
||||
modelName: string | undefined | null
|
||||
): boolean {
|
||||
return !!modelName && modelName !== INVALID_OLD_MODEL;
|
||||
}
|
@ -1,87 +0,0 @@
|
||||
export interface EmbeddingModelResponse {
|
||||
model_name: string | null;
|
||||
}
|
||||
|
||||
export interface FullEmbeddingModelResponse {
|
||||
current_model_name: string;
|
||||
secondary_model_name: string | null;
|
||||
}
|
||||
|
||||
export interface EmbeddingModelDescriptor {
|
||||
model_name: string;
|
||||
model_dim: number;
|
||||
normalize: boolean;
|
||||
query_prefix?: string;
|
||||
passage_prefix?: string;
|
||||
}
|
||||
|
||||
export interface FullEmbeddingModelDescriptor extends EmbeddingModelDescriptor {
|
||||
description: string;
|
||||
isDefault?: boolean;
|
||||
link?: string;
|
||||
}
|
||||
|
||||
export const AVAILABLE_MODELS: FullEmbeddingModelDescriptor[] = [
|
||||
{
|
||||
model_name: "intfloat/e5-base-v2",
|
||||
model_dim: 768,
|
||||
normalize: true,
|
||||
description:
|
||||
"The recommended default for most situations. If you aren't sure which model to use, this is probably the one.",
|
||||
isDefault: true,
|
||||
link: "https://huggingface.co/intfloat/e5-base-v2",
|
||||
query_prefix: "query: ",
|
||||
passage_prefix: "passage: ",
|
||||
},
|
||||
{
|
||||
model_name: "intfloat/e5-small-v2",
|
||||
model_dim: 384,
|
||||
normalize: true,
|
||||
description:
|
||||
"A smaller / faster version of the default model. If you're running Danswer on a resource constrained system, then this is a good choice.",
|
||||
link: "https://huggingface.co/intfloat/e5-small-v2",
|
||||
query_prefix: "query: ",
|
||||
passage_prefix: "passage: ",
|
||||
},
|
||||
{
|
||||
model_name: "intfloat/multilingual-e5-base",
|
||||
model_dim: 768,
|
||||
normalize: true,
|
||||
description:
|
||||
"If you have many documents in other languages besides English, this is the one to go for.",
|
||||
link: "https://huggingface.co/intfloat/multilingual-e5-base",
|
||||
query_prefix: "query: ",
|
||||
passage_prefix: "passage: ",
|
||||
},
|
||||
{
|
||||
model_name: "intfloat/multilingual-e5-small",
|
||||
model_dim: 384,
|
||||
normalize: true,
|
||||
description:
|
||||
"If you have many documents in other languages besides English, and you're running on a resource constrained system, then this is the one to go for.",
|
||||
link: "https://huggingface.co/intfloat/multilingual-e5-base",
|
||||
query_prefix: "query: ",
|
||||
passage_prefix: "passage: ",
|
||||
},
|
||||
];
|
||||
|
||||
export const INVALID_OLD_MODEL = "thenlper/gte-small";
|
||||
|
||||
export function checkModelNameIsValid(modelName: string | undefined | null) {
|
||||
if (!modelName) {
|
||||
return false;
|
||||
}
|
||||
if (modelName === INVALID_OLD_MODEL) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
export function fillOutEmeddingModelDescriptor(
|
||||
embeddingModel: EmbeddingModelDescriptor | FullEmbeddingModelDescriptor
|
||||
): FullEmbeddingModelDescriptor {
|
||||
return {
|
||||
...embeddingModel,
|
||||
description: "",
|
||||
};
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
import React from "react";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button, Text } from "@tremor/react";
|
||||
|
||||
import { CloudEmbeddingModel } from "../components/types";
|
||||
|
||||
export function AlreadyPickedModal({
|
||||
model,
|
||||
onClose,
|
||||
}: {
|
||||
model: CloudEmbeddingModel;
|
||||
onClose: () => void;
|
||||
}) {
|
||||
return (
|
||||
<Modal
|
||||
title={`${model.model_name} already chosen`}
|
||||
onOutsideClick={onClose}
|
||||
>
|
||||
<div className="mb-4">
|
||||
<Text className="text-sm mb-2">
|
||||
You can select a different one if you want!
|
||||
</Text>
|
||||
<div className="flex mt-8 justify-between">
|
||||
<Button color="blue" onClick={onClose}>
|
||||
Close
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
@ -0,0 +1,244 @@
|
||||
import React, { useRef, useState } from "react";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button, Text, Callout, Subtitle, Divider } from "@tremor/react";
|
||||
import { Label, TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { CloudEmbeddingProvider } from "../components/types";
|
||||
import {
|
||||
EMBEDDING_PROVIDERS_ADMIN_URL,
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
} from "../../llm/constants";
|
||||
import { mutate } from "swr";
|
||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { Field } from "formik";
|
||||
|
||||
export function ChangeCredentialsModal({
|
||||
provider,
|
||||
onConfirm,
|
||||
onCancel,
|
||||
onDeleted,
|
||||
useFileUpload,
|
||||
}: {
|
||||
provider: CloudEmbeddingProvider;
|
||||
onConfirm: () => void;
|
||||
onCancel: () => void;
|
||||
onDeleted: () => void;
|
||||
useFileUpload: boolean;
|
||||
}) {
|
||||
const [apiKey, setApiKey] = useState("");
|
||||
const [testError, setTestError] = useState<string>("");
|
||||
const [fileName, setFileName] = useState<string>("");
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
const [isProcessing, setIsProcessing] = useState(false);
|
||||
const [deletionError, setDeletionError] = useState<string>("");
|
||||
|
||||
const clearFileInput = () => {
|
||||
setFileName("");
|
||||
if (fileInputRef.current) {
|
||||
fileInputRef.current.value = "";
|
||||
}
|
||||
};
|
||||
|
||||
const handleFileUpload = async (
|
||||
event: React.ChangeEvent<HTMLInputElement>
|
||||
) => {
|
||||
const file = event.target.files?.[0];
|
||||
setFileName("");
|
||||
|
||||
if (file) {
|
||||
setFileName(file.name);
|
||||
try {
|
||||
setDeletionError("");
|
||||
const fileContent = await file.text();
|
||||
let jsonContent;
|
||||
try {
|
||||
jsonContent = JSON.parse(fileContent);
|
||||
setApiKey(JSON.stringify(jsonContent));
|
||||
} catch (parseError) {
|
||||
throw new Error(
|
||||
"Failed to parse JSON file. Please ensure it's a valid JSON."
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
setTestError(
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: "An unknown error occurred while processing the file."
|
||||
);
|
||||
setApiKey("");
|
||||
clearFileInput();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleDelete = async () => {
|
||||
setDeletionError("");
|
||||
setIsProcessing(true);
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.name}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json();
|
||||
setDeletionError(errorData.detail);
|
||||
return;
|
||||
}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onDeleted();
|
||||
} catch (error) {
|
||||
setDeletionError(
|
||||
error instanceof Error ? error.message : "An unknown error occurred"
|
||||
);
|
||||
} finally {
|
||||
setIsProcessing(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = async () => {
|
||||
setTestError("");
|
||||
|
||||
try {
|
||||
const body = JSON.stringify({
|
||||
api_key: apiKey,
|
||||
provider: provider.name.toLowerCase().split(" ")[0],
|
||||
default_model_id: provider.name,
|
||||
});
|
||||
|
||||
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider: provider.name.toLowerCase().split(" ")[0],
|
||||
api_key: apiKey,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!testResponse.ok) {
|
||||
const errorMsg = (await testResponse.json()).detail;
|
||||
throw new Error(errorMsg);
|
||||
}
|
||||
|
||||
const updateResponse = await fetch(EMBEDDING_PROVIDERS_ADMIN_URL, {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
name: provider.name,
|
||||
api_key: apiKey,
|
||||
is_default_provider: false,
|
||||
is_configured: true,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!updateResponse.ok) {
|
||||
const errorData = await updateResponse.json();
|
||||
throw new Error(
|
||||
errorData.detail || "Failed to update provider- check your API key"
|
||||
);
|
||||
}
|
||||
|
||||
onConfirm();
|
||||
} catch (error) {
|
||||
setTestError(
|
||||
error instanceof Error ? error.message : "An unknown error occurred"
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
icon={provider.icon}
|
||||
title={`Modify your ${provider.name} key`}
|
||||
onOutsideClick={onCancel}
|
||||
>
|
||||
<div className="mb-4">
|
||||
<Subtitle className="mt-4 font-bold text-lg mb-2">
|
||||
Want to swap out your key?
|
||||
</Subtitle>
|
||||
|
||||
<div className="flex flex-col gap-y-2">
|
||||
{useFileUpload ? (
|
||||
<>
|
||||
<Label>Upload JSON File</Label>
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept=".json"
|
||||
onChange={handleFileUpload}
|
||||
className="text-lg w-full p-1"
|
||||
/>
|
||||
{fileName && <p>Uploaded file: {fileName}</p>}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="flex gap-x-2 items-center">
|
||||
<Label>New API Key</Label>
|
||||
</div>
|
||||
<input
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mt-1
|
||||
bg-background-emphasis
|
||||
`}
|
||||
value={apiKey}
|
||||
onChange={(e: any) => setApiKey(e.target.value)}
|
||||
placeholder="Paste your API key here"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
<a
|
||||
href={provider.apiLink}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline cursor-pointer"
|
||||
>
|
||||
Visit API
|
||||
</a>
|
||||
</div>
|
||||
|
||||
{testError && (
|
||||
<Callout title="Error" color="red" className="mt-4">
|
||||
{testError}
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
<div className="flex mt-8 justify-between">
|
||||
<Button
|
||||
color="blue"
|
||||
onClick={() => handleSubmit()}
|
||||
disabled={!apiKey}
|
||||
>
|
||||
Execute Key Swap
|
||||
</Button>
|
||||
</div>
|
||||
<Divider />
|
||||
|
||||
<Subtitle className="mt-4 font-bold text-lg mb-2">
|
||||
You can also delete your key.
|
||||
</Subtitle>
|
||||
<Text className="mb-2">
|
||||
This is only possible if you have already switched to a different
|
||||
embedding type!
|
||||
</Text>
|
||||
|
||||
<Button onClick={handleDelete} color="red">
|
||||
Delete key
|
||||
</Button>
|
||||
{deletionError && (
|
||||
<Callout title="Error" color="red" className="mt-4">
|
||||
{deletionError}
|
||||
</Callout>
|
||||
)}
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
@ -0,0 +1,41 @@
|
||||
import React from "react";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button, Text, Callout } from "@tremor/react";
|
||||
import { CloudEmbeddingProvider } from "../components/types";
|
||||
|
||||
export function DeleteCredentialsModal({
|
||||
modelProvider,
|
||||
onConfirm,
|
||||
onCancel,
|
||||
}: {
|
||||
modelProvider: CloudEmbeddingProvider;
|
||||
onConfirm: () => void;
|
||||
onCancel: () => void;
|
||||
}) {
|
||||
return (
|
||||
<Modal
|
||||
title={`Nuke ${modelProvider.name} Credentials?`}
|
||||
onOutsideClick={onCancel}
|
||||
>
|
||||
<div className="mb-4">
|
||||
<Text className="text-lg mb-2">
|
||||
You're about to delete your {modelProvider.name} credentials. Are
|
||||
you sure?
|
||||
</Text>
|
||||
<Callout
|
||||
title="Point of No Return"
|
||||
color="red"
|
||||
className="mt-4"
|
||||
></Callout>
|
||||
<div className="flex mt-8 justify-between">
|
||||
<Button color="gray" onClick={onCancel}>
|
||||
Keep Credentaisl
|
||||
</Button>
|
||||
<Button color="red" onClick={onConfirm}>
|
||||
Delete Credentials
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
@ -0,0 +1,61 @@
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button, Text, Callout } from "@tremor/react";
|
||||
import {
|
||||
EmbeddingModelDescriptor,
|
||||
HostedEmbeddingModel,
|
||||
} from "../components/types";
|
||||
|
||||
export function ModelSelectionConfirmationModal({
|
||||
selectedModel,
|
||||
isCustom,
|
||||
onConfirm,
|
||||
onCancel,
|
||||
}: {
|
||||
selectedModel: HostedEmbeddingModel;
|
||||
isCustom: boolean;
|
||||
onConfirm: () => void;
|
||||
onCancel: () => void;
|
||||
}) {
|
||||
return (
|
||||
<Modal title="Update Embedding Model" onOutsideClick={onCancel}>
|
||||
<div>
|
||||
<div className="mb-4">
|
||||
<Text className="text-lg mb-4">
|
||||
You have selected: <b>{selectedModel.model_name}</b>. Are you sure
|
||||
you want to update to this new embedding model?
|
||||
</Text>
|
||||
<Text className="text-lg mb-2">
|
||||
We will re-index all your documents in the background so you will be
|
||||
able to continue to use Danswer as normal with the old model in the
|
||||
meantime. Depending on how many documents you have indexed, this may
|
||||
take a while.
|
||||
</Text>
|
||||
<Text className="text-lg mb-2">
|
||||
<i>NOTE:</i> this re-indexing process will consume more resources
|
||||
than normal. If you are self-hosting, we recommend that you allocate
|
||||
at least 16GB of RAM to Danswer during this process.
|
||||
</Text>
|
||||
|
||||
{/* TODO Change this back- ensure functional */}
|
||||
{!isCustom && (
|
||||
<Callout title="IMPORTANT" color="yellow" className="mt-4">
|
||||
We've detected that this is a custom-specified embedding
|
||||
model. Since we have to download the model files before verifying
|
||||
the configuration's correctness, we won't be able to let
|
||||
you know if the configuration is valid until <b>after</b> we start
|
||||
re-indexing your documents. If there is an issue, it will show up
|
||||
on this page as an indexing error on this page after clicking
|
||||
Confirm.
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
<div className="flex mt-8">
|
||||
<Button className="mx-auto" color="green" onClick={onConfirm}>
|
||||
Confirm
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
@ -0,0 +1,232 @@
|
||||
import React, { useRef, useState } from "react";
|
||||
import { Text, Button, Callout } from "@tremor/react";
|
||||
import { Formik, Form, Field } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { Label, TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { CloudEmbeddingProvider } from "../components/types";
|
||||
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../../llm/constants";
|
||||
import { Modal } from "@/components/Modal";
|
||||
|
||||
export function ProviderCreationModal({
|
||||
selectedProvider,
|
||||
onConfirm,
|
||||
onCancel,
|
||||
existingProvider,
|
||||
}: {
|
||||
selectedProvider: CloudEmbeddingProvider;
|
||||
onConfirm: () => void;
|
||||
onCancel: () => void;
|
||||
existingProvider?: CloudEmbeddingProvider;
|
||||
}) {
|
||||
const useFileUpload = selectedProvider.name == "Google";
|
||||
|
||||
const [isProcessing, setIsProcessing] = useState(false);
|
||||
const [errorMsg, setErrorMsg] = useState<string>("");
|
||||
const [fileName, setFileName] = useState<string>("");
|
||||
|
||||
const initialValues = {
|
||||
name: existingProvider?.name || selectedProvider.name,
|
||||
api_key: existingProvider?.api_key || "",
|
||||
custom_config: existingProvider?.custom_config
|
||||
? Object.entries(existingProvider.custom_config)
|
||||
: [],
|
||||
default_model_name: "",
|
||||
model_id: 0,
|
||||
};
|
||||
|
||||
const validationSchema = Yup.object({
|
||||
name: Yup.string().required("Name is required"),
|
||||
api_key: useFileUpload
|
||||
? Yup.string()
|
||||
: Yup.string().required("API Key is required"),
|
||||
custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)),
|
||||
});
|
||||
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const handleFileUpload = async (
|
||||
event: React.ChangeEvent<HTMLInputElement>,
|
||||
setFieldValue: (field: string, value: any) => void
|
||||
) => {
|
||||
const file = event.target.files?.[0];
|
||||
setFileName("");
|
||||
if (file) {
|
||||
setFileName(file.name);
|
||||
try {
|
||||
const fileContent = await file.text();
|
||||
let jsonContent;
|
||||
try {
|
||||
jsonContent = JSON.parse(fileContent);
|
||||
} catch (parseError) {
|
||||
throw new Error(
|
||||
"Failed to parse JSON file. Please ensure it's a valid JSON."
|
||||
);
|
||||
}
|
||||
setFieldValue("api_key", JSON.stringify(jsonContent));
|
||||
} catch (error) {
|
||||
setFieldValue("api_key", "");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = async (
|
||||
values: any,
|
||||
{ setSubmitting }: { setSubmitting: (isSubmitting: boolean) => void }
|
||||
) => {
|
||||
setIsProcessing(true);
|
||||
setErrorMsg("");
|
||||
|
||||
try {
|
||||
const customConfig = Object.fromEntries(values.custom_config);
|
||||
|
||||
const initialResponse = await fetch(
|
||||
"/api/admin/embedding/test-embedding",
|
||||
{
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider: values.name.toLowerCase().split(" ")[0],
|
||||
api_key: values.api_key,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
if (!initialResponse.ok) {
|
||||
const errorMsg = (await initialResponse.json()).detail;
|
||||
setErrorMsg(errorMsg);
|
||||
setIsProcessing(false);
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await fetch(EMBEDDING_PROVIDERS_ADMIN_URL, {
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
...values,
|
||||
custom_config: customConfig,
|
||||
is_default_provider: false,
|
||||
is_configured: true,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json();
|
||||
throw new Error(
|
||||
errorData.detail || "Failed to update provider- check your API key"
|
||||
);
|
||||
}
|
||||
|
||||
onConfirm();
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof Error) {
|
||||
setErrorMsg(error.message);
|
||||
} else {
|
||||
setErrorMsg("An unknown error occurred");
|
||||
}
|
||||
} finally {
|
||||
setIsProcessing(false);
|
||||
setSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={`Configure ${selectedProvider.name}`}
|
||||
onOutsideClick={onCancel}
|
||||
icon={selectedProvider.icon}
|
||||
>
|
||||
<div>
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
onSubmit={handleSubmit}
|
||||
>
|
||||
{({
|
||||
values,
|
||||
errors,
|
||||
touched,
|
||||
isSubmitting,
|
||||
handleSubmit,
|
||||
setFieldValue,
|
||||
}) => (
|
||||
<Form onSubmit={handleSubmit} className="space-y-4">
|
||||
<Text className="text-lg mb-2">
|
||||
You are setting the credentials for this provider. To access
|
||||
this information, follow the instructions{" "}
|
||||
<a
|
||||
className="cursor-pointer underline"
|
||||
target="_blank"
|
||||
href={selectedProvider.docsLink}
|
||||
>
|
||||
here
|
||||
</a>{" "}
|
||||
and gather your{" "}
|
||||
<a
|
||||
className="cursor-pointer underline"
|
||||
target="_blank"
|
||||
href={selectedProvider.apiLink}
|
||||
>
|
||||
API KEY
|
||||
</a>
|
||||
</Text>
|
||||
|
||||
<div className="flex flex-col gap-y-2">
|
||||
{useFileUpload ? (
|
||||
<>
|
||||
<Label>Upload JSON File</Label>
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept=".json"
|
||||
onChange={(e) => handleFileUpload(e, setFieldValue)}
|
||||
className="text-lg w-full p-1"
|
||||
/>
|
||||
{fileName && <p>Uploaded file: {fileName}</p>}
|
||||
</>
|
||||
) : (
|
||||
<TextFormField
|
||||
name="api_key"
|
||||
label="API Key"
|
||||
placeholder="API Key"
|
||||
type="password"
|
||||
/>
|
||||
)}
|
||||
|
||||
<a
|
||||
href={selectedProvider.apiLink}
|
||||
target="_blank"
|
||||
className="underline cursor-pointer"
|
||||
>
|
||||
Learn more here
|
||||
</a>
|
||||
</div>
|
||||
|
||||
{errorMsg && (
|
||||
<Callout title="Error" color="red">
|
||||
{errorMsg}
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
color="blue"
|
||||
className="w-full"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isProcessing ? (
|
||||
<LoadingAnimation />
|
||||
) : existingProvider ? (
|
||||
"Update"
|
||||
) : (
|
||||
"Create"
|
||||
)}
|
||||
</Button>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
@ -0,0 +1,37 @@
|
||||
import React from "react";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button, Text, Callout } from "@tremor/react";
|
||||
import { CloudEmbeddingModel } from "../components/types";
|
||||
|
||||
export function SelectModelModal({
|
||||
model,
|
||||
onConfirm,
|
||||
onCancel,
|
||||
}: {
|
||||
model: CloudEmbeddingModel;
|
||||
onConfirm: () => void;
|
||||
onCancel: () => void;
|
||||
}) {
|
||||
return (
|
||||
<Modal
|
||||
title={`Elevate Your Game with ${model.model_name}`}
|
||||
onOutsideClick={onCancel}
|
||||
>
|
||||
<div className="mb-4">
|
||||
<Text className="text-lg mb-2">
|
||||
You're about to set your embedding model to {model.model_name}.
|
||||
<br />
|
||||
Are you sure?
|
||||
</Text>
|
||||
<div className="flex mt-8 justify-between">
|
||||
<Button color="gray" onClick={onCancel}>
|
||||
Exit
|
||||
</Button>
|
||||
<Button color="green" onClick={onConfirm}>
|
||||
Continue
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
@ -3,28 +3,75 @@
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { Button, Card, Text, Title } from "@tremor/react";
|
||||
import { Button, Text, Title } from "@tremor/react";
|
||||
import { FiPackage } from "react-icons/fi";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { ModelOption, ModelSelector } from "./ModelSelector";
|
||||
import { ModelOption, ModelPreview } from "./components/ModelSelector";
|
||||
import { useState } from "react";
|
||||
import { ModelSelectionConfirmaionModal } from "./ModelSelectionConfirmation";
|
||||
import { ReindexingProgressTable } from "./ReindexingProgressTable";
|
||||
import { ReindexingProgressTable } from "./components/ReindexingProgressTable";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import {
|
||||
CloudEmbeddingProvider,
|
||||
CloudEmbeddingModel,
|
||||
AVAILABLE_CLOUD_PROVIDERS,
|
||||
AVAILABLE_MODELS,
|
||||
EmbeddingModelDescriptor,
|
||||
INVALID_OLD_MODEL,
|
||||
fillOutEmeddingModelDescriptor,
|
||||
} from "./embeddingModels";
|
||||
HostedEmbeddingModel,
|
||||
EmbeddingModelDescriptor,
|
||||
} from "./components/types";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { Connector, ConnectorIndexingStatus } from "@/lib/types";
|
||||
import Link from "next/link";
|
||||
import { CustomModelForm } from "./CustomModelForm";
|
||||
import OpenEmbeddingPage from "./OpenEmbeddingPage";
|
||||
import CloudEmbeddingPage from "./CloudEmbeddingPage";
|
||||
import { ProviderCreationModal } from "./modals/ProviderCreationModal";
|
||||
|
||||
import { DeleteCredentialsModal } from "./modals/DeleteCredentialsModal";
|
||||
import { SelectModelModal } from "./modals/SelectModelModal";
|
||||
import { ChangeCredentialsModal } from "./modals/ChangeCredentialsModal";
|
||||
import { ModelSelectionConfirmationModal } from "./modals/ModelSelectionModal";
|
||||
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../llm/constants";
|
||||
import { AlreadyPickedModal } from "./modals/AlreadyPickedModal";
|
||||
|
||||
export interface EmbeddingDetails {
|
||||
api_key: string;
|
||||
custom_config: any;
|
||||
default_model_id?: number;
|
||||
name: string;
|
||||
}
|
||||
|
||||
function Main() {
|
||||
const [tentativeNewEmbeddingModel, setTentativeNewEmbeddingModel] =
|
||||
useState<EmbeddingModelDescriptor | null>(null);
|
||||
const [openToggle, setOpenToggle] = useState(true);
|
||||
|
||||
// Cloud Provider based modals
|
||||
const [showTentativeProvider, setShowTentativeProvider] =
|
||||
useState<CloudEmbeddingProvider | null>(null);
|
||||
const [showUnconfiguredProvider, setShowUnconfiguredProvider] =
|
||||
useState<CloudEmbeddingProvider | null>(null);
|
||||
const [changeCredentialsProvider, setChangeCredentialsProvider] =
|
||||
useState<CloudEmbeddingProvider | null>(null);
|
||||
|
||||
// Cloud Model based modals
|
||||
const [alreadySelectedModel, setAlreadySelectedModel] =
|
||||
useState<CloudEmbeddingModel | null>(null);
|
||||
const [showTentativeModel, setShowTentativeModel] =
|
||||
useState<CloudEmbeddingModel | null>(null);
|
||||
|
||||
const [showModelInQueue, setShowModelInQueue] =
|
||||
useState<CloudEmbeddingModel | null>(null);
|
||||
|
||||
// Open Model based modals
|
||||
const [showTentativeOpenProvider, setShowTentativeOpenProvider] =
|
||||
useState<HostedEmbeddingModel | null>(null);
|
||||
|
||||
// Enabled / unenabled providers
|
||||
const [newEnabledProviders, setNewEnabledProviders] = useState<string[]>([]);
|
||||
const [newUnenabledProviders, setNewUnenabledProviders] = useState<string[]>(
|
||||
[]
|
||||
);
|
||||
|
||||
const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] =
|
||||
useState<boolean>(false);
|
||||
const [isCancelling, setIsCancelling] = useState<boolean>(false);
|
||||
const [showAddConnectorPopup, setShowAddConnectorPopup] =
|
||||
useState<boolean>(false);
|
||||
@ -33,16 +80,22 @@ function Main() {
|
||||
data: currentEmeddingModel,
|
||||
isLoading: isLoadingCurrentModel,
|
||||
error: currentEmeddingModelError,
|
||||
} = useSWR<EmbeddingModelDescriptor>(
|
||||
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
|
||||
"/api/secondary-index/get-current-embedding-model",
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
);
|
||||
|
||||
const { data: embeddingProviderDetails } = useSWR<EmbeddingDetails[]>(
|
||||
EMBEDDING_PROVIDERS_ADMIN_URL,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const {
|
||||
data: futureEmbeddingModel,
|
||||
isLoading: isLoadingFutureModel,
|
||||
error: futureEmeddingModelError,
|
||||
} = useSWR<EmbeddingModelDescriptor | null>(
|
||||
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
|
||||
"/api/secondary-index/get-secondary-embedding-model",
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
@ -61,27 +114,41 @@ function Main() {
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
);
|
||||
|
||||
const onSelect = async (model: EmbeddingModelDescriptor) => {
|
||||
if (currentEmeddingModel?.model_name === INVALID_OLD_MODEL) {
|
||||
await onConfirm(model);
|
||||
} else {
|
||||
setTentativeNewEmbeddingModel(model);
|
||||
}
|
||||
};
|
||||
const onConfirm = async (
|
||||
model: CloudEmbeddingModel | HostedEmbeddingModel
|
||||
) => {
|
||||
let newModel: EmbeddingModelDescriptor;
|
||||
|
||||
if ("cloud_provider_name" in model) {
|
||||
// This is a CloudEmbeddingModel
|
||||
newModel = {
|
||||
...model,
|
||||
model_name: model.model_name,
|
||||
cloud_provider_name: model.cloud_provider_name,
|
||||
// cloud_provider_id: model.cloud_provider_id || 0,
|
||||
};
|
||||
} else {
|
||||
// This is an EmbeddingModelDescriptor
|
||||
newModel = {
|
||||
...model,
|
||||
model_name: model.model_name!,
|
||||
description: "",
|
||||
cloud_provider_name: null,
|
||||
};
|
||||
}
|
||||
|
||||
const onConfirm = async (model: EmbeddingModelDescriptor) => {
|
||||
const response = await fetch(
|
||||
"/api/secondary-index/set-new-embedding-model",
|
||||
{
|
||||
method: "POST",
|
||||
body: JSON.stringify(model),
|
||||
body: JSON.stringify(newModel),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setTentativeNewEmbeddingModel(null);
|
||||
setShowTentativeModel(null);
|
||||
mutate("/api/secondary-index/get-secondary-embedding-model");
|
||||
if (!connectors || !connectors.length) {
|
||||
setShowAddConnectorPopup(true);
|
||||
@ -96,14 +163,13 @@ function Main() {
|
||||
method: "POST",
|
||||
});
|
||||
if (response.ok) {
|
||||
setTentativeNewEmbeddingModel(null);
|
||||
setShowTentativeModel(null);
|
||||
mutate("/api/secondary-index/get-secondary-embedding-model");
|
||||
} else {
|
||||
alert(
|
||||
`Failed to cancel embedding model update - ${await response.text()}`
|
||||
);
|
||||
}
|
||||
|
||||
setIsCancelling(false);
|
||||
};
|
||||
|
||||
@ -119,216 +185,328 @@ function Main() {
|
||||
return <ErrorCallout errorTitle="Failed to fetch embedding model status" />;
|
||||
}
|
||||
|
||||
const currentModelName = currentEmeddingModel.model_name;
|
||||
const currentModel =
|
||||
AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
|
||||
fillOutEmeddingModelDescriptor(currentEmeddingModel);
|
||||
const onConfirmSelection = async (model: EmbeddingModelDescriptor) => {
|
||||
const response = await fetch(
|
||||
"/api/secondary-index/set-new-embedding-model",
|
||||
{
|
||||
method: "POST",
|
||||
body: JSON.stringify(model),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
setShowTentativeModel(null);
|
||||
mutate("/api/secondary-index/get-secondary-embedding-model");
|
||||
if (!connectors || !connectors.length) {
|
||||
setShowAddConnectorPopup(true);
|
||||
}
|
||||
} else {
|
||||
alert(`Failed to update embedding model - ${await response.text()}`);
|
||||
}
|
||||
};
|
||||
|
||||
const newModelSelection = futureEmbeddingModel
|
||||
? AVAILABLE_MODELS.find(
|
||||
(model) => model.model_name === futureEmbeddingModel.model_name
|
||||
) || fillOutEmeddingModelDescriptor(futureEmbeddingModel)
|
||||
: null;
|
||||
const currentModelName = currentEmeddingModel?.model_name;
|
||||
const AVAILABLE_CLOUD_PROVIDERS_FLATTENED = AVAILABLE_CLOUD_PROVIDERS.flatMap(
|
||||
(provider) =>
|
||||
provider.embedding_models.map((model) => ({
|
||||
...model,
|
||||
cloud_provider_id: provider.id,
|
||||
model_name: model.model_name, // Ensure model_name is set for consistency
|
||||
}))
|
||||
);
|
||||
|
||||
const currentModel: CloudEmbeddingModel | HostedEmbeddingModel =
|
||||
AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
|
||||
AVAILABLE_CLOUD_PROVIDERS_FLATTENED.find(
|
||||
(model) => model.model_name === currentEmeddingModel.model_name
|
||||
)!;
|
||||
// ||
|
||||
// fillOutEmeddingModelDescriptor(currentEmeddingModel);
|
||||
|
||||
const onSelectOpenSource = async (model: HostedEmbeddingModel) => {
|
||||
if (currentEmeddingModel?.model_name === INVALID_OLD_MODEL) {
|
||||
await onConfirmSelection(model);
|
||||
} else {
|
||||
setShowTentativeOpenProvider(model);
|
||||
}
|
||||
};
|
||||
|
||||
const selectedModel = AVAILABLE_CLOUD_PROVIDERS[0];
|
||||
const clientsideAddProvider = (provider: CloudEmbeddingProvider) => {
|
||||
const providerName = provider.name;
|
||||
setNewEnabledProviders((newEnabledProviders) => [
|
||||
...newEnabledProviders,
|
||||
providerName,
|
||||
]);
|
||||
setNewUnenabledProviders((newUnenabledProviders) =>
|
||||
newUnenabledProviders.filter(
|
||||
(givenProvidername) => givenProvidername != providerName
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
const clientsideRemoveProvider = (provider: CloudEmbeddingProvider) => {
|
||||
const providerName = provider.name;
|
||||
setNewEnabledProviders((newEnabledProviders) =>
|
||||
newEnabledProviders.filter(
|
||||
(givenProvidername) => givenProvidername != providerName
|
||||
)
|
||||
);
|
||||
setNewUnenabledProviders((newUnenabledProviders) => [
|
||||
...newUnenabledProviders,
|
||||
providerName,
|
||||
]);
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
{tentativeNewEmbeddingModel && (
|
||||
<ModelSelectionConfirmaionModal
|
||||
selectedModel={tentativeNewEmbeddingModel}
|
||||
isCustom={
|
||||
AVAILABLE_MODELS.find(
|
||||
(model) =>
|
||||
model.model_name === tentativeNewEmbeddingModel.model_name
|
||||
) === undefined
|
||||
}
|
||||
onConfirm={() => onConfirm(tentativeNewEmbeddingModel)}
|
||||
onCancel={() => setTentativeNewEmbeddingModel(null)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{showAddConnectorPopup && (
|
||||
<Modal>
|
||||
<div>
|
||||
<div>
|
||||
<b className="text-base">Embeding model successfully selected</b>{" "}
|
||||
🙌
|
||||
<br />
|
||||
<br />
|
||||
To complete the initial setup, let's add a connector!
|
||||
<br />
|
||||
<br />
|
||||
Connectors are the way that Danswer gets data from your
|
||||
organization's various data sources. Once setup, we'll
|
||||
automatically sync data from your apps and docs into Danswer, so
|
||||
you can search all through all of them in one place.
|
||||
</div>
|
||||
<div className="flex">
|
||||
<Link className="mx-auto mt-2 w-fit" href="/admin/add-connector">
|
||||
<Button className="mt-3 mx-auto" size="xs">
|
||||
Add Connector
|
||||
</Button>
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
)}
|
||||
|
||||
{isCancelling && (
|
||||
<Modal
|
||||
onOutsideClick={() => setIsCancelling(false)}
|
||||
title="Cancel Embedding Model Switch"
|
||||
>
|
||||
<div>
|
||||
<div>
|
||||
Are you sure you want to cancel?
|
||||
<br />
|
||||
<br />
|
||||
Cancelling will revert to the previous model and all progress will
|
||||
be lost.
|
||||
</div>
|
||||
<div className="flex">
|
||||
<Button onClick={onCancel} className="mt-3 mx-auto" color="green">
|
||||
Confirm
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
)}
|
||||
|
||||
<div className="h-screen">
|
||||
<Text>
|
||||
Embedding models are used to generate embeddings for your documents,
|
||||
which then power Danswer's search.
|
||||
</Text>
|
||||
|
||||
{alreadySelectedModel && (
|
||||
<AlreadyPickedModal
|
||||
model={alreadySelectedModel}
|
||||
onClose={() => setAlreadySelectedModel(null)}
|
||||
/>
|
||||
)}
|
||||
{showTentativeOpenProvider && (
|
||||
<ModelSelectionConfirmationModal
|
||||
selectedModel={showTentativeOpenProvider}
|
||||
isCustom={
|
||||
AVAILABLE_MODELS.find(
|
||||
(model) =>
|
||||
model.model_name === showTentativeOpenProvider.model_name
|
||||
) === undefined
|
||||
}
|
||||
onConfirm={() => onConfirm(showTentativeOpenProvider)}
|
||||
onCancel={() => setShowTentativeOpenProvider(null)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{showTentativeProvider && (
|
||||
<ProviderCreationModal
|
||||
selectedProvider={showTentativeProvider}
|
||||
onConfirm={() => {
|
||||
setShowTentativeProvider(showUnconfiguredProvider);
|
||||
clientsideAddProvider(showTentativeProvider);
|
||||
if (showModelInQueue) {
|
||||
setShowTentativeModel(showModelInQueue);
|
||||
}
|
||||
}}
|
||||
onCancel={() => {
|
||||
setShowModelInQueue(null);
|
||||
setShowTentativeProvider(null);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{changeCredentialsProvider && (
|
||||
<ChangeCredentialsModal
|
||||
// setPopup={setPopup}
|
||||
useFileUpload={changeCredentialsProvider.name == "Google"}
|
||||
onDeleted={() => {
|
||||
clientsideRemoveProvider(changeCredentialsProvider);
|
||||
setChangeCredentialsProvider(null);
|
||||
}}
|
||||
provider={changeCredentialsProvider}
|
||||
onConfirm={() => setChangeCredentialsProvider(null)}
|
||||
onCancel={() => setChangeCredentialsProvider(null)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{showTentativeModel && (
|
||||
<SelectModelModal
|
||||
model={showTentativeModel}
|
||||
onConfirm={() => {
|
||||
setShowModelInQueue(null);
|
||||
onConfirm(showTentativeModel);
|
||||
}}
|
||||
onCancel={() => {
|
||||
setShowModelInQueue(null);
|
||||
setShowTentativeModel(null);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{showDeleteCredentialsModal && (
|
||||
<DeleteCredentialsModal
|
||||
modelProvider={showTentativeProvider!}
|
||||
onConfirm={() => {
|
||||
setShowDeleteCredentialsModal(false);
|
||||
}}
|
||||
onCancel={() => setShowDeleteCredentialsModal(false)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{currentModel ? (
|
||||
<>
|
||||
<Title className="mt-8 mb-2">Current Embedding Model</Title>
|
||||
|
||||
<Text>
|
||||
<ModelOption model={currentModel} />
|
||||
<ModelPreview model={currentModel} />
|
||||
</Text>
|
||||
</>
|
||||
) : (
|
||||
newModelSelection &&
|
||||
(!connectors || !connectors.length) && (
|
||||
<>
|
||||
<Title className="mt-8 mb-2">Current Embedding Model</Title>
|
||||
<Title className="mt-8 mb-4">Choose your Embedding Model</Title>
|
||||
)}
|
||||
|
||||
<Text>
|
||||
<ModelOption model={newModelSelection} />
|
||||
</Text>
|
||||
</>
|
||||
)
|
||||
{!(futureEmbeddingModel && connectors && connectors.length > 0) && (
|
||||
<>
|
||||
<Title className="mt-8">Switch your Embedding Model</Title>
|
||||
<Text className="mb-4">
|
||||
If the current model is not working for you, you can update your
|
||||
model choice below. Note that this will require a complete
|
||||
re-indexing of all your documents across every connected source. We
|
||||
will take care of this in the background, but depending on the size
|
||||
of your corpus, this could take hours, day, or even weeks. You can
|
||||
monitor the progress of the re-indexing on this page.
|
||||
</Text>
|
||||
|
||||
<div className="mt-8 text-sm mr-auto mb-12 divide-x-2 flex ">
|
||||
<button
|
||||
onClick={() => setOpenToggle(true)}
|
||||
className={` mx-2 p-2 font-bold ${openToggle ? "rounded bg-neutral-900 text-neutral-100 underline" : "hover:underline"}`}
|
||||
>
|
||||
Self-hosted
|
||||
</button>
|
||||
<div className="px-2 ">
|
||||
<button
|
||||
onClick={() => setOpenToggle(false)}
|
||||
className={`mx-2 p-2 font-bold ${!openToggle ? "rounded bg-neutral-900 text-neutral-100 underline" : " hover:underline"}`}
|
||||
>
|
||||
Cloud-based
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{!showAddConnectorPopup &&
|
||||
(!newModelSelection ? (
|
||||
<div>
|
||||
{currentModel ? (
|
||||
<>
|
||||
<Title className="mt-8">Switch your Embedding Model</Title>
|
||||
|
||||
<Text className="mb-4">
|
||||
If the current model is not working for you, you can update
|
||||
your model choice below. Note that this will require a
|
||||
complete re-indexing of all your documents across every
|
||||
connected source. We will take care of this in the background,
|
||||
but depending on the size of your corpus, this could take
|
||||
hours, day, or even weeks. You can monitor the progress of the
|
||||
re-indexing on this page.
|
||||
</Text>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Title className="mt-8 mb-4">Choose your Embedding Model</Title>
|
||||
</>
|
||||
)}
|
||||
|
||||
<Text className="mb-4">
|
||||
Below are a curated selection of quality models that we recommend
|
||||
you choose from.
|
||||
</Text>
|
||||
|
||||
<ModelSelector
|
||||
modelOptions={AVAILABLE_MODELS.filter(
|
||||
(modelOption) => modelOption.model_name !== currentModelName
|
||||
)}
|
||||
setSelectedModel={onSelect}
|
||||
/>
|
||||
|
||||
<Text className="mt-6">
|
||||
Alternatively, (if you know what you're doing) you can
|
||||
specify a{" "}
|
||||
<a
|
||||
target="_blank"
|
||||
href="https://www.sbert.net/"
|
||||
className="text-link"
|
||||
>
|
||||
SentenceTransformers
|
||||
</a>
|
||||
-compatible model of your choice below. The rough list of
|
||||
supported models can be found{" "}
|
||||
<a
|
||||
target="_blank"
|
||||
href="https://huggingface.co/models?library=sentence-transformers&sort=trending"
|
||||
className="text-link"
|
||||
>
|
||||
here
|
||||
</a>
|
||||
.
|
||||
<br />
|
||||
<b>NOTE:</b> not all models listed will work with Danswer, since
|
||||
some have unique interfaces or special requirements. If in doubt,
|
||||
reach out to the Danswer team.
|
||||
</Text>
|
||||
|
||||
<div className="w-full flex">
|
||||
<Card className="mt-4 2xl:w-4/6 mx-auto">
|
||||
<CustomModelForm onSubmit={onSelect} />
|
||||
</Card>
|
||||
</div>
|
||||
</div>
|
||||
!futureEmbeddingModel &&
|
||||
(openToggle ? (
|
||||
<OpenEmbeddingPage
|
||||
onSelectOpenSource={onSelectOpenSource}
|
||||
currentModelName={currentModelName!}
|
||||
/>
|
||||
) : (
|
||||
connectors &&
|
||||
connectors.length > 0 && (
|
||||
<div>
|
||||
<Title className="mt-8">Current Upgrade Status</Title>
|
||||
<div className="mt-4">
|
||||
<div className="italic text-sm mb-2">
|
||||
Currently in the process of switching to:
|
||||
</div>
|
||||
<ModelOption model={newModelSelection} />
|
||||
|
||||
<Button
|
||||
color="red"
|
||||
size="xs"
|
||||
className="mt-4"
|
||||
onClick={() => setIsCancelling(true)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
|
||||
<Text className="my-4">
|
||||
The table below shows the re-indexing progress of all existing
|
||||
connectors. Once all connectors have been re-indexed
|
||||
successfully, the new model will be used for all search
|
||||
queries. Until then, we will use the old model so that no
|
||||
downtime is necessary during this transition.
|
||||
</Text>
|
||||
|
||||
{isLoadingOngoingReIndexingStatus ? (
|
||||
<ThreeDotsLoader />
|
||||
) : ongoingReIndexingStatus ? (
|
||||
<ReindexingProgressTable
|
||||
reindexingProgress={ongoingReIndexingStatus}
|
||||
/>
|
||||
) : (
|
||||
<ErrorCallout errorTitle="Failed to fetch re-indexing progress" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
<CloudEmbeddingPage
|
||||
setShowModelInQueue={setShowModelInQueue}
|
||||
setShowTentativeModel={setShowTentativeModel}
|
||||
currentModel={currentModel}
|
||||
setAlreadySelectedModel={setAlreadySelectedModel}
|
||||
embeddingProviderDetails={embeddingProviderDetails}
|
||||
newEnabledProviders={newEnabledProviders}
|
||||
newUnenabledProviders={newUnenabledProviders}
|
||||
setShowTentativeProvider={setShowTentativeProvider}
|
||||
selectedModel={selectedModel}
|
||||
setChangeCredentialsProvider={setChangeCredentialsProvider}
|
||||
/>
|
||||
))}
|
||||
|
||||
{openToggle && (
|
||||
<>
|
||||
{showAddConnectorPopup && (
|
||||
<Modal>
|
||||
<div>
|
||||
<div>
|
||||
<b className="text-base">
|
||||
Embedding model successfully selected
|
||||
</b>{" "}
|
||||
🙌
|
||||
<br />
|
||||
<br />
|
||||
To complete the initial setup, let's add a connector!
|
||||
<br />
|
||||
<br />
|
||||
Connectors are the way that Danswer gets data from your
|
||||
organization's various data sources. Once setup,
|
||||
we'll automatically sync data from your apps and docs
|
||||
into Danswer, so you can search all through all of them in one
|
||||
place.
|
||||
</div>
|
||||
<div className="flex">
|
||||
<Link
|
||||
className="mx-auto mt-2 w-fit"
|
||||
href="/admin/add-connector"
|
||||
>
|
||||
<Button className="mt-3 mx-auto" size="xs">
|
||||
Add Connector
|
||||
</Button>
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
)}
|
||||
|
||||
{isCancelling && (
|
||||
<Modal
|
||||
onOutsideClick={() => setIsCancelling(false)}
|
||||
title="Cancel Embedding Model Switch"
|
||||
>
|
||||
<div>
|
||||
<div>
|
||||
Are you sure you want to cancel?
|
||||
<br />
|
||||
<br />
|
||||
Cancelling will revert to the previous model and all progress
|
||||
will be lost.
|
||||
</div>
|
||||
<div className="flex">
|
||||
<Button
|
||||
onClick={onCancel}
|
||||
className="mt-3 mx-auto"
|
||||
color="green"
|
||||
>
|
||||
Confirm
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{futureEmbeddingModel && connectors && connectors.length > 0 && (
|
||||
<div>
|
||||
<Title className="mt-8">Current Upgrade Status</Title>
|
||||
<div className="mt-4">
|
||||
<div className="italic text-lg mb-2">
|
||||
Currently in the process of switching to:{" "}
|
||||
{futureEmbeddingModel.model_name}
|
||||
</div>
|
||||
{/* <ModelOption model={futureEmbeddingModel} /> */}
|
||||
|
||||
<Button
|
||||
color="red"
|
||||
size="xs"
|
||||
className="mt-4"
|
||||
onClick={() => setIsCancelling(true)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
|
||||
<Text className="my-4">
|
||||
The table below shows the re-indexing progress of all existing
|
||||
connectors. Once all connectors have been re-indexed successfully,
|
||||
the new model will be used for all search queries. Until then, we
|
||||
will use the old model so that no downtime is necessary during
|
||||
this transition.
|
||||
</Text>
|
||||
|
||||
{isLoadingOngoingReIndexingStatus ? (
|
||||
<ThreeDotsLoader />
|
||||
) : ongoingReIndexingStatus ? (
|
||||
<ReindexingProgressTable
|
||||
reindexingProgress={ongoingReIndexingStatus}
|
||||
/>
|
||||
) : (
|
||||
<ErrorCallout errorTitle="Failed to fetch re-indexing progress" />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
@ -1 +1,4 @@
|
||||
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
|
||||
|
||||
export const EMBEDDING_PROVIDERS_ADMIN_URL =
|
||||
"/api/admin/embedding/embedding-provider";
|
||||
|
@ -20,7 +20,7 @@ import {
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import { personaComparator } from "../admin/assistants/lib";
|
||||
import { FullEmbeddingModelResponse } from "../admin/models/embedding/embeddingModels";
|
||||
import { FullEmbeddingModelResponse } from "../admin/models/embedding/components/types";
|
||||
import { NoSourcesModal } from "@/components/initialSetup/search/NoSourcesModal";
|
||||
import { NoCompleteSourcesModal } from "@/components/initialSetup/search/NoCompleteSourceModal";
|
||||
import { ChatPopup } from "../chat/ChatPopup";
|
||||
|
@ -1,7 +1,9 @@
|
||||
import { Divider } from "@tremor/react";
|
||||
import { FiX } from "react-icons/fi";
|
||||
import { IconProps } from "./icons/icons";
|
||||
|
||||
interface ModalProps {
|
||||
icon?: ({ size, className }: IconProps) => JSX.Element;
|
||||
children: JSX.Element | string;
|
||||
title?: JSX.Element | string;
|
||||
onOutsideClick?: () => void;
|
||||
@ -13,6 +15,7 @@ interface ModalProps {
|
||||
}
|
||||
|
||||
export function Modal({
|
||||
icon,
|
||||
children,
|
||||
title,
|
||||
onOutsideClick,
|
||||
@ -44,10 +47,15 @@ export function Modal({
|
||||
<>
|
||||
<div className="flex mb-4">
|
||||
<h2
|
||||
className={"my-auto font-bold " + (titleSize || "text-2xl")}
|
||||
className={
|
||||
"my-auto flex content-start gap-x-4 font-bold " +
|
||||
(titleSize || "text-2xl")
|
||||
}
|
||||
>
|
||||
{title}
|
||||
{icon && icon({ size: 30 })}
|
||||
</h2>
|
||||
|
||||
{onOutsideClick && (
|
||||
<div
|
||||
onClick={onOutsideClick}
|
||||
|
@ -72,11 +72,16 @@ import sharepointIcon from "../../../public/Sharepoint.png";
|
||||
import teamsIcon from "../../../public/Teams.png";
|
||||
import mediawikiIcon from "../../../public/MediaWiki.svg";
|
||||
import wikipediaIcon from "../../../public/Wikipedia.svg";
|
||||
|
||||
import discourseIcon from "../../../public/Discourse.png";
|
||||
import clickupIcon from "../../../public/Clickup.svg";
|
||||
import cohereIcon from "../../../public/Cohere.svg";
|
||||
import voyageIcon from "../../../public/Voyage.png";
|
||||
import googleIcon from "../../../public/Google.webp";
|
||||
|
||||
import { FaRobot } from "react-icons/fa";
|
||||
|
||||
interface IconProps {
|
||||
export interface IconProps {
|
||||
size?: number;
|
||||
className?: string;
|
||||
}
|
||||
@ -84,20 +89,6 @@ interface IconProps {
|
||||
export const defaultTailwindCSS = "my-auto flex flex-shrink-0 text-default";
|
||||
export const defaultTailwindCSSBlue = "my-auto flex flex-shrink-0 text-link";
|
||||
|
||||
export const OpenAIIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<div
|
||||
style={{ width: `${size + 4}px`, height: `${size + 4}px` }}
|
||||
className={`w-[${size + 4}px] h-[${size + 4}px] -m-0.5 ` + className}
|
||||
>
|
||||
<Image src={openAISVG} alt="Logo" width="96" height="96" />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const OpenSourceIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
@ -528,6 +519,62 @@ export const ZulipIcon = ({
|
||||
);
|
||||
};
|
||||
|
||||
export const OpenAIIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<div
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
>
|
||||
<Image src={openAISVG} alt="Logo" width="96" height="96" />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const VoyageIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<div
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
>
|
||||
<Image src={voyageIcon} alt="Logo" width="96" height="96" />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const GoogleIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<div
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
>
|
||||
<Image src={googleIcon} alt="Logo" width="96" height="96" />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const CohereIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<div
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
>
|
||||
<Image src={cohereIcon} alt="Logo" width="96" height="96" />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const GoogleStorageIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
|
@ -13,7 +13,7 @@ import {
|
||||
} from "@/lib/types";
|
||||
import { ChatSession } from "@/app/chat/interfaces";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { FullEmbeddingModelResponse } from "@/app/admin/models/embedding/embeddingModels";
|
||||
import { FullEmbeddingModelResponse } from "@/app/admin/models/embedding/components/types";
|
||||
import { Settings } from "@/app/admin/settings/interfaces";
|
||||
import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
|
||||
|
Loading…
x
Reference in New Issue
Block a user