add third party embedding models (#1818)

This commit is contained in:
pablodanswer 2024-07-14 10:19:53 -07:00 committed by GitHub
parent b6bd818e60
commit e7f81d1688
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 2293 additions and 453 deletions

View File

@ -0,0 +1,65 @@
"""add cloud embedding model and update embedding_model
Revision ID: 44f856ae2a4a
Revises: d716b0791ddd
Create Date: 2024-06-28 20:01:05.927647
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "44f856ae2a4a"
down_revision = "d716b0791ddd"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Create embedding_provider table
op.create_table(
"embedding_provider",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("api_key", sa.LargeBinary(), nullable=True),
sa.Column("default_model_id", sa.Integer(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
# Add cloud_provider_id to embedding_model table
op.add_column(
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
)
# Add foreign key constraints
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["cloud_provider_id"],
["id"],
)
op.create_foreign_key(
"fk_embedding_provider_default_model",
"embedding_provider",
"embedding_model",
["default_model_id"],
["id"],
)
def downgrade() -> None:
# Remove foreign key constraints
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
op.drop_constraint(
"fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey"
)
# Remove cloud_provider_id column
op.drop_column("embedding_model", "cloud_provider_id")
# Drop embedding_provider table
op.drop_table("embedding_provider")

View File

@ -10,8 +10,8 @@ from alembic import op
# revision identifiers, used by Alembic.
revision = "d716b0791ddd"
down_revision = "7aea705850d5"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -98,7 +98,6 @@ def _run_indexing(
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
start_time = time.time()
db_embedding_model = index_attempt.embedding_model
index_name = db_embedding_model.index_name
@ -116,6 +115,8 @@ def _run_indexing(
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
api_key=db_embedding_model.api_key,
provider_type=db_embedding_model.provider_type,
)
indexing_pipeline = build_indexing_pipeline(
@ -287,6 +288,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
db_session=db_session,
index_attempt_id=index_attempt_id,
)
if attempt is None:
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")

View File

@ -343,13 +343,15 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if db_embedding_model.cloud_provider_id is None:
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient

View File

@ -469,13 +469,13 @@ if __name__ == "__main__":
# or the tokens have updated (set up for the first time)
with Session(get_sqlalchemy_engine()) as db_session:
embedding_model = get_current_db_embedding_model(db_session)
warm_up_encoders(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if embedding_model.cloud_provider_id is None:
warm_up_encoders(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
slack_bot_tokens = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated

View File

@ -10,10 +10,15 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.models import EmbeddingModelDetail
from danswer.search.search_nlp_models import clean_model_name
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -31,6 +36,7 @@ def create_embedding_model(
query_prefix=model_details.query_prefix,
passage_prefix=model_details.passage_prefix,
status=status,
cloud_provider_id=model_details.cloud_provider_id,
# Every single embedding model except the initial one from migrations has this name
# The initial one from migration is called "danswer_chunk"
index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}",
@ -42,6 +48,42 @@ def create_embedding_model(
return embedding_model
def get_model_id_from_name(
db_session: Session, embedding_provider_name: str
) -> int | None:
query = select(CloudEmbeddingProvider).where(
CloudEmbeddingProvider.name == embedding_provider_name
)
provider = db_session.execute(query).scalars().first()
return provider.id if provider else None
def get_current_db_embedding_provider(
db_session: Session,
) -> ServerCloudEmbeddingProvider | None:
current_embedding_model = EmbeddingModelDetail.from_model(
get_current_db_embedding_model(db_session=db_session)
)
if (
current_embedding_model is None
or current_embedding_model.cloud_provider_id is None
):
return None
embedding_provider = fetch_embedding_provider(
db_session=db_session, provider_id=current_embedding_model.cloud_provider_id
)
if embedding_provider is None:
raise RuntimeError("No embedding provider exists for this model.")
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
cloud_provider_model=embedding_provider
)
return current_embedding_provider
def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel:
query = (
select(EmbeddingModel)

View File

@ -2,11 +2,34 @@ from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
def upsert_cloud_embedding_provider(
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
) -> CloudEmbeddingProvider:
existing_provider = (
db_session.query(CloudEmbeddingProviderModel)
.filter_by(name=provider.name)
.first()
)
if existing_provider:
for key, value in provider.dict().items():
setattr(existing_provider, key, value)
else:
new_provider = CloudEmbeddingProviderModel(**provider.dict())
db_session.add(new_provider)
existing_provider = new_provider
db_session.commit()
db_session.refresh(existing_provider)
return CloudEmbeddingProvider.from_request(existing_provider)
def upsert_llm_provider(
db_session: Session, llm_provider: LLMProviderUpsertRequest
) -> FullLLMProvider:
@ -26,7 +49,6 @@ def upsert_llm_provider(
existing_llm_provider.model_names = llm_provider.model_names
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
# if it does not exist, create a new entry
llm_provider_model = LLMProviderModel(
name=llm_provider.name,
@ -46,10 +68,26 @@ def upsert_llm_provider(
return FullLLMProvider.from_model(llm_provider_model)
def fetch_existing_embedding_providers(
db_session: Session,
) -> list[CloudEmbeddingProviderModel]:
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
return list(db_session.scalars(select(LLMProviderModel)).all())
def fetch_embedding_provider(
db_session: Session, provider_id: int
) -> CloudEmbeddingProviderModel | None:
return db_session.scalar(
select(CloudEmbeddingProviderModel).where(
CloudEmbeddingProviderModel.id == provider_id
)
)
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
@ -70,6 +108,16 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
return FullLLMProvider.from_model(provider_model)
def remove_embedding_provider(
db_session: Session, embedding_provider_name: str
) -> None:
db_session.execute(
delete(CloudEmbeddingProviderModel).where(
CloudEmbeddingProviderModel.name == embedding_provider_name
)
)
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)

View File

@ -130,6 +130,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
chat_folders: Mapped[list["ChatFolder"]] = relationship(
"ChatFolder", back_populates="user"
)
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
# Personas owned by this user
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
@ -469,7 +470,7 @@ class Credential(Base):
class EmbeddingModel(Base):
__tablename__ = "embedding_model"
# ID is used also to indicate the order that the models are configured by the admin
id: Mapped[int] = mapped_column(primary_key=True)
model_name: Mapped[str] = mapped_column(String)
model_dim: Mapped[int] = mapped_column(Integer)
@ -481,6 +482,16 @@ class EmbeddingModel(Base):
)
index_name: Mapped[str] = mapped_column(String)
# New field for cloud provider relationship
cloud_provider_id: Mapped[int | None] = mapped_column(
ForeignKey("embedding_provider.id")
)
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
"CloudEmbeddingProvider",
back_populates="embedding_models",
foreign_keys=[cloud_provider_id],
)
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
"IndexAttempt", back_populates="embedding_model"
)
@ -500,6 +511,18 @@ class EmbeddingModel(Base):
),
)
def __repr__(self) -> str:
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
@property
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider else None
@property
def provider_type(self) -> str | None:
return self.cloud_provider.name if self.cloud_provider else None
class IndexAttempt(Base):
"""
@ -519,6 +542,7 @@ class IndexAttempt(Base):
ForeignKey("credential.id"),
nullable=True,
)
# Some index attempts that run from beginning will still have this as False
# This is only for attempts that are explicitly marked as from the start via
# the run once API
@ -879,11 +903,6 @@ class ChatMessageFeedback(Base):
)
"""
Structures, Organizational, Configurations Tables
"""
class LLMProvider(Base):
__tablename__ = "llm_provider"
@ -912,6 +931,29 @@ class LLMProvider(Base):
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
class CloudEmbeddingProvider(Base):
__tablename__ = "embedding_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
default_model_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("embedding_model.id"), nullable=True
)
embedding_models: Mapped[list["EmbeddingModel"]] = relationship(
"EmbeddingModel",
back_populates="cloud_provider",
foreign_keys="EmbeddingModel.cloud_provider_id",
)
default_model: Mapped["EmbeddingModel"] = relationship(
"EmbeddingModel", foreign_keys=[default_model_id]
)
def __repr__(self) -> str:
return f"<EmbeddingProvider(name='{self.name}')>"
class DocumentSet(Base):
__tablename__ = "document_set"
@ -1194,6 +1236,7 @@ class SlackBotConfig(Base):
response_type: Mapped[SlackBotResponseType] = mapped_column(
Enum(SlackBotResponseType, native_enum=False), nullable=False
)
enable_auto_filters: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)

View File

@ -50,6 +50,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
api_key: str | None = None,
provider_type: str | None = None,
):
super().__init__(model_name, normalize, query_prefix, passage_prefix)
self.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE # Currently not customizable
@ -59,6 +61,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
query_prefix=query_prefix,
passage_prefix=passage_prefix,
normalize=normalize,
api_key=api_key,
provider_type=provider_type,
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,

View File

@ -97,13 +97,19 @@ class EmbeddingModelDetail(BaseModel):
normalize: bool
query_prefix: str | None
passage_prefix: str | None
cloud_provider_id: int | None = None
cloud_provider_name: str | None = None
@classmethod
def from_model(cls, embedding_model: "EmbeddingModel") -> "EmbeddingModelDetail":
def from_model(
cls,
embedding_model: "EmbeddingModel",
) -> "EmbeddingModelDetail":
return cls(
model_name=embedding_model.model_name,
model_dim=embedding_model.model_dim,
normalize=embedding_model.normalize,
query_prefix=embedding_model.query_prefix,
passage_prefix=embedding_model.passage_prefix,
cloud_provider_id=embedding_model.cloud_provider_id,
)

View File

@ -67,6 +67,8 @@ from danswer.server.features.tool.api import admin_router as admin_tool_router
from danswer.server.features.tool.api import router as tool_router
from danswer.server.gpts.api import router as gpts_router
from danswer.server.manage.administrative import router as admin_router
from danswer.server.manage.embedding.api import admin_router as embedding_admin_router
from danswer.server.manage.embedding.api import basic_router as embedding_router
from danswer.server.manage.get_state import router as state_router
from danswer.server.manage.llm.api import admin_router as llm_admin_router
from danswer.server.manage.llm.api import basic_router as llm_router
@ -247,12 +249,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
time.sleep(wait_time)
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if db_embedding_model.cloud_provider_id is None:
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
yield
@ -291,6 +294,8 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, settings_admin_router)
include_router_with_global_prefix_prepended(application, llm_admin_router)
include_router_with_global_prefix_prepended(application, llm_router)
include_router_with_global_prefix_prepended(application, embedding_admin_router)
include_router_with_global_prefix_prepended(application, embedding_router)
include_router_with_global_prefix_prepended(
application, token_rate_limit_settings_router
)

View File

@ -168,6 +168,7 @@ def stream_answer_objects(
max_tokens=max_document_tokens,
use_sections=query_req.chunks_above > 0 or query_req.chunks_below > 0,
)
search_tool = SearchTool(
db_session=db_session,
user=user,

View File

@ -131,6 +131,8 @@ def doc_index_retrieval(
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
normalize=db_embedding_model.normalize,
api_key=db_embedding_model.api_key,
provider_type=db_embedding_model.provider_type,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,

View File

@ -84,20 +84,24 @@ def build_model_server_url(
class EmbeddingModel:
def __init__(
self,
model_name: str,
query_prefix: str | None,
passage_prefix: str | None,
normalize: bool,
server_host: str, # Changes depending on indexing or inference
server_port: int,
model_name: str | None,
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
api_key: str | None,
provider_type: str | None,
# The following are globals are currently not configurable
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> None:
self.model_name = model_name
self.api_key = api_key
self.provider_type = provider_type
self.max_seq_length = max_seq_length
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix
self.normalize = normalize
self.model_name = model_name
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
@ -111,10 +115,13 @@ class EmbeddingModel:
prefixed_texts = texts
embed_request = EmbedRequest(
texts=prefixed_texts,
model_name=self.model_name,
texts=prefixed_texts,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
)
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
@ -187,6 +194,8 @@ def warm_up_encoders(
passage_prefix=None,
server_host=model_server_host,
server_port=model_server_port,
api_key=None,
provider_type=None,
)
# First time downloading the models it may take even longer, but just in case,

View File

@ -0,0 +1,93 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.db.embedding_model import get_current_db_embedding_provider
from danswer.db.engine import get_session
from danswer.db.llm import fetch_existing_embedding_providers
from danswer.db.llm import remove_embedding_provider
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.models import User
from danswer.search.enums import EmbedTextType
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.embedding.models import TestEmbeddingRequest
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/embedding")
basic_router = APIRouter(prefix="/embedding")
@admin_router.post("/test-embedding")
def test_embedding_configuration(
test_llm_request: TestEmbeddingRequest,
_: User | None = Depends(current_admin_user),
) -> None:
try:
test_model = EmbeddingModel(
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
api_key=test_llm_request.api_key,
provider_type=test_llm_request.provider,
normalize=False,
query_prefix=None,
passage_prefix=None,
model_name=None,
)
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
except ValueError as e:
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
logger.error(error_msg)
raise ValueError(error_msg)
except Exception as e:
error_msg = "An error occurred while testing your embedding model. Please check your configuration."
logger.error(f"{error_msg} Error message: {e}", exc_info=True)
raise HTTPException(status_code=400, detail=error_msg)
@admin_router.get("/embedding-provider")
def list_embedding_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[CloudEmbeddingProvider]:
return [
CloudEmbeddingProvider.from_request(embedding_provider_model)
for embedding_provider_model in fetch_existing_embedding_providers(db_session)
]
@admin_router.delete("/embedding-provider/{embedding_provider_name}")
def delete_embedding_provider(
embedding_provider_name: str,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
embedding_provider = get_current_db_embedding_provider(db_session=db_session)
if (
embedding_provider is not None
and embedding_provider_name == embedding_provider.name
):
raise HTTPException(
status_code=400, detail="You can't delete a currently active model"
)
remove_embedding_provider(db_session, embedding_provider_name)
@admin_router.put("/embedding-provider")
def put_cloud_embedding_provider(
provider: CloudEmbeddingProviderCreationRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> CloudEmbeddingProvider:
return upsert_cloud_embedding_provider(db_session, provider)

View File

@ -0,0 +1,35 @@
from typing import TYPE_CHECKING
from pydantic import BaseModel
if TYPE_CHECKING:
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
class TestEmbeddingRequest(BaseModel):
provider: str
api_key: str | None = None
class CloudEmbeddingProvider(BaseModel):
name: str
api_key: str | None = None
default_model_id: int | None = None
id: int
@classmethod
def from_request(
cls, cloud_provider_model: "CloudEmbeddingProviderModel"
) -> "CloudEmbeddingProvider":
return cls(
id=cloud_provider_model.id,
name=cloud_provider_model.name,
api_key=cloud_provider_model.api_key,
default_model_id=cloud_provider_model.default_model_id,
)
class CloudEmbeddingProviderCreationRequest(BaseModel):
name: str
api_key: str | None = None
default_model_id: int | None = None

View File

@ -4,6 +4,7 @@ from pydantic import BaseModel
from danswer.llm.llm_provider_options import fetch_models_for_provider
if TYPE_CHECKING:
from danswer.db.models import LLMProvider as LLMProviderModel

View File

@ -11,6 +11,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.embedding_model import create_embedding_model
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_model_id_from_name
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.embedding_model import update_embedding_model_status
from danswer.db.engine import get_session
@ -38,6 +39,19 @@ def set_new_embedding_model(
"""
current_model = get_current_db_embedding_model(db_session)
if embed_model_details.cloud_provider_name is not None:
cloud_id = get_model_id_from_name(
db_session, embed_model_details.cloud_provider_name
)
if cloud_id is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No ID exists for given provider name",
)
embed_model_details.cloud_provider_id = cloud_id
if embed_model_details.model_name == current_model.model_name:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -1 +1,38 @@
from enum import Enum
from danswer.search.enums import EmbedTextType
MODEL_WARM_UP_STRING = "hi " * 512
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
DEFAULT_VERTEX_MODEL = "text-embedding-004"
class EmbeddingProvider(Enum):
OPENAI = "openai"
COHERE = "cohere"
VOYAGE = "voyage"
GOOGLE = "google"
class EmbeddingModelTextType:
PROVIDER_TEXT_TYPE_MAP = {
EmbeddingProvider.COHERE: {
EmbedTextType.QUERY: "search_query",
EmbedTextType.PASSAGE: "search_document",
},
EmbeddingProvider.VOYAGE: {
EmbedTextType.QUERY: "query",
EmbedTextType.PASSAGE: "document",
},
EmbeddingProvider.GOOGLE: {
EmbedTextType.QUERY: "RETRIEVAL_QUERY",
EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT",
},
}
@staticmethod
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]

View File

@ -1,12 +1,28 @@
import gc
import json
from typing import Any
from typing import Optional
import openai
import vertexai # type: ignore
import voyageai # type: ignore
from cohere import Client as CohereClient
from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from vertexai.language_models import TextEmbeddingInput # type: ignore
from vertexai.language_models import TextEmbeddingModel # type: ignore
from danswer.search.enums import EmbedTextType
from danswer.utils.logger import setup_logger
from model_server.constants import DEFAULT_COHERE_MODEL
from model_server.constants import DEFAULT_OPENAI_MODEL
from model_server.constants import DEFAULT_VERTEX_MODEL
from model_server.constants import DEFAULT_VOYAGE_MODEL
from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.utils import simple_log_function_time
from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
@ -17,6 +33,7 @@ from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
logger = setup_logger()
router = APIRouter(prefix="/encoder")
@ -25,6 +42,117 @@ _GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
class CloudEmbedding:
def __init__(self, api_key: str, provider: str, model: str | None = None):
self.api_key = api_key
# Only for Google as is needed on client setup
self.model = model
try:
self.provider = EmbeddingProvider(provider.lower())
except ValueError:
raise ValueError(f"Unsupported provider: {provider}")
self.client = self._initialize_client()
def _initialize_client(self) -> Any:
if self.provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=self.api_key)
elif self.provider == EmbeddingProvider.COHERE:
return CohereClient(api_key=self.api_key)
elif self.provider == EmbeddingProvider.VOYAGE:
return voyageai.Client(api_key=self.api_key)
elif self.provider == EmbeddingProvider.GOOGLE:
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.api_key)
)
project_id = json.loads(self.api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(
self.model or DEFAULT_VERTEX_MODEL
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def encode(
self, texts: list[str], model_name: str | None, text_type: EmbedTextType
) -> list[list[float]]:
return [
self.embed(text=text, text_type=text_type, model=model_name)
for text in texts
]
def embed(
self, *, text: str, text_type: EmbedTextType, model: str | None = None
) -> list[float]:
logger.debug(f"Embedding text with provider: {self.provider}")
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(text, model)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(text, model, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(text, model, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(text, model, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
def _embed_openai(self, text: str, model: str | None) -> list[float]:
if model is None:
model = DEFAULT_OPENAI_MODEL
response = self.client.embeddings.create(input=text, model=model)
return response.data[0].embedding
def _embed_cohere(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
if model is None:
model = DEFAULT_COHERE_MODEL
response = self.client.embed(
texts=[text],
model=model,
input_type=embedding_type,
)
return response.embeddings[0]
def _embed_voyage(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
if model is None:
model = DEFAULT_VOYAGE_MODEL
response = self.client.embed(text, model=model, input_type=embedding_type)
return response.embeddings[0]
def _embed_vertex(
self, text: str, model: str | None, embedding_type: str
) -> list[float]:
if model is None:
model = DEFAULT_VERTEX_MODEL
embedding = self.client.get_embeddings(
[
TextEmbeddingInput(
text,
embedding_type,
)
]
)
return embedding[0].values
@staticmethod
def create(
api_key: str, provider: str, model: str | None = None
) -> "CloudEmbedding":
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, model)
def get_embedding_model(
model_name: str,
max_context_length: int,
@ -78,18 +206,35 @@ def warm_up_cross_encoders() -> None:
@simple_log_function_time()
def embed_text(
texts: list[str],
model_name: str,
text_type: EmbedTextType,
model_name: str | None,
max_context_length: int,
normalize_embeddings: bool,
api_key: str | None,
provider_type: str | None,
) -> list[list[float]]:
model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
if provider_type is not None:
if api_key is None:
raise RuntimeError("API key not provided for cloud model")
cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=model_name
)
embeddings = cloud_model.encode(texts, model_name, text_type)
elif model_name is not None:
hosted_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = hosted_model.encode(
texts, normalize_embeddings=normalize_embeddings
)
if embeddings is None:
raise RuntimeError("Embeddings were not created")
if not isinstance(embeddings, list):
embeddings = embeddings.tolist()
return embeddings
@ -113,6 +258,9 @@ async def process_embed_request(
model_name=embed_request.model_name,
max_context_length=embed_request.max_context_length,
normalize_embeddings=embed_request.normalize_embeddings,
api_key=embed_request.api_key,
provider_type=embed_request.provider_type,
text_type=embed_request.text_type,
)
return EmbedResponse(embeddings=embeddings)
except Exception as e:

View File

@ -7,3 +7,7 @@ tensorflow==2.15.0
torch==2.0.1
transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
openai==1.14.3
cohere==5.5.8
google-cloud-aiplatform==1.58.0

View File

@ -1,12 +1,19 @@
from pydantic import BaseModel
from danswer.search.enums import EmbedTextType
class EmbedRequest(BaseModel):
# This already includes any prefixes, the text is just passed directly to the model
texts: list[str]
model_name: str
# Can be none for cloud embedding model requests, error handling logic exists for other cases
model_name: str | None
max_context_length: int
normalize_embeddings: bool
api_key: str | None
provider_type: str | None
text_type: EmbedTextType
class EmbedResponse(BaseModel):

30
web/public/Cohere.svg Normal file
View File

@ -0,0 +1,30 @@
<svg version="1.1" id="Layer_1" xmlns:x="ns_extend;" xmlns:i="ns_ai;" xmlns:graph="ns_graphs;" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px" viewBox="0 0 75 75" style="enable-background:new 0 0 75 75;" xml:space="preserve">
<style type="text/css">
.st0{fill-rule:evenodd;clip-rule:evenodd;fill:#39594D;}
.st1{fill-rule:evenodd;clip-rule:evenodd;fill:#D18EE2;}
.st2{fill:#FF7759;}
</style>
<metadata>
<sfw xmlns="ns_sfw;">
<slices>
</slices>
<sliceSourceBounds bottomLeftOrigin="true" height="75" width="75" x="-347.6" y="0.5">
</sliceSourceBounds>
</sfw>
</metadata>
<g>
<g>
<g>
<path class="st0" d="M24.3,44.7c2,0,6-0.1,11.6-2.4c6.5-2.7,19.3-7.5,28.6-12.5c6.5-3.5,9.3-8.1,9.3-14.3C73.8,7,66.9,0,58.3,0
h-36C10,0,0,10,0,22.3S9.4,44.7,24.3,44.7z">
</path>
<path class="st1" d="M30.4,60c0-6,3.6-11.5,9.2-13.8l11.3-4.7C62.4,36.8,75,45.2,75,57.6C75,67.2,67.2,75,57.6,75l-12.3,0
C37.1,75,30.4,68.3,30.4,60z">
</path>
<path class="st2" d="M12.9,47.6L12.9,47.6C5.8,47.6,0,53.4,0,60.5v1.7C0,69.2,5.8,75,12.9,75h0c7.1,0,12.9-5.8,12.9-12.9v-1.7
C25.7,53.4,20,47.6,12.9,47.6z">
</path>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

BIN
web/public/Google.webp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.4 KiB

View File

@ -1 +1 @@
<svg viewBox="0 0 320 320" xmlns="http://www.w3.org/2000/svg"><path d="m297.06 130.97c7.26-21.79 4.76-45.66-6.85-65.48-17.46-30.4-52.56-46.04-86.84-38.68-15.25-17.18-37.16-26.95-60.13-26.81-35.04-.08-66.13 22.48-76.91 55.82-22.51 4.61-41.94 18.7-53.31 38.67-17.59 30.32-13.58 68.54 9.92 94.54-7.26 21.79-4.76 45.66 6.85 65.48 17.46 30.4 52.56 46.04 86.84 38.68 15.24 17.18 37.16 26.95 60.13 26.8 35.06.09 66.16-22.49 76.94-55.86 22.51-4.61 41.94-18.7 53.31-38.67 17.57-30.32 13.55-68.51-9.94-94.51zm-120.28 168.11c-14.03.02-27.62-4.89-38.39-13.88.49-.26 1.34-.73 1.89-1.07l63.72-36.8c3.26-1.85 5.26-5.32 5.24-9.07v-89.83l26.93 15.55c.29.14.48.42.52.74v74.39c-.04 33.08-26.83 59.9-59.91 59.97zm-128.84-55.03c-7.03-12.14-9.56-26.37-7.15-40.18.47.28 1.3.79 1.89 1.13l63.72 36.8c3.23 1.89 7.23 1.89 10.47 0l77.79-44.92v31.1c.02.32-.13.63-.38.83l-64.41 37.19c-28.69 16.52-65.33 6.7-81.92-21.95zm-16.77-139.09c7-12.16 18.05-21.46 31.21-26.29 0 .55-.03 1.52-.03 2.2v73.61c-.02 3.74 1.98 7.21 5.23 9.06l77.79 44.91-26.93 15.55c-.27.18-.61.21-.91.08l-64.42-37.22c-28.63-16.58-38.45-53.21-21.95-81.89zm221.26 51.49-77.79-44.92 26.93-15.54c.27-.18.61-.21.91-.08l64.42 37.19c28.68 16.57 38.51 53.26 21.94 81.94-7.01 12.14-18.05 21.44-31.2 26.28v-75.81c.03-3.74-1.96-7.2-5.2-9.06zm26.8-40.34c-.47-.29-1.3-.79-1.89-1.13l-63.72-36.8c-3.23-1.89-7.23-1.89-10.47 0l-77.79 44.92v-31.1c-.02-.32.13-.63.38-.83l64.41-37.16c28.69-16.55 65.37-6.7 81.91 22 6.99 12.12 9.52 26.31 7.15 40.1zm-168.51 55.43-26.94-15.55c-.29-.14-.48-.42-.52-.74v-74.39c.02-33.12 26.89-59.96 60.01-59.94 14.01 0 27.57 4.92 38.34 13.88-.49.26-1.33.73-1.89 1.07l-63.72 36.8c-3.26 1.85-5.26 5.31-5.24 9.06l-.04 89.79zm14.63-31.54 34.65-20.01 34.65 20v40.01l-34.65 20-34.65-20z"/></svg>
<svg viewBox="0 0 320 320" xmlns="http://www.w3.org/2000/svg"><path d="m297.06 130.97c7.26-21.79 4.76-45.66-6.85-65.48-17.46-30.4-52.56-46.04-86.84-38.68-15.25-17.18-37.16-26.95-60.13-26.81-35.04-.08-66.13 22.48-76.91 55.82-22.51 4.61-41.94 18.7-53.31 38.67-17.59 30.32-13.58 68.54 9.92 94.54-7.26 21.79-4.76 45.66 6.85 65.48 17.46 30.4 52.56 46.04 86.84 38.68 15.24 17.18 37.16 26.95 60.13 26.8 35.06.09 66.16-22.49 76.94-55.86 22.51-4.61 41.94-18.7 53.31-38.67 17.57-30.32 13.55-68.51-9.94-94.51zm-120.28 168.11c-14.03.02-27.62-4.89-38.39-13.88.49-.26 1.34-.73 1.89-1.07l63.72-36.8c3.26-1.85 5.26-5.32 5.24-9.07v-89.83l26.93 15.55c.29.14.48.42.52.74v74.39c-.04 33.08-26.83 59.9-59.91 59.97zm-128.84-55.03c-7.03-12.14-9.56-26.37-7.15-40.18.47.28 1.3.79 1.89 1.13l63.72 36.8c3.23 1.89 7.23 1.89 10.47 0l77.79-44.92v31.1c.02.32-.13.63-.38.83l-64.41 37.19c-28.69 16.52-65.33 6.7-81.92-21.95zm-16.77-139.09c7-12.16 18.05-21.46 31.21-26.29 0 .55-.03 1.52-.03 2.2v73.61c-.02 3.74 1.98 7.21 5.23 9.06l77.79 44.91-26.93 15.55c-.27.18-.61.21-.91.08l-64.42-37.22c-28.63-16.58-38.45-53.21-21.95-81.89zm221.26 51.49-77.79-44.92 26.93-15.54c.27-.18.61-.21.91-.08l64.42 37.19c28.68 16.57 38.51 53.26 21.94 81.94-7.01 12.14-18.05 21.44-31.2 26.28v-75.81c.03-3.74-1.96-7.2-5.2-9.06zm26.8-40.34c-.47-.29-1.3-.79-1.89-1.13l-63.72-36.8c-3.23-1.89-7.23-1.89-10.47 0l-77.79 44.92v-31.1c-.02-.32.13-.63.38-.83l64.41-37.16c28.69-16.55 65.37-6.7 81.91 22 6.99 12.12 9.52 26.31 7.15 40.1zm-168.51 55.43-26.94-15.55c-.29-.14-.48-.42-.52-.74v-74.39c.02-33.12 26.89-59.96 60.01-59.94 14.01 0 27.57 4.92 38.34 13.88-.49.26-1.33.73-1.89 1.07l-63.72 36.8c-3.26 1.85-5.26 5.31-5.24 9.06l-.04 89.79zm14.63-31.54 34.65-20.01 34.65 20v40.01l-34.65 20-34.65-20z"/></svg>

Before

Width:  |  Height:  |  Size: 1.7 KiB

After

Width:  |  Height:  |  Size: 1.7 KiB

BIN
web/public/Voyage.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

View File

@ -0,0 +1,163 @@
"use client";
import { Text, Title } from "@tremor/react";
import {
CloudEmbeddingProvider,
CloudEmbeddingModel,
AVAILABLE_CLOUD_PROVIDERS,
CloudEmbeddingProviderFull,
EmbeddingModelDescriptor,
} from "./components/types";
import { EmbeddingDetails } from "./page";
import { FiInfo } from "react-icons/fi";
import { HoverPopup } from "@/components/HoverPopup";
import { Dispatch, SetStateAction } from "react";
export default function CloudEmbeddingPage({
currentModel,
embeddingProviderDetails,
newEnabledProviders,
newUnenabledProviders,
setShowTentativeProvider,
setChangeCredentialsProvider,
setAlreadySelectedModel,
setShowTentativeModel,
setShowModelInQueue,
}: {
setShowModelInQueue: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel;
setAlreadySelectedModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
newUnenabledProviders: string[];
embeddingProviderDetails?: EmbeddingDetails[];
newEnabledProviders: string[];
selectedModel: CloudEmbeddingProvider;
// create modal functions
setShowTentativeProvider: React.Dispatch<
React.SetStateAction<CloudEmbeddingProvider | null>
>;
setChangeCredentialsProvider: React.Dispatch<
React.SetStateAction<CloudEmbeddingProvider | null>
>;
}) {
function hasNameInArray(
arr: Array<{ name: string }>,
searchName: string
): boolean {
return arr.some(
(item) => item.name.toLowerCase() === searchName.toLowerCase()
);
}
let providers: CloudEmbeddingProviderFull[] = [];
AVAILABLE_CLOUD_PROVIDERS.forEach((model, ind) => {
let temporary_model: CloudEmbeddingProviderFull = {
...model,
configured:
!newUnenabledProviders.includes(model.name) &&
(newEnabledProviders.includes(model.name) ||
(embeddingProviderDetails &&
hasNameInArray(embeddingProviderDetails, model.name))!),
};
providers.push(temporary_model);
});
return (
<div>
<Title className="mt-8">
Here are some cloud-based models to choose from.
</Title>
<Text className="mb-4">
They require API keys and run in the clouds of the respective providers.
</Text>
<div className="gap-4 mt-2 pb-10 flex content-start flex-wrap">
{providers.map((provider, ind) => (
<div
key={ind}
className="p-4 border border-border rounded-lg shadow-md bg-hover-light w-96 flex flex-col"
>
<div className="font-bold text-neutral-900 text-lg items-center py-1 gap-x-2 flex">
{provider.icon({ size: 40 })}
<p className="my-auto">{provider.name}</p>
<button
onClick={() => {
setShowTentativeProvider(provider);
}}
className="cursor-pointer ml-auto"
>
<a className="my-auto hover:underline cursor-pointer">
<HoverPopup
mainContent={
<FiInfo className="cusror-pointer" size={20} />
}
popupContent={
<div className="text-sm text-neutral-800 w-52 flex">
<div className="flex mx-auto">
<div className="my-auto">{provider.description}</div>
</div>
</div>
}
direction="left-top"
style="dark"
/>
</a>
</button>
</div>
<div>
{provider.embedding_models.map((model, index) => {
const enabled = model.model_name == currentModel.model_name;
return (
<div
key={index}
className={`p-3 my-2 border-2 border-neutral-300 border-opacity-40 rounded-md rounded cursor-pointer
${!provider.configured ? "opacity-80 hover:opacity-100" : enabled ? "bg-background-stronger" : "hover:bg-background-strong"}`}
onClick={() => {
if (enabled) {
setAlreadySelectedModel(model);
} else if (provider.configured) {
setShowTentativeModel(model);
} else {
setShowModelInQueue(model);
setShowTentativeProvider(provider);
}
}}
>
<div className="flex justify-between">
<div className="font-medium text-sm">
{model.model_name}
</div>
<p className="text-sm flex-none">
${model.pricePerMillion}/M tokens
</p>
</div>
<div className="text-sm text-gray-600">
{model.description}
</div>
</div>
);
})}
</div>
<button
onClick={() => {
if (!provider.configured) {
setShowTentativeProvider(provider);
} else {
setChangeCredentialsProvider(provider);
}
}}
className="hover:underline mb-1 text-sm mr-auto cursor-pointer"
>
{provider.configured && "Modify credentials"}
</button>
</div>
))}
</div>
</div>
);
}

View File

@ -1,74 +0,0 @@
import { Modal } from "@/components/Modal";
import { Button, Text, Callout } from "@tremor/react";
import { EmbeddingModelDescriptor } from "./embeddingModels";
export function ModelSelectionConfirmaion({
selectedModel,
isCustom,
onConfirm,
}: {
selectedModel: EmbeddingModelDescriptor;
isCustom: boolean;
onConfirm: () => void;
}) {
return (
<div className="mb-4">
<Text className="text-lg mb-4">
You have selected: <b>{selectedModel.model_name}</b>. Are you sure you
want to update to this new embedding model?
</Text>
<Text className="text-lg mb-2">
We will re-index all your documents in the background so you will be
able to continue to use Danswer as normal with the old model in the
meantime. Depending on how many documents you have indexed, this may
take a while.
</Text>
<Text className="text-lg mb-2">
<i>NOTE:</i> this re-indexing process will consume more resources than
normal. If you are self-hosting, we recommend that you allocate at least
16GB of RAM to Danswer during this process.
</Text>
{isCustom && (
<Callout title="IMPORTANT" color="yellow" className="mt-4">
We&apos;ve detected that this is a custom-specified embedding model.
Since we have to download the model files before verifying the
configuration&apos;s correctness, we won&apos;t be able to let you
know if the configuration is valid until <b>after</b> we start
re-indexing your documents. If there is an issue, it will show up on
this page as an indexing error on this page after clicking Confirm.
</Callout>
)}
<div className="flex mt-8">
<Button className="mx-auto" color="green" onClick={onConfirm}>
Confirm
</Button>
</div>
</div>
);
}
export function ModelSelectionConfirmaionModal({
selectedModel,
isCustom,
onConfirm,
onCancel,
}: {
selectedModel: EmbeddingModelDescriptor;
isCustom: boolean;
onConfirm: () => void;
onCancel: () => void;
}) {
return (
<Modal title="Update Embedding Model" onOutsideClick={onCancel}>
<div>
<ModelSelectionConfirmaion
selectedModel={selectedModel}
isCustom={isCustom}
onConfirm={onConfirm}
/>
</div>
</Modal>
);
}

View File

@ -0,0 +1,55 @@
"use client";
import { Card, Text, Title } from "@tremor/react";
import { ModelSelector } from "./components/ModelSelector";
import {
AVAILABLE_MODELS,
EmbeddingModelDescriptor,
HostedEmbeddingModel,
} from "./components/types";
import { CustomModelForm } from "./components/CustomModelForm";
export default function OpenEmbeddingPage({
onSelectOpenSource,
currentModelName,
}: {
currentModelName: string;
onSelectOpenSource: (model: HostedEmbeddingModel) => Promise<void>;
}) {
return (
<div>
<ModelSelector
modelOptions={AVAILABLE_MODELS.filter(
(modelOption) => modelOption.model_name !== currentModelName
)}
setSelectedModel={onSelectOpenSource}
/>
<Text className="mt-6">
Alternatively, (if you know what you&apos;re doing) you can specify a{" "}
<a target="_blank" href="https://www.sbert.net/" className="text-link">
SentenceTransformers
</a>
-compatible model of your choice below. The rough list of supported
models can be found{" "}
<a
target="_blank"
href="https://huggingface.co/models?library=sentence-transformers&sort=trending"
className="text-link"
>
here
</a>
.
<br />
<b>NOTE:</b> not all models listed will work with Danswer, since some
have unique interfaces or special requirements. If in doubt, reach out
to the Danswer team.
</Text>
<div className="w-full flex">
<Card className="mt-4 2xl:w-4/6 mx-auto">
<CustomModelForm onSubmit={onSelectOpenSource} />
</Card>
</div>
</div>
);
}

View File

@ -2,16 +2,15 @@ import {
BooleanFormField,
TextFormField,
} from "@/components/admin/connectors/Field";
import { Button, Divider, Text } from "@tremor/react";
import { Button } from "@tremor/react";
import { Form, Formik } from "formik";
import * as Yup from "yup";
import { EmbeddingModelDescriptor } from "./embeddingModels";
import { EmbeddingModelDescriptor, HostedEmbeddingModel } from "./types";
export function CustomModelForm({
onSubmit,
}: {
onSubmit: (model: EmbeddingModelDescriptor) => void;
onSubmit: (model: HostedEmbeddingModel) => void;
}) {
return (
<div>
@ -21,6 +20,7 @@ export function CustomModelForm({
model_dim: "",
query_prefix: "",
passage_prefix: "",
description: "",
normalize: true,
}}
validationSchema={Yup.object().shape({
@ -62,6 +62,13 @@ export function CustomModelForm({
}
}}
/>
<TextFormField
name="description"
label="Description:"
subtext="Description of your model"
placeholder=""
autoCompleteDisabled={true}
/>
<TextFormField
name="query_prefix"
@ -77,7 +84,6 @@ export function CustomModelForm({
placeholder="E.g. 'query: '"
autoCompleteDisabled={true}
/>
<TextFormField
name="passage_prefix"
label="[Optional] Passage Prefix:"

View File

@ -1,18 +1,29 @@
import { DefaultDropdown, StringOrNumberOption } from "@/components/Dropdown";
import { Title, Text, Divider, Card } from "@tremor/react";
import {
EmbeddingModelDescriptor,
FullEmbeddingModelDescriptor,
} from "./embeddingModels";
import { EmbeddingModelDescriptor, HostedEmbeddingModel } from "./types";
import { FiStar } from "react-icons/fi";
import { CustomModelForm } from "./CustomModelForm";
export function ModelPreview({ model }: { model: EmbeddingModelDescriptor }) {
return (
<div
className={
"p-2 border border-border rounded shadow-md bg-hover-light w-96 flex flex-col"
}
>
<div className="font-bold text-lg flex">{model.model_name}</div>
<div className="text-sm mt-1 mx-1">
{model.description
? model.description
: "Custom model—no description is available."}
</div>
</div>
);
}
export function ModelOption({
model,
onSelect,
}: {
model: FullEmbeddingModelDescriptor;
onSelect?: (model: EmbeddingModelDescriptor) => void;
model: HostedEmbeddingModel;
onSelect?: (model: HostedEmbeddingModel) => void;
}) {
return (
<div
@ -68,8 +79,8 @@ export function ModelSelector({
modelOptions,
setSelectedModel,
}: {
modelOptions: FullEmbeddingModelDescriptor[];
setSelectedModel: (model: EmbeddingModelDescriptor) => void;
modelOptions: HostedEmbeddingModel[];
setSelectedModel: (model: HostedEmbeddingModel) => void;
}) {
return (
<div>

View File

@ -0,0 +1,286 @@
import {
CohereIcon,
GoogleIcon,
IconProps,
OpenAIIcon,
VoyageIcon,
} from "@/components/icons/icons";
// Cloud Provider (not needed for hosted ones)
export interface CloudEmbeddingProvider {
id: number;
name: string;
api_key?: string;
custom_config?: Record<string, string>;
docsLink?: string;
// Frontend-specific properties
website: string;
icon: ({ size, className }: IconProps) => JSX.Element;
description: string;
apiLink: string;
costslink?: string;
// Relationships
embedding_models: CloudEmbeddingModel[];
default_model?: CloudEmbeddingModel;
}
// Embedding Models
export interface EmbeddingModelDescriptor {
model_name: string;
model_dim: number;
normalize: boolean;
query_prefix: string;
passage_prefix: string;
cloud_provider_name?: string | null;
description: string;
}
export interface CloudEmbeddingModel extends EmbeddingModelDescriptor {
cloud_provider_name: string | null;
pricePerMillion: number;
enabled?: boolean;
mtebScore: number;
maxContext: number;
}
export interface HostedEmbeddingModel extends EmbeddingModelDescriptor {
link?: string;
model_dim: number;
normalize: boolean;
query_prefix: string;
passage_prefix: string;
isDefault?: boolean;
}
// Responses
export interface FullEmbeddingModelResponse {
current_model_name: string;
secondary_model_name: string | null;
}
export interface CloudEmbeddingProviderFull extends CloudEmbeddingProvider {
configured: boolean;
}
export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
{
model_name: "intfloat/e5-base-v2",
model_dim: 768,
normalize: true,
description:
"The recommended default for most situations. If you aren't sure which model to use, this is probably the one.",
isDefault: true,
link: "https://huggingface.co/intfloat/e5-base-v2",
query_prefix: "query: ",
passage_prefix: "passage: ",
},
{
model_name: "intfloat/e5-small-v2",
model_dim: 384,
normalize: true,
description:
"A smaller / faster version of the default model. If you're running Danswer on a resource constrained system, then this is a good choice.",
link: "https://huggingface.co/intfloat/e5-small-v2",
query_prefix: "query: ",
passage_prefix: "passage: ",
},
{
model_name: "intfloat/multilingual-e5-base",
model_dim: 768,
normalize: true,
description:
"If you have many documents in other languages besides English, this is the one to go for.",
link: "https://huggingface.co/intfloat/multilingual-e5-base",
query_prefix: "query: ",
passage_prefix: "passage: ",
},
{
model_name: "intfloat/multilingual-e5-small",
model_dim: 384,
normalize: true,
description:
"If you have many documents in other languages besides English, and you're running on a resource constrained system, then this is the one to go for.",
link: "https://huggingface.co/intfloat/multilingual-e5-base",
query_prefix: "query: ",
passage_prefix: "passage: ",
},
];
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
{
id: 0,
name: "OpenAI",
website: "https://openai.com",
icon: OpenAIIcon,
description: "AI industry leader known for ChatGPT and DALL-E",
apiLink: "https://platform.openai.com/api-keys",
docsLink:
"https://docs.danswer.dev/guides/embedding_providers#openai-models",
costslink: "https://openai.com/pricing",
embedding_models: [
{
model_name: "text-embedding-3-large",
cloud_provider_name: "OpenAI",
description:
"OpenAI's large embedding model. Best performance, but more expensive.",
pricePerMillion: 0.13,
model_dim: 3072,
normalize: false,
query_prefix: "",
passage_prefix: "",
mtebScore: 64.6,
maxContext: 8191,
enabled: false,
},
{
model_name: "text-embedding-3-small",
cloud_provider_name: "OpenAI",
model_dim: 1536,
normalize: false,
query_prefix: "",
passage_prefix: "",
description:
"OpenAI's newer, more efficient embedding model. Good balance of performance and cost.",
pricePerMillion: 0.02,
enabled: false,
mtebScore: 62.3,
maxContext: 8191,
},
],
},
{
id: 1,
name: "Cohere",
website: "https://cohere.ai",
icon: CohereIcon,
docsLink:
"https://docs.danswer.dev/guides/embedding_providers#cohere-models",
description:
"AI company specializing in NLP models for various text-based tasks",
apiLink: "https://dashboard.cohere.ai/api-keys",
costslink: "https://cohere.com/pricing",
embedding_models: [
{
model_name: "embed-english-v3.0",
cloud_provider_name: "Cohere",
description:
"Cohere's English embedding model. Good performance for English-language tasks.",
pricePerMillion: 0.1,
mtebScore: 64.5,
maxContext: 512,
enabled: false,
model_dim: 1024,
normalize: false,
query_prefix: "",
passage_prefix: "",
},
{
model_name: "embed-english-light-v3.0",
cloud_provider_name: "Cohere",
description:
"Cohere's lightweight English embedding model. Faster and more efficient for simpler tasks.",
pricePerMillion: 0.1,
mtebScore: 62,
maxContext: 512,
enabled: false,
model_dim: 384,
normalize: false,
query_prefix: "",
passage_prefix: "",
},
],
},
{
id: 2,
name: "Google",
website: "https://ai.google",
icon: GoogleIcon,
docsLink:
"https://docs.danswer.dev/guides/embedding_providers#vertex-ai-google-model",
description:
"Offers a wide range of AI services including language and vision models",
apiLink: "https://console.cloud.google.com/apis/credentials",
costslink: "https://cloud.google.com/vertex-ai/pricing",
embedding_models: [
{
cloud_provider_name: "Google",
model_name: "text-embedding-004",
description: "Google's most recent text embedding model.",
pricePerMillion: 0.025,
mtebScore: 66.31,
maxContext: 2048,
enabled: false,
model_dim: 768,
normalize: false,
query_prefix: "",
passage_prefix: "",
},
{
cloud_provider_name: "Google",
model_name: "textembedding-gecko@003",
description: "Google's Gecko embedding model. Powerful and efficient.",
pricePerMillion: 0.025,
mtebScore: 66.31,
maxContext: 2048,
enabled: false,
model_dim: 768,
normalize: false,
query_prefix: "",
passage_prefix: "",
},
],
},
{
id: 3,
name: "Voyage",
website: "https://www.voyageai.com",
icon: VoyageIcon,
description: "Advanced NLP research startup born from Stanford AI Labs",
docsLink:
"https://docs.danswer.dev/guides/embedding_providers#voyage-models",
apiLink: "https://www.voyageai.com/dashboard",
costslink: "https://www.voyageai.com/pricing",
embedding_models: [
{
cloud_provider_name: "Voyage",
model_name: "voyage-large-2-instruct",
description:
"Voyage's large embedding model. High performance with instruction fine-tuning.",
pricePerMillion: 0.12,
mtebScore: 68.28,
maxContext: 4000,
enabled: false,
model_dim: 1024,
normalize: false,
query_prefix: "",
passage_prefix: "",
},
{
cloud_provider_name: "Voyage",
model_name: "voyage-light-2-instruct",
description:
"Voyage's lightweight embedding model. Good balance of performance and efficiency.",
pricePerMillion: 0.12,
mtebScore: 67.13,
maxContext: 16000,
enabled: false,
model_dim: 1024,
normalize: false,
query_prefix: "",
passage_prefix: "",
},
],
},
];
export const INVALID_OLD_MODEL = "thenlper/gte-small";
export function checkModelNameIsValid(
modelName: string | undefined | null
): boolean {
return !!modelName && modelName !== INVALID_OLD_MODEL;
}

View File

@ -1,87 +0,0 @@
export interface EmbeddingModelResponse {
model_name: string | null;
}
export interface FullEmbeddingModelResponse {
current_model_name: string;
secondary_model_name: string | null;
}
export interface EmbeddingModelDescriptor {
model_name: string;
model_dim: number;
normalize: boolean;
query_prefix?: string;
passage_prefix?: string;
}
export interface FullEmbeddingModelDescriptor extends EmbeddingModelDescriptor {
description: string;
isDefault?: boolean;
link?: string;
}
export const AVAILABLE_MODELS: FullEmbeddingModelDescriptor[] = [
{
model_name: "intfloat/e5-base-v2",
model_dim: 768,
normalize: true,
description:
"The recommended default for most situations. If you aren't sure which model to use, this is probably the one.",
isDefault: true,
link: "https://huggingface.co/intfloat/e5-base-v2",
query_prefix: "query: ",
passage_prefix: "passage: ",
},
{
model_name: "intfloat/e5-small-v2",
model_dim: 384,
normalize: true,
description:
"A smaller / faster version of the default model. If you're running Danswer on a resource constrained system, then this is a good choice.",
link: "https://huggingface.co/intfloat/e5-small-v2",
query_prefix: "query: ",
passage_prefix: "passage: ",
},
{
model_name: "intfloat/multilingual-e5-base",
model_dim: 768,
normalize: true,
description:
"If you have many documents in other languages besides English, this is the one to go for.",
link: "https://huggingface.co/intfloat/multilingual-e5-base",
query_prefix: "query: ",
passage_prefix: "passage: ",
},
{
model_name: "intfloat/multilingual-e5-small",
model_dim: 384,
normalize: true,
description:
"If you have many documents in other languages besides English, and you're running on a resource constrained system, then this is the one to go for.",
link: "https://huggingface.co/intfloat/multilingual-e5-base",
query_prefix: "query: ",
passage_prefix: "passage: ",
},
];
export const INVALID_OLD_MODEL = "thenlper/gte-small";
export function checkModelNameIsValid(modelName: string | undefined | null) {
if (!modelName) {
return false;
}
if (modelName === INVALID_OLD_MODEL) {
return false;
}
return true;
}
export function fillOutEmeddingModelDescriptor(
embeddingModel: EmbeddingModelDescriptor | FullEmbeddingModelDescriptor
): FullEmbeddingModelDescriptor {
return {
...embeddingModel,
description: "",
};
}

View File

@ -0,0 +1,31 @@
import React from "react";
import { Modal } from "@/components/Modal";
import { Button, Text } from "@tremor/react";
import { CloudEmbeddingModel } from "../components/types";
export function AlreadyPickedModal({
model,
onClose,
}: {
model: CloudEmbeddingModel;
onClose: () => void;
}) {
return (
<Modal
title={`${model.model_name} already chosen`}
onOutsideClick={onClose}
>
<div className="mb-4">
<Text className="text-sm mb-2">
You can select a different one if you want!
</Text>
<div className="flex mt-8 justify-between">
<Button color="blue" onClick={onClose}>
Close
</Button>
</div>
</div>
</Modal>
);
}

View File

@ -0,0 +1,244 @@
import React, { useRef, useState } from "react";
import { Modal } from "@/components/Modal";
import { Button, Text, Callout, Subtitle, Divider } from "@tremor/react";
import { Label, TextFormField } from "@/components/admin/connectors/Field";
import { CloudEmbeddingProvider } from "../components/types";
import {
EMBEDDING_PROVIDERS_ADMIN_URL,
LLM_PROVIDERS_ADMIN_URL,
} from "../../llm/constants";
import { mutate } from "swr";
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { Field } from "formik";
export function ChangeCredentialsModal({
provider,
onConfirm,
onCancel,
onDeleted,
useFileUpload,
}: {
provider: CloudEmbeddingProvider;
onConfirm: () => void;
onCancel: () => void;
onDeleted: () => void;
useFileUpload: boolean;
}) {
const [apiKey, setApiKey] = useState("");
const [testError, setTestError] = useState<string>("");
const [fileName, setFileName] = useState<string>("");
const fileInputRef = useRef<HTMLInputElement>(null);
const [isProcessing, setIsProcessing] = useState(false);
const [deletionError, setDeletionError] = useState<string>("");
const clearFileInput = () => {
setFileName("");
if (fileInputRef.current) {
fileInputRef.current.value = "";
}
};
const handleFileUpload = async (
event: React.ChangeEvent<HTMLInputElement>
) => {
const file = event.target.files?.[0];
setFileName("");
if (file) {
setFileName(file.name);
try {
setDeletionError("");
const fileContent = await file.text();
let jsonContent;
try {
jsonContent = JSON.parse(fileContent);
setApiKey(JSON.stringify(jsonContent));
} catch (parseError) {
throw new Error(
"Failed to parse JSON file. Please ensure it's a valid JSON."
);
}
} catch (error) {
setTestError(
error instanceof Error
? error.message
: "An unknown error occurred while processing the file."
);
setApiKey("");
clearFileInput();
}
}
};
const handleDelete = async () => {
setDeletionError("");
setIsProcessing(true);
try {
const response = await fetch(
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.name}`,
{
method: "DELETE",
}
);
if (!response.ok) {
const errorData = await response.json();
setDeletionError(errorData.detail);
return;
}
mutate(LLM_PROVIDERS_ADMIN_URL);
onDeleted();
} catch (error) {
setDeletionError(
error instanceof Error ? error.message : "An unknown error occurred"
);
} finally {
setIsProcessing(false);
}
};
const handleSubmit = async () => {
setTestError("");
try {
const body = JSON.stringify({
api_key: apiKey,
provider: provider.name.toLowerCase().split(" ")[0],
default_model_id: provider.name,
});
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider: provider.name.toLowerCase().split(" ")[0],
api_key: apiKey,
}),
});
if (!testResponse.ok) {
const errorMsg = (await testResponse.json()).detail;
throw new Error(errorMsg);
}
const updateResponse = await fetch(EMBEDDING_PROVIDERS_ADMIN_URL, {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
name: provider.name,
api_key: apiKey,
is_default_provider: false,
is_configured: true,
}),
});
if (!updateResponse.ok) {
const errorData = await updateResponse.json();
throw new Error(
errorData.detail || "Failed to update provider- check your API key"
);
}
onConfirm();
} catch (error) {
setTestError(
error instanceof Error ? error.message : "An unknown error occurred"
);
}
};
return (
<Modal
icon={provider.icon}
title={`Modify your ${provider.name} key`}
onOutsideClick={onCancel}
>
<div className="mb-4">
<Subtitle className="mt-4 font-bold text-lg mb-2">
Want to swap out your key?
</Subtitle>
<div className="flex flex-col gap-y-2">
{useFileUpload ? (
<>
<Label>Upload JSON File</Label>
<input
ref={fileInputRef}
type="file"
accept=".json"
onChange={handleFileUpload}
className="text-lg w-full p-1"
/>
{fileName && <p>Uploaded file: {fileName}</p>}
</>
) : (
<>
<div className="flex gap-x-2 items-center">
<Label>New API Key</Label>
</div>
<input
className={`
border
border-border
rounded
w-full
py-2
px-3
mt-1
bg-background-emphasis
`}
value={apiKey}
onChange={(e: any) => setApiKey(e.target.value)}
placeholder="Paste your API key here"
/>
</>
)}
<a
href={provider.apiLink}
target="_blank"
rel="noopener noreferrer"
className="underline cursor-pointer"
>
Visit API
</a>
</div>
{testError && (
<Callout title="Error" color="red" className="mt-4">
{testError}
</Callout>
)}
<div className="flex mt-8 justify-between">
<Button
color="blue"
onClick={() => handleSubmit()}
disabled={!apiKey}
>
Execute Key Swap
</Button>
</div>
<Divider />
<Subtitle className="mt-4 font-bold text-lg mb-2">
You can also delete your key.
</Subtitle>
<Text className="mb-2">
This is only possible if you have already switched to a different
embedding type!
</Text>
<Button onClick={handleDelete} color="red">
Delete key
</Button>
{deletionError && (
<Callout title="Error" color="red" className="mt-4">
{deletionError}
</Callout>
)}
</div>
</Modal>
);
}

View File

@ -0,0 +1,41 @@
import React from "react";
import { Modal } from "@/components/Modal";
import { Button, Text, Callout } from "@tremor/react";
import { CloudEmbeddingProvider } from "../components/types";
export function DeleteCredentialsModal({
modelProvider,
onConfirm,
onCancel,
}: {
modelProvider: CloudEmbeddingProvider;
onConfirm: () => void;
onCancel: () => void;
}) {
return (
<Modal
title={`Nuke ${modelProvider.name} Credentials?`}
onOutsideClick={onCancel}
>
<div className="mb-4">
<Text className="text-lg mb-2">
You&apos;re about to delete your {modelProvider.name} credentials. Are
you sure?
</Text>
<Callout
title="Point of No Return"
color="red"
className="mt-4"
></Callout>
<div className="flex mt-8 justify-between">
<Button color="gray" onClick={onCancel}>
Keep Credentaisl
</Button>
<Button color="red" onClick={onConfirm}>
Delete Credentials
</Button>
</div>
</div>
</Modal>
);
}

View File

@ -0,0 +1,61 @@
import { Modal } from "@/components/Modal";
import { Button, Text, Callout } from "@tremor/react";
import {
EmbeddingModelDescriptor,
HostedEmbeddingModel,
} from "../components/types";
export function ModelSelectionConfirmationModal({
selectedModel,
isCustom,
onConfirm,
onCancel,
}: {
selectedModel: HostedEmbeddingModel;
isCustom: boolean;
onConfirm: () => void;
onCancel: () => void;
}) {
return (
<Modal title="Update Embedding Model" onOutsideClick={onCancel}>
<div>
<div className="mb-4">
<Text className="text-lg mb-4">
You have selected: <b>{selectedModel.model_name}</b>. Are you sure
you want to update to this new embedding model?
</Text>
<Text className="text-lg mb-2">
We will re-index all your documents in the background so you will be
able to continue to use Danswer as normal with the old model in the
meantime. Depending on how many documents you have indexed, this may
take a while.
</Text>
<Text className="text-lg mb-2">
<i>NOTE:</i> this re-indexing process will consume more resources
than normal. If you are self-hosting, we recommend that you allocate
at least 16GB of RAM to Danswer during this process.
</Text>
{/* TODO Change this back- ensure functional */}
{!isCustom && (
<Callout title="IMPORTANT" color="yellow" className="mt-4">
We&apos;ve detected that this is a custom-specified embedding
model. Since we have to download the model files before verifying
the configuration&apos;s correctness, we won&apos;t be able to let
you know if the configuration is valid until <b>after</b> we start
re-indexing your documents. If there is an issue, it will show up
on this page as an indexing error on this page after clicking
Confirm.
</Callout>
)}
<div className="flex mt-8">
<Button className="mx-auto" color="green" onClick={onConfirm}>
Confirm
</Button>
</div>
</div>
</div>
</Modal>
);
}

View File

@ -0,0 +1,232 @@
import React, { useRef, useState } from "react";
import { Text, Button, Callout } from "@tremor/react";
import { Formik, Form, Field } from "formik";
import * as Yup from "yup";
import { Label, TextFormField } from "@/components/admin/connectors/Field";
import { LoadingAnimation } from "@/components/Loading";
import { CloudEmbeddingProvider } from "../components/types";
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../../llm/constants";
import { Modal } from "@/components/Modal";
export function ProviderCreationModal({
selectedProvider,
onConfirm,
onCancel,
existingProvider,
}: {
selectedProvider: CloudEmbeddingProvider;
onConfirm: () => void;
onCancel: () => void;
existingProvider?: CloudEmbeddingProvider;
}) {
const useFileUpload = selectedProvider.name == "Google";
const [isProcessing, setIsProcessing] = useState(false);
const [errorMsg, setErrorMsg] = useState<string>("");
const [fileName, setFileName] = useState<string>("");
const initialValues = {
name: existingProvider?.name || selectedProvider.name,
api_key: existingProvider?.api_key || "",
custom_config: existingProvider?.custom_config
? Object.entries(existingProvider.custom_config)
: [],
default_model_name: "",
model_id: 0,
};
const validationSchema = Yup.object({
name: Yup.string().required("Name is required"),
api_key: useFileUpload
? Yup.string()
: Yup.string().required("API Key is required"),
custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)),
});
const fileInputRef = useRef<HTMLInputElement>(null);
const handleFileUpload = async (
event: React.ChangeEvent<HTMLInputElement>,
setFieldValue: (field: string, value: any) => void
) => {
const file = event.target.files?.[0];
setFileName("");
if (file) {
setFileName(file.name);
try {
const fileContent = await file.text();
let jsonContent;
try {
jsonContent = JSON.parse(fileContent);
} catch (parseError) {
throw new Error(
"Failed to parse JSON file. Please ensure it's a valid JSON."
);
}
setFieldValue("api_key", JSON.stringify(jsonContent));
} catch (error) {
setFieldValue("api_key", "");
}
}
};
const handleSubmit = async (
values: any,
{ setSubmitting }: { setSubmitting: (isSubmitting: boolean) => void }
) => {
setIsProcessing(true);
setErrorMsg("");
try {
const customConfig = Object.fromEntries(values.custom_config);
const initialResponse = await fetch(
"/api/admin/embedding/test-embedding",
{
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider: values.name.toLowerCase().split(" ")[0],
api_key: values.api_key,
}),
}
);
if (!initialResponse.ok) {
const errorMsg = (await initialResponse.json()).detail;
setErrorMsg(errorMsg);
setIsProcessing(false);
setSubmitting(false);
return;
}
const response = await fetch(EMBEDDING_PROVIDERS_ADMIN_URL, {
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
...values,
custom_config: customConfig,
is_default_provider: false,
is_configured: true,
}),
});
if (!response.ok) {
const errorData = await response.json();
throw new Error(
errorData.detail || "Failed to update provider- check your API key"
);
}
onConfirm();
} catch (error: unknown) {
if (error instanceof Error) {
setErrorMsg(error.message);
} else {
setErrorMsg("An unknown error occurred");
}
} finally {
setIsProcessing(false);
setSubmitting(false);
}
};
return (
<Modal
title={`Configure ${selectedProvider.name}`}
onOutsideClick={onCancel}
icon={selectedProvider.icon}
>
<div>
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={handleSubmit}
>
{({
values,
errors,
touched,
isSubmitting,
handleSubmit,
setFieldValue,
}) => (
<Form onSubmit={handleSubmit} className="space-y-4">
<Text className="text-lg mb-2">
You are setting the credentials for this provider. To access
this information, follow the instructions{" "}
<a
className="cursor-pointer underline"
target="_blank"
href={selectedProvider.docsLink}
>
here
</a>{" "}
and gather your{" "}
<a
className="cursor-pointer underline"
target="_blank"
href={selectedProvider.apiLink}
>
API KEY
</a>
</Text>
<div className="flex flex-col gap-y-2">
{useFileUpload ? (
<>
<Label>Upload JSON File</Label>
<input
ref={fileInputRef}
type="file"
accept=".json"
onChange={(e) => handleFileUpload(e, setFieldValue)}
className="text-lg w-full p-1"
/>
{fileName && <p>Uploaded file: {fileName}</p>}
</>
) : (
<TextFormField
name="api_key"
label="API Key"
placeholder="API Key"
type="password"
/>
)}
<a
href={selectedProvider.apiLink}
target="_blank"
className="underline cursor-pointer"
>
Learn more here
</a>
</div>
{errorMsg && (
<Callout title="Error" color="red">
{errorMsg}
</Callout>
)}
<Button
type="submit"
color="blue"
className="w-full"
disabled={isSubmitting}
>
{isProcessing ? (
<LoadingAnimation />
) : existingProvider ? (
"Update"
) : (
"Create"
)}
</Button>
</Form>
)}
</Formik>
</div>
</Modal>
);
}

View File

@ -0,0 +1,37 @@
import React from "react";
import { Modal } from "@/components/Modal";
import { Button, Text, Callout } from "@tremor/react";
import { CloudEmbeddingModel } from "../components/types";
export function SelectModelModal({
model,
onConfirm,
onCancel,
}: {
model: CloudEmbeddingModel;
onConfirm: () => void;
onCancel: () => void;
}) {
return (
<Modal
title={`Elevate Your Game with ${model.model_name}`}
onOutsideClick={onCancel}
>
<div className="mb-4">
<Text className="text-lg mb-2">
You&apos;re about to set your embedding model to {model.model_name}.
<br />
Are you sure?
</Text>
<div className="flex mt-8 justify-between">
<Button color="gray" onClick={onCancel}>
Exit
</Button>
<Button color="green" onClick={onConfirm}>
Continue
</Button>
</div>
</div>
</Modal>
);
}

View File

@ -3,28 +3,75 @@
import { ThreeDotsLoader } from "@/components/Loading";
import { AdminPageTitle } from "@/components/admin/Title";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { Button, Card, Text, Title } from "@tremor/react";
import { Button, Text, Title } from "@tremor/react";
import { FiPackage } from "react-icons/fi";
import useSWR, { mutate } from "swr";
import { ModelOption, ModelSelector } from "./ModelSelector";
import { ModelOption, ModelPreview } from "./components/ModelSelector";
import { useState } from "react";
import { ModelSelectionConfirmaionModal } from "./ModelSelectionConfirmation";
import { ReindexingProgressTable } from "./ReindexingProgressTable";
import { ReindexingProgressTable } from "./components/ReindexingProgressTable";
import { Modal } from "@/components/Modal";
import {
CloudEmbeddingProvider,
CloudEmbeddingModel,
AVAILABLE_CLOUD_PROVIDERS,
AVAILABLE_MODELS,
EmbeddingModelDescriptor,
INVALID_OLD_MODEL,
fillOutEmeddingModelDescriptor,
} from "./embeddingModels";
HostedEmbeddingModel,
EmbeddingModelDescriptor,
} from "./components/types";
import { ErrorCallout } from "@/components/ErrorCallout";
import { Connector, ConnectorIndexingStatus } from "@/lib/types";
import Link from "next/link";
import { CustomModelForm } from "./CustomModelForm";
import OpenEmbeddingPage from "./OpenEmbeddingPage";
import CloudEmbeddingPage from "./CloudEmbeddingPage";
import { ProviderCreationModal } from "./modals/ProviderCreationModal";
import { DeleteCredentialsModal } from "./modals/DeleteCredentialsModal";
import { SelectModelModal } from "./modals/SelectModelModal";
import { ChangeCredentialsModal } from "./modals/ChangeCredentialsModal";
import { ModelSelectionConfirmationModal } from "./modals/ModelSelectionModal";
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../llm/constants";
import { AlreadyPickedModal } from "./modals/AlreadyPickedModal";
export interface EmbeddingDetails {
api_key: string;
custom_config: any;
default_model_id?: number;
name: string;
}
function Main() {
const [tentativeNewEmbeddingModel, setTentativeNewEmbeddingModel] =
useState<EmbeddingModelDescriptor | null>(null);
const [openToggle, setOpenToggle] = useState(true);
// Cloud Provider based modals
const [showTentativeProvider, setShowTentativeProvider] =
useState<CloudEmbeddingProvider | null>(null);
const [showUnconfiguredProvider, setShowUnconfiguredProvider] =
useState<CloudEmbeddingProvider | null>(null);
const [changeCredentialsProvider, setChangeCredentialsProvider] =
useState<CloudEmbeddingProvider | null>(null);
// Cloud Model based modals
const [alreadySelectedModel, setAlreadySelectedModel] =
useState<CloudEmbeddingModel | null>(null);
const [showTentativeModel, setShowTentativeModel] =
useState<CloudEmbeddingModel | null>(null);
const [showModelInQueue, setShowModelInQueue] =
useState<CloudEmbeddingModel | null>(null);
// Open Model based modals
const [showTentativeOpenProvider, setShowTentativeOpenProvider] =
useState<HostedEmbeddingModel | null>(null);
// Enabled / unenabled providers
const [newEnabledProviders, setNewEnabledProviders] = useState<string[]>([]);
const [newUnenabledProviders, setNewUnenabledProviders] = useState<string[]>(
[]
);
const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] =
useState<boolean>(false);
const [isCancelling, setIsCancelling] = useState<boolean>(false);
const [showAddConnectorPopup, setShowAddConnectorPopup] =
useState<boolean>(false);
@ -33,16 +80,22 @@ function Main() {
data: currentEmeddingModel,
isLoading: isLoadingCurrentModel,
error: currentEmeddingModelError,
} = useSWR<EmbeddingModelDescriptor>(
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
"/api/secondary-index/get-current-embedding-model",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
const { data: embeddingProviderDetails } = useSWR<EmbeddingDetails[]>(
EMBEDDING_PROVIDERS_ADMIN_URL,
errorHandlingFetcher
);
const {
data: futureEmbeddingModel,
isLoading: isLoadingFutureModel,
error: futureEmeddingModelError,
} = useSWR<EmbeddingModelDescriptor | null>(
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
"/api/secondary-index/get-secondary-embedding-model",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
@ -61,27 +114,41 @@ function Main() {
{ refreshInterval: 5000 } // 5 seconds
);
const onSelect = async (model: EmbeddingModelDescriptor) => {
if (currentEmeddingModel?.model_name === INVALID_OLD_MODEL) {
await onConfirm(model);
} else {
setTentativeNewEmbeddingModel(model);
}
};
const onConfirm = async (
model: CloudEmbeddingModel | HostedEmbeddingModel
) => {
let newModel: EmbeddingModelDescriptor;
if ("cloud_provider_name" in model) {
// This is a CloudEmbeddingModel
newModel = {
...model,
model_name: model.model_name,
cloud_provider_name: model.cloud_provider_name,
// cloud_provider_id: model.cloud_provider_id || 0,
};
} else {
// This is an EmbeddingModelDescriptor
newModel = {
...model,
model_name: model.model_name!,
description: "",
cloud_provider_name: null,
};
}
const onConfirm = async (model: EmbeddingModelDescriptor) => {
const response = await fetch(
"/api/secondary-index/set-new-embedding-model",
{
method: "POST",
body: JSON.stringify(model),
body: JSON.stringify(newModel),
headers: {
"Content-Type": "application/json",
},
}
);
if (response.ok) {
setTentativeNewEmbeddingModel(null);
setShowTentativeModel(null);
mutate("/api/secondary-index/get-secondary-embedding-model");
if (!connectors || !connectors.length) {
setShowAddConnectorPopup(true);
@ -96,14 +163,13 @@ function Main() {
method: "POST",
});
if (response.ok) {
setTentativeNewEmbeddingModel(null);
setShowTentativeModel(null);
mutate("/api/secondary-index/get-secondary-embedding-model");
} else {
alert(
`Failed to cancel embedding model update - ${await response.text()}`
);
}
setIsCancelling(false);
};
@ -119,216 +185,328 @@ function Main() {
return <ErrorCallout errorTitle="Failed to fetch embedding model status" />;
}
const currentModelName = currentEmeddingModel.model_name;
const currentModel =
AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
fillOutEmeddingModelDescriptor(currentEmeddingModel);
const onConfirmSelection = async (model: EmbeddingModelDescriptor) => {
const response = await fetch(
"/api/secondary-index/set-new-embedding-model",
{
method: "POST",
body: JSON.stringify(model),
headers: {
"Content-Type": "application/json",
},
}
);
if (response.ok) {
setShowTentativeModel(null);
mutate("/api/secondary-index/get-secondary-embedding-model");
if (!connectors || !connectors.length) {
setShowAddConnectorPopup(true);
}
} else {
alert(`Failed to update embedding model - ${await response.text()}`);
}
};
const newModelSelection = futureEmbeddingModel
? AVAILABLE_MODELS.find(
(model) => model.model_name === futureEmbeddingModel.model_name
) || fillOutEmeddingModelDescriptor(futureEmbeddingModel)
: null;
const currentModelName = currentEmeddingModel?.model_name;
const AVAILABLE_CLOUD_PROVIDERS_FLATTENED = AVAILABLE_CLOUD_PROVIDERS.flatMap(
(provider) =>
provider.embedding_models.map((model) => ({
...model,
cloud_provider_id: provider.id,
model_name: model.model_name, // Ensure model_name is set for consistency
}))
);
const currentModel: CloudEmbeddingModel | HostedEmbeddingModel =
AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
AVAILABLE_CLOUD_PROVIDERS_FLATTENED.find(
(model) => model.model_name === currentEmeddingModel.model_name
)!;
// ||
// fillOutEmeddingModelDescriptor(currentEmeddingModel);
const onSelectOpenSource = async (model: HostedEmbeddingModel) => {
if (currentEmeddingModel?.model_name === INVALID_OLD_MODEL) {
await onConfirmSelection(model);
} else {
setShowTentativeOpenProvider(model);
}
};
const selectedModel = AVAILABLE_CLOUD_PROVIDERS[0];
const clientsideAddProvider = (provider: CloudEmbeddingProvider) => {
const providerName = provider.name;
setNewEnabledProviders((newEnabledProviders) => [
...newEnabledProviders,
providerName,
]);
setNewUnenabledProviders((newUnenabledProviders) =>
newUnenabledProviders.filter(
(givenProvidername) => givenProvidername != providerName
)
);
};
const clientsideRemoveProvider = (provider: CloudEmbeddingProvider) => {
const providerName = provider.name;
setNewEnabledProviders((newEnabledProviders) =>
newEnabledProviders.filter(
(givenProvidername) => givenProvidername != providerName
)
);
setNewUnenabledProviders((newUnenabledProviders) => [
...newUnenabledProviders,
providerName,
]);
};
return (
<div>
{tentativeNewEmbeddingModel && (
<ModelSelectionConfirmaionModal
selectedModel={tentativeNewEmbeddingModel}
isCustom={
AVAILABLE_MODELS.find(
(model) =>
model.model_name === tentativeNewEmbeddingModel.model_name
) === undefined
}
onConfirm={() => onConfirm(tentativeNewEmbeddingModel)}
onCancel={() => setTentativeNewEmbeddingModel(null)}
/>
)}
{showAddConnectorPopup && (
<Modal>
<div>
<div>
<b className="text-base">Embeding model successfully selected</b>{" "}
🙌
<br />
<br />
To complete the initial setup, let&apos;s add a connector!
<br />
<br />
Connectors are the way that Danswer gets data from your
organization&apos;s various data sources. Once setup, we&apos;ll
automatically sync data from your apps and docs into Danswer, so
you can search all through all of them in one place.
</div>
<div className="flex">
<Link className="mx-auto mt-2 w-fit" href="/admin/add-connector">
<Button className="mt-3 mx-auto" size="xs">
Add Connector
</Button>
</Link>
</div>
</div>
</Modal>
)}
{isCancelling && (
<Modal
onOutsideClick={() => setIsCancelling(false)}
title="Cancel Embedding Model Switch"
>
<div>
<div>
Are you sure you want to cancel?
<br />
<br />
Cancelling will revert to the previous model and all progress will
be lost.
</div>
<div className="flex">
<Button onClick={onCancel} className="mt-3 mx-auto" color="green">
Confirm
</Button>
</div>
</div>
</Modal>
)}
<div className="h-screen">
<Text>
Embedding models are used to generate embeddings for your documents,
which then power Danswer&apos;s search.
</Text>
{alreadySelectedModel && (
<AlreadyPickedModal
model={alreadySelectedModel}
onClose={() => setAlreadySelectedModel(null)}
/>
)}
{showTentativeOpenProvider && (
<ModelSelectionConfirmationModal
selectedModel={showTentativeOpenProvider}
isCustom={
AVAILABLE_MODELS.find(
(model) =>
model.model_name === showTentativeOpenProvider.model_name
) === undefined
}
onConfirm={() => onConfirm(showTentativeOpenProvider)}
onCancel={() => setShowTentativeOpenProvider(null)}
/>
)}
{showTentativeProvider && (
<ProviderCreationModal
selectedProvider={showTentativeProvider}
onConfirm={() => {
setShowTentativeProvider(showUnconfiguredProvider);
clientsideAddProvider(showTentativeProvider);
if (showModelInQueue) {
setShowTentativeModel(showModelInQueue);
}
}}
onCancel={() => {
setShowModelInQueue(null);
setShowTentativeProvider(null);
}}
/>
)}
{changeCredentialsProvider && (
<ChangeCredentialsModal
// setPopup={setPopup}
useFileUpload={changeCredentialsProvider.name == "Google"}
onDeleted={() => {
clientsideRemoveProvider(changeCredentialsProvider);
setChangeCredentialsProvider(null);
}}
provider={changeCredentialsProvider}
onConfirm={() => setChangeCredentialsProvider(null)}
onCancel={() => setChangeCredentialsProvider(null)}
/>
)}
{showTentativeModel && (
<SelectModelModal
model={showTentativeModel}
onConfirm={() => {
setShowModelInQueue(null);
onConfirm(showTentativeModel);
}}
onCancel={() => {
setShowModelInQueue(null);
setShowTentativeModel(null);
}}
/>
)}
{showDeleteCredentialsModal && (
<DeleteCredentialsModal
modelProvider={showTentativeProvider!}
onConfirm={() => {
setShowDeleteCredentialsModal(false);
}}
onCancel={() => setShowDeleteCredentialsModal(false)}
/>
)}
{currentModel ? (
<>
<Title className="mt-8 mb-2">Current Embedding Model</Title>
<Text>
<ModelOption model={currentModel} />
<ModelPreview model={currentModel} />
</Text>
</>
) : (
newModelSelection &&
(!connectors || !connectors.length) && (
<>
<Title className="mt-8 mb-2">Current Embedding Model</Title>
<Title className="mt-8 mb-4">Choose your Embedding Model</Title>
)}
<Text>
<ModelOption model={newModelSelection} />
</Text>
</>
)
{!(futureEmbeddingModel && connectors && connectors.length > 0) && (
<>
<Title className="mt-8">Switch your Embedding Model</Title>
<Text className="mb-4">
If the current model is not working for you, you can update your
model choice below. Note that this will require a complete
re-indexing of all your documents across every connected source. We
will take care of this in the background, but depending on the size
of your corpus, this could take hours, day, or even weeks. You can
monitor the progress of the re-indexing on this page.
</Text>
<div className="mt-8 text-sm mr-auto mb-12 divide-x-2 flex ">
<button
onClick={() => setOpenToggle(true)}
className={` mx-2 p-2 font-bold ${openToggle ? "rounded bg-neutral-900 text-neutral-100 underline" : "hover:underline"}`}
>
Self-hosted
</button>
<div className="px-2 ">
<button
onClick={() => setOpenToggle(false)}
className={`mx-2 p-2 font-bold ${!openToggle ? "rounded bg-neutral-900 text-neutral-100 underline" : " hover:underline"}`}
>
Cloud-based
</button>
</div>
</div>
</>
)}
{!showAddConnectorPopup &&
(!newModelSelection ? (
<div>
{currentModel ? (
<>
<Title className="mt-8">Switch your Embedding Model</Title>
<Text className="mb-4">
If the current model is not working for you, you can update
your model choice below. Note that this will require a
complete re-indexing of all your documents across every
connected source. We will take care of this in the background,
but depending on the size of your corpus, this could take
hours, day, or even weeks. You can monitor the progress of the
re-indexing on this page.
</Text>
</>
) : (
<>
<Title className="mt-8 mb-4">Choose your Embedding Model</Title>
</>
)}
<Text className="mb-4">
Below are a curated selection of quality models that we recommend
you choose from.
</Text>
<ModelSelector
modelOptions={AVAILABLE_MODELS.filter(
(modelOption) => modelOption.model_name !== currentModelName
)}
setSelectedModel={onSelect}
/>
<Text className="mt-6">
Alternatively, (if you know what you&apos;re doing) you can
specify a{" "}
<a
target="_blank"
href="https://www.sbert.net/"
className="text-link"
>
SentenceTransformers
</a>
-compatible model of your choice below. The rough list of
supported models can be found{" "}
<a
target="_blank"
href="https://huggingface.co/models?library=sentence-transformers&sort=trending"
className="text-link"
>
here
</a>
.
<br />
<b>NOTE:</b> not all models listed will work with Danswer, since
some have unique interfaces or special requirements. If in doubt,
reach out to the Danswer team.
</Text>
<div className="w-full flex">
<Card className="mt-4 2xl:w-4/6 mx-auto">
<CustomModelForm onSubmit={onSelect} />
</Card>
</div>
</div>
!futureEmbeddingModel &&
(openToggle ? (
<OpenEmbeddingPage
onSelectOpenSource={onSelectOpenSource}
currentModelName={currentModelName!}
/>
) : (
connectors &&
connectors.length > 0 && (
<div>
<Title className="mt-8">Current Upgrade Status</Title>
<div className="mt-4">
<div className="italic text-sm mb-2">
Currently in the process of switching to:
</div>
<ModelOption model={newModelSelection} />
<Button
color="red"
size="xs"
className="mt-4"
onClick={() => setIsCancelling(true)}
>
Cancel
</Button>
<Text className="my-4">
The table below shows the re-indexing progress of all existing
connectors. Once all connectors have been re-indexed
successfully, the new model will be used for all search
queries. Until then, we will use the old model so that no
downtime is necessary during this transition.
</Text>
{isLoadingOngoingReIndexingStatus ? (
<ThreeDotsLoader />
) : ongoingReIndexingStatus ? (
<ReindexingProgressTable
reindexingProgress={ongoingReIndexingStatus}
/>
) : (
<ErrorCallout errorTitle="Failed to fetch re-indexing progress" />
)}
</div>
</div>
)
<CloudEmbeddingPage
setShowModelInQueue={setShowModelInQueue}
setShowTentativeModel={setShowTentativeModel}
currentModel={currentModel}
setAlreadySelectedModel={setAlreadySelectedModel}
embeddingProviderDetails={embeddingProviderDetails}
newEnabledProviders={newEnabledProviders}
newUnenabledProviders={newUnenabledProviders}
setShowTentativeProvider={setShowTentativeProvider}
selectedModel={selectedModel}
setChangeCredentialsProvider={setChangeCredentialsProvider}
/>
))}
{openToggle && (
<>
{showAddConnectorPopup && (
<Modal>
<div>
<div>
<b className="text-base">
Embedding model successfully selected
</b>{" "}
🙌
<br />
<br />
To complete the initial setup, let&apos;s add a connector!
<br />
<br />
Connectors are the way that Danswer gets data from your
organization&apos;s various data sources. Once setup,
we&apos;ll automatically sync data from your apps and docs
into Danswer, so you can search all through all of them in one
place.
</div>
<div className="flex">
<Link
className="mx-auto mt-2 w-fit"
href="/admin/add-connector"
>
<Button className="mt-3 mx-auto" size="xs">
Add Connector
</Button>
</Link>
</div>
</div>
</Modal>
)}
{isCancelling && (
<Modal
onOutsideClick={() => setIsCancelling(false)}
title="Cancel Embedding Model Switch"
>
<div>
<div>
Are you sure you want to cancel?
<br />
<br />
Cancelling will revert to the previous model and all progress
will be lost.
</div>
<div className="flex">
<Button
onClick={onCancel}
className="mt-3 mx-auto"
color="green"
>
Confirm
</Button>
</div>
</div>
</Modal>
)}
</>
)}
{futureEmbeddingModel && connectors && connectors.length > 0 && (
<div>
<Title className="mt-8">Current Upgrade Status</Title>
<div className="mt-4">
<div className="italic text-lg mb-2">
Currently in the process of switching to:{" "}
{futureEmbeddingModel.model_name}
</div>
{/* <ModelOption model={futureEmbeddingModel} /> */}
<Button
color="red"
size="xs"
className="mt-4"
onClick={() => setIsCancelling(true)}
>
Cancel
</Button>
<Text className="my-4">
The table below shows the re-indexing progress of all existing
connectors. Once all connectors have been re-indexed successfully,
the new model will be used for all search queries. Until then, we
will use the old model so that no downtime is necessary during
this transition.
</Text>
{isLoadingOngoingReIndexingStatus ? (
<ThreeDotsLoader />
) : ongoingReIndexingStatus ? (
<ReindexingProgressTable
reindexingProgress={ongoingReIndexingStatus}
/>
) : (
<ErrorCallout errorTitle="Failed to fetch re-indexing progress" />
)}
</div>
</div>
)}
</div>
);
}

View File

@ -1 +1,4 @@
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
export const EMBEDDING_PROVIDERS_ADMIN_URL =
"/api/admin/embedding/embedding-provider";

View File

@ -20,7 +20,7 @@ import {
import { unstable_noStore as noStore } from "next/cache";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import { personaComparator } from "../admin/assistants/lib";
import { FullEmbeddingModelResponse } from "../admin/models/embedding/embeddingModels";
import { FullEmbeddingModelResponse } from "../admin/models/embedding/components/types";
import { NoSourcesModal } from "@/components/initialSetup/search/NoSourcesModal";
import { NoCompleteSourcesModal } from "@/components/initialSetup/search/NoCompleteSourceModal";
import { ChatPopup } from "../chat/ChatPopup";

View File

@ -1,7 +1,9 @@
import { Divider } from "@tremor/react";
import { FiX } from "react-icons/fi";
import { IconProps } from "./icons/icons";
interface ModalProps {
icon?: ({ size, className }: IconProps) => JSX.Element;
children: JSX.Element | string;
title?: JSX.Element | string;
onOutsideClick?: () => void;
@ -13,6 +15,7 @@ interface ModalProps {
}
export function Modal({
icon,
children,
title,
onOutsideClick,
@ -44,10 +47,15 @@ export function Modal({
<>
<div className="flex mb-4">
<h2
className={"my-auto font-bold " + (titleSize || "text-2xl")}
className={
"my-auto flex content-start gap-x-4 font-bold " +
(titleSize || "text-2xl")
}
>
{title}
{icon && icon({ size: 30 })}
</h2>
{onOutsideClick && (
<div
onClick={onOutsideClick}

View File

@ -72,11 +72,16 @@ import sharepointIcon from "../../../public/Sharepoint.png";
import teamsIcon from "../../../public/Teams.png";
import mediawikiIcon from "../../../public/MediaWiki.svg";
import wikipediaIcon from "../../../public/Wikipedia.svg";
import discourseIcon from "../../../public/Discourse.png";
import clickupIcon from "../../../public/Clickup.svg";
import cohereIcon from "../../../public/Cohere.svg";
import voyageIcon from "../../../public/Voyage.png";
import googleIcon from "../../../public/Google.webp";
import { FaRobot } from "react-icons/fa";
interface IconProps {
export interface IconProps {
size?: number;
className?: string;
}
@ -84,20 +89,6 @@ interface IconProps {
export const defaultTailwindCSS = "my-auto flex flex-shrink-0 text-default";
export const defaultTailwindCSSBlue = "my-auto flex flex-shrink-0 text-link";
export const OpenAIIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<div
style={{ width: `${size + 4}px`, height: `${size + 4}px` }}
className={`w-[${size + 4}px] h-[${size + 4}px] -m-0.5 ` + className}
>
<Image src={openAISVG} alt="Logo" width="96" height="96" />
</div>
);
};
export const OpenSourceIcon = ({
size = 16,
className = defaultTailwindCSS,
@ -528,6 +519,62 @@ export const ZulipIcon = ({
);
};
export const OpenAIIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<div
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
>
<Image src={openAISVG} alt="Logo" width="96" height="96" />
</div>
);
};
export const VoyageIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<div
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
>
<Image src={voyageIcon} alt="Logo" width="96" height="96" />
</div>
);
};
export const GoogleIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<div
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
>
<Image src={googleIcon} alt="Logo" width="96" height="96" />
</div>
);
};
export const CohereIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<div
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
>
<Image src={cohereIcon} alt="Logo" width="96" height="96" />
</div>
);
};
export const GoogleStorageIcon = ({
size = 16,
className = defaultTailwindCSS,

View File

@ -13,7 +13,7 @@ import {
} from "@/lib/types";
import { ChatSession } from "@/app/chat/interfaces";
import { Persona } from "@/app/admin/assistants/interfaces";
import { FullEmbeddingModelResponse } from "@/app/admin/models/embedding/embeddingModels";
import { FullEmbeddingModelResponse } from "@/app/admin/models/embedding/components/types";
import { Settings } from "@/app/admin/settings/interfaces";
import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs";
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";