diff --git a/backend/alembic/versions/44f856ae2a4a_add_cloud_embedding_model.py b/backend/alembic/versions/44f856ae2a4a_add_cloud_embedding_model.py new file mode 100644 index 000000000..2d0e1a32f --- /dev/null +++ b/backend/alembic/versions/44f856ae2a4a_add_cloud_embedding_model.py @@ -0,0 +1,65 @@ +"""add cloud embedding model and update embedding_model + +Revision ID: 44f856ae2a4a +Revises: d716b0791ddd +Create Date: 2024-06-28 20:01:05.927647 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "44f856ae2a4a" +down_revision = "d716b0791ddd" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + # Create embedding_provider table + op.create_table( + "embedding_provider", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("api_key", sa.LargeBinary(), nullable=True), + sa.Column("default_model_id", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), + ) + + # Add cloud_provider_id to embedding_model table + op.add_column( + "embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True) + ) + + # Add foreign key constraints + op.create_foreign_key( + "fk_embedding_model_cloud_provider", + "embedding_model", + "embedding_provider", + ["cloud_provider_id"], + ["id"], + ) + op.create_foreign_key( + "fk_embedding_provider_default_model", + "embedding_provider", + "embedding_model", + ["default_model_id"], + ["id"], + ) + + +def downgrade() -> None: + # Remove foreign key constraints + op.drop_constraint( + "fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey" + ) + op.drop_constraint( + "fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey" + ) + + # Remove cloud_provider_id column + op.drop_column("embedding_model", "cloud_provider_id") + + # Drop embedding_provider table + op.drop_table("embedding_provider") diff --git a/backend/alembic/versions/d716b0791ddd_combined_slack_id_fields.py b/backend/alembic/versions/d716b0791ddd_combined_slack_id_fields.py index 3f13d7c55..6510d8b39 100644 --- a/backend/alembic/versions/d716b0791ddd_combined_slack_id_fields.py +++ b/backend/alembic/versions/d716b0791ddd_combined_slack_id_fields.py @@ -10,8 +10,8 @@ from alembic import op # revision identifiers, used by Alembic. revision = "d716b0791ddd" down_revision = "7aea705850d5" -branch_labels = None -depends_on = None +branch_labels: None = None +depends_on: None = None def upgrade() -> None: diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index fa684f020..dd529cbf5 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -98,7 +98,6 @@ def _run_indexing( 3. Updates Postgres to record the indexed documents + the outcome of this run """ start_time = time.time() - db_embedding_model = index_attempt.embedding_model index_name = db_embedding_model.index_name @@ -116,6 +115,8 @@ def _run_indexing( normalize=db_embedding_model.normalize, query_prefix=db_embedding_model.query_prefix, passage_prefix=db_embedding_model.passage_prefix, + api_key=db_embedding_model.api_key, + provider_type=db_embedding_model.provider_type, ) indexing_pipeline = build_indexing_pipeline( @@ -287,6 +288,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA db_session=db_session, index_attempt_id=index_attempt_id, ) + if attempt is None: raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'") diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 9ca65f8b3..fe9ed6008 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -343,13 +343,15 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non # So that the first time users aren't surprised by really slow speed of first # batch of documents indexed - logger.info("Running a first inference to warm up embedding model") - warm_up_encoders( - model_name=db_embedding_model.model_name, - normalize=db_embedding_model.normalize, - model_server_host=INDEXING_MODEL_SERVER_HOST, - model_server_port=MODEL_SERVER_PORT, - ) + + if db_embedding_model.cloud_provider_id is None: + logger.info("Running a first inference to warm up embedding model") + warm_up_encoders( + model_name=db_embedding_model.model_name, + normalize=db_embedding_model.normalize, + model_server_host=INDEXING_MODEL_SERVER_HOST, + model_server_port=MODEL_SERVER_PORT, + ) client_primary: Client | SimpleJobClient client_secondary: Client | SimpleJobClient diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index 4f6df7654..4ec876407 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -469,13 +469,13 @@ if __name__ == "__main__": # or the tokens have updated (set up for the first time) with Session(get_sqlalchemy_engine()) as db_session: embedding_model = get_current_db_embedding_model(db_session) - - warm_up_encoders( - model_name=embedding_model.model_name, - normalize=embedding_model.normalize, - model_server_host=MODEL_SERVER_HOST, - model_server_port=MODEL_SERVER_PORT, - ) + if embedding_model.cloud_provider_id is None: + warm_up_encoders( + model_name=embedding_model.model_name, + normalize=embedding_model.normalize, + model_server_host=MODEL_SERVER_HOST, + model_server_port=MODEL_SERVER_PORT, + ) slack_bot_tokens = latest_slack_bot_tokens # potentially may cause a message to be dropped, but it is complicated diff --git a/backend/danswer/db/embedding_model.py b/backend/danswer/db/embedding_model.py index ae2b98d51..9388a3e3c 100644 --- a/backend/danswer/db/embedding_model.py +++ b/backend/danswer/db/embedding_model.py @@ -10,10 +10,15 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS +from danswer.db.llm import fetch_embedding_provider +from danswer.db.models import CloudEmbeddingProvider from danswer.db.models import EmbeddingModel from danswer.db.models import IndexModelStatus from danswer.indexing.models import EmbeddingModelDetail from danswer.search.search_nlp_models import clean_model_name +from danswer.server.manage.embedding.models import ( + CloudEmbeddingProvider as ServerCloudEmbeddingProvider, +) from danswer.utils.logger import setup_logger logger = setup_logger() @@ -31,6 +36,7 @@ def create_embedding_model( query_prefix=model_details.query_prefix, passage_prefix=model_details.passage_prefix, status=status, + cloud_provider_id=model_details.cloud_provider_id, # Every single embedding model except the initial one from migrations has this name # The initial one from migration is called "danswer_chunk" index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}", @@ -42,6 +48,42 @@ def create_embedding_model( return embedding_model +def get_model_id_from_name( + db_session: Session, embedding_provider_name: str +) -> int | None: + query = select(CloudEmbeddingProvider).where( + CloudEmbeddingProvider.name == embedding_provider_name + ) + provider = db_session.execute(query).scalars().first() + return provider.id if provider else None + + +def get_current_db_embedding_provider( + db_session: Session, +) -> ServerCloudEmbeddingProvider | None: + current_embedding_model = EmbeddingModelDetail.from_model( + get_current_db_embedding_model(db_session=db_session) + ) + + if ( + current_embedding_model is None + or current_embedding_model.cloud_provider_id is None + ): + return None + + embedding_provider = fetch_embedding_provider( + db_session=db_session, provider_id=current_embedding_model.cloud_provider_id + ) + if embedding_provider is None: + raise RuntimeError("No embedding provider exists for this model.") + + current_embedding_provider = ServerCloudEmbeddingProvider.from_request( + cloud_provider_model=embedding_provider + ) + + return current_embedding_provider + + def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel: query = ( select(EmbeddingModel) diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py index f969dbf68..76c94ee16 100644 --- a/backend/danswer/db/llm.py +++ b/backend/danswer/db/llm.py @@ -2,11 +2,34 @@ from sqlalchemy import delete from sqlalchemy import select from sqlalchemy.orm import Session +from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel from danswer.db.models import LLMProvider as LLMProviderModel +from danswer.server.manage.embedding.models import CloudEmbeddingProvider +from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest from danswer.server.manage.llm.models import FullLLMProvider from danswer.server.manage.llm.models import LLMProviderUpsertRequest +def upsert_cloud_embedding_provider( + db_session: Session, provider: CloudEmbeddingProviderCreationRequest +) -> CloudEmbeddingProvider: + existing_provider = ( + db_session.query(CloudEmbeddingProviderModel) + .filter_by(name=provider.name) + .first() + ) + if existing_provider: + for key, value in provider.dict().items(): + setattr(existing_provider, key, value) + else: + new_provider = CloudEmbeddingProviderModel(**provider.dict()) + db_session.add(new_provider) + existing_provider = new_provider + db_session.commit() + db_session.refresh(existing_provider) + return CloudEmbeddingProvider.from_request(existing_provider) + + def upsert_llm_provider( db_session: Session, llm_provider: LLMProviderUpsertRequest ) -> FullLLMProvider: @@ -26,7 +49,6 @@ def upsert_llm_provider( existing_llm_provider.model_names = llm_provider.model_names db_session.commit() return FullLLMProvider.from_model(existing_llm_provider) - # if it does not exist, create a new entry llm_provider_model = LLMProviderModel( name=llm_provider.name, @@ -46,10 +68,26 @@ def upsert_llm_provider( return FullLLMProvider.from_model(llm_provider_model) +def fetch_existing_embedding_providers( + db_session: Session, +) -> list[CloudEmbeddingProviderModel]: + return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all()) + + def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]: return list(db_session.scalars(select(LLMProviderModel)).all()) +def fetch_embedding_provider( + db_session: Session, provider_id: int +) -> CloudEmbeddingProviderModel | None: + return db_session.scalar( + select(CloudEmbeddingProviderModel).where( + CloudEmbeddingProviderModel.id == provider_id + ) + ) + + def fetch_default_provider(db_session: Session) -> FullLLMProvider | None: provider_model = db_session.scalar( select(LLMProviderModel).where( @@ -70,6 +108,16 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | return FullLLMProvider.from_model(provider_model) +def remove_embedding_provider( + db_session: Session, embedding_provider_name: str +) -> None: + db_session.execute( + delete(CloudEmbeddingProviderModel).where( + CloudEmbeddingProviderModel.name == embedding_provider_name + ) + ) + + def remove_llm_provider(db_session: Session, provider_id: int) -> None: db_session.execute( delete(LLMProviderModel).where(LLMProviderModel.id == provider_id) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 3d33596a1..66a68a473 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -130,6 +130,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base): chat_folders: Mapped[list["ChatFolder"]] = relationship( "ChatFolder", back_populates="user" ) + prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user") # Personas owned by this user personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user") @@ -469,7 +470,7 @@ class Credential(Base): class EmbeddingModel(Base): __tablename__ = "embedding_model" - # ID is used also to indicate the order that the models are configured by the admin + id: Mapped[int] = mapped_column(primary_key=True) model_name: Mapped[str] = mapped_column(String) model_dim: Mapped[int] = mapped_column(Integer) @@ -481,6 +482,16 @@ class EmbeddingModel(Base): ) index_name: Mapped[str] = mapped_column(String) + # New field for cloud provider relationship + cloud_provider_id: Mapped[int | None] = mapped_column( + ForeignKey("embedding_provider.id") + ) + cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship( + "CloudEmbeddingProvider", + back_populates="embedding_models", + foreign_keys=[cloud_provider_id], + ) + index_attempts: Mapped[list["IndexAttempt"]] = relationship( "IndexAttempt", back_populates="embedding_model" ) @@ -500,6 +511,18 @@ class EmbeddingModel(Base): ), ) + def __repr__(self) -> str: + return f"" + + @property + def api_key(self) -> str | None: + return self.cloud_provider.api_key if self.cloud_provider else None + + @property + def provider_type(self) -> str | None: + return self.cloud_provider.name if self.cloud_provider else None + class IndexAttempt(Base): """ @@ -519,6 +542,7 @@ class IndexAttempt(Base): ForeignKey("credential.id"), nullable=True, ) + # Some index attempts that run from beginning will still have this as False # This is only for attempts that are explicitly marked as from the start via # the run once API @@ -879,11 +903,6 @@ class ChatMessageFeedback(Base): ) -""" -Structures, Organizational, Configurations Tables -""" - - class LLMProvider(Base): __tablename__ = "llm_provider" @@ -912,6 +931,29 @@ class LLMProvider(Base): is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True) +class CloudEmbeddingProvider(Base): + __tablename__ = "embedding_provider" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + name: Mapped[str] = mapped_column(String, unique=True) + api_key: Mapped[str | None] = mapped_column(EncryptedString()) + default_model_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("embedding_model.id"), nullable=True + ) + + embedding_models: Mapped[list["EmbeddingModel"]] = relationship( + "EmbeddingModel", + back_populates="cloud_provider", + foreign_keys="EmbeddingModel.cloud_provider_id", + ) + default_model: Mapped["EmbeddingModel"] = relationship( + "EmbeddingModel", foreign_keys=[default_model_id] + ) + + def __repr__(self) -> str: + return f"" + + class DocumentSet(Base): __tablename__ = "document_set" @@ -1194,6 +1236,7 @@ class SlackBotConfig(Base): response_type: Mapped[SlackBotResponseType] = mapped_column( Enum(SlackBotResponseType, native_enum=False), nullable=False ) + enable_auto_filters: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False ) diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index 5e6f7a729..542ff783e 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -50,6 +50,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder): normalize: bool, query_prefix: str | None, passage_prefix: str | None, + api_key: str | None = None, + provider_type: str | None = None, ): super().__init__(model_name, normalize, query_prefix, passage_prefix) self.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE # Currently not customizable @@ -59,6 +61,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder): query_prefix=query_prefix, passage_prefix=passage_prefix, normalize=normalize, + api_key=api_key, + provider_type=provider_type, # The below are globally set, this flow always uses the indexing one server_host=INDEXING_MODEL_SERVER_HOST, server_port=INDEXING_MODEL_SERVER_PORT, diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index 2667665d7..ee91997a3 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -97,13 +97,19 @@ class EmbeddingModelDetail(BaseModel): normalize: bool query_prefix: str | None passage_prefix: str | None + cloud_provider_id: int | None = None + cloud_provider_name: str | None = None @classmethod - def from_model(cls, embedding_model: "EmbeddingModel") -> "EmbeddingModelDetail": + def from_model( + cls, + embedding_model: "EmbeddingModel", + ) -> "EmbeddingModelDetail": return cls( model_name=embedding_model.model_name, model_dim=embedding_model.model_dim, normalize=embedding_model.normalize, query_prefix=embedding_model.query_prefix, passage_prefix=embedding_model.passage_prefix, + cloud_provider_id=embedding_model.cloud_provider_id, ) diff --git a/backend/danswer/main.py b/backend/danswer/main.py index c41d1eb29..3c6b31a72 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -67,6 +67,8 @@ from danswer.server.features.tool.api import admin_router as admin_tool_router from danswer.server.features.tool.api import router as tool_router from danswer.server.gpts.api import router as gpts_router from danswer.server.manage.administrative import router as admin_router +from danswer.server.manage.embedding.api import admin_router as embedding_admin_router +from danswer.server.manage.embedding.api import basic_router as embedding_router from danswer.server.manage.get_state import router as state_router from danswer.server.manage.llm.api import admin_router as llm_admin_router from danswer.server.manage.llm.api import basic_router as llm_router @@ -247,12 +249,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: time.sleep(wait_time) logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}") - warm_up_encoders( - model_name=db_embedding_model.model_name, - normalize=db_embedding_model.normalize, - model_server_host=MODEL_SERVER_HOST, - model_server_port=MODEL_SERVER_PORT, - ) + if db_embedding_model.cloud_provider_id is None: + warm_up_encoders( + model_name=db_embedding_model.model_name, + normalize=db_embedding_model.normalize, + model_server_host=MODEL_SERVER_HOST, + model_server_port=MODEL_SERVER_PORT, + ) optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__}) yield @@ -291,6 +294,8 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended(application, settings_admin_router) include_router_with_global_prefix_prepended(application, llm_admin_router) include_router_with_global_prefix_prepended(application, llm_router) + include_router_with_global_prefix_prepended(application, embedding_admin_router) + include_router_with_global_prefix_prepended(application, embedding_router) include_router_with_global_prefix_prepended( application, token_rate_limit_settings_router ) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 08953b1b7..8767bc03a 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -168,6 +168,7 @@ def stream_answer_objects( max_tokens=max_document_tokens, use_sections=query_req.chunks_above > 0 or query_req.chunks_below > 0, ) + search_tool = SearchTool( db_session=db_session, user=user, diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py index 8f52f9f59..1298f8e54 100644 --- a/backend/danswer/search/retrieval/search_runner.py +++ b/backend/danswer/search/retrieval/search_runner.py @@ -131,6 +131,8 @@ def doc_index_retrieval( query_prefix=db_embedding_model.query_prefix, passage_prefix=db_embedding_model.passage_prefix, normalize=db_embedding_model.normalize, + api_key=db_embedding_model.api_key, + provider_type=db_embedding_model.provider_type, # The below are globally set, this flow always uses the indexing one server_host=MODEL_SERVER_HOST, server_port=MODEL_SERVER_PORT, diff --git a/backend/danswer/search/search_nlp_models.py b/backend/danswer/search/search_nlp_models.py index 761d9aa79..a88e82f2d 100644 --- a/backend/danswer/search/search_nlp_models.py +++ b/backend/danswer/search/search_nlp_models.py @@ -84,20 +84,24 @@ def build_model_server_url( class EmbeddingModel: def __init__( self, - model_name: str, - query_prefix: str | None, - passage_prefix: str | None, - normalize: bool, server_host: str, # Changes depending on indexing or inference server_port: int, + model_name: str | None, + normalize: bool, + query_prefix: str | None, + passage_prefix: str | None, + api_key: str | None, + provider_type: str | None, # The following are globals are currently not configurable max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE, ) -> None: - self.model_name = model_name + self.api_key = api_key + self.provider_type = provider_type self.max_seq_length = max_seq_length self.query_prefix = query_prefix self.passage_prefix = passage_prefix self.normalize = normalize + self.model_name = model_name model_server_url = build_model_server_url(server_host, server_port) self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed" @@ -111,10 +115,13 @@ class EmbeddingModel: prefixed_texts = texts embed_request = EmbedRequest( - texts=prefixed_texts, model_name=self.model_name, + texts=prefixed_texts, max_context_length=self.max_seq_length, normalize_embeddings=self.normalize, + api_key=self.api_key, + provider_type=self.provider_type, + text_type=text_type, ) response = requests.post(self.embed_server_endpoint, json=embed_request.dict()) @@ -187,6 +194,8 @@ def warm_up_encoders( passage_prefix=None, server_host=model_server_host, server_port=model_server_port, + api_key=None, + provider_type=None, ) # First time downloading the models it may take even longer, but just in case, diff --git a/backend/danswer/server/manage/embedding/api.py b/backend/danswer/server/manage/embedding/api.py new file mode 100644 index 000000000..5d35fda06 --- /dev/null +++ b/backend/danswer/server/manage/embedding/api.py @@ -0,0 +1,93 @@ +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from danswer.auth.users import current_admin_user +from danswer.db.embedding_model import get_current_db_embedding_provider +from danswer.db.engine import get_session +from danswer.db.llm import fetch_existing_embedding_providers +from danswer.db.llm import remove_embedding_provider +from danswer.db.llm import upsert_cloud_embedding_provider +from danswer.db.models import User +from danswer.search.enums import EmbedTextType +from danswer.search.search_nlp_models import EmbeddingModel +from danswer.server.manage.embedding.models import CloudEmbeddingProvider +from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest +from danswer.server.manage.embedding.models import TestEmbeddingRequest +from danswer.utils.logger import setup_logger +from shared_configs.configs import MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT + +logger = setup_logger() + + +admin_router = APIRouter(prefix="/admin/embedding") +basic_router = APIRouter(prefix="/embedding") + + +@admin_router.post("/test-embedding") +def test_embedding_configuration( + test_llm_request: TestEmbeddingRequest, + _: User | None = Depends(current_admin_user), +) -> None: + try: + test_model = EmbeddingModel( + server_host=MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + api_key=test_llm_request.api_key, + provider_type=test_llm_request.provider, + normalize=False, + query_prefix=None, + passage_prefix=None, + model_name=None, + ) + test_model.encode(["Test String"], text_type=EmbedTextType.QUERY) + + except ValueError as e: + error_msg = f"Not a valid embedding model. Exception thrown: {e}" + logger.error(error_msg) + raise ValueError(error_msg) + + except Exception as e: + error_msg = "An error occurred while testing your embedding model. Please check your configuration." + logger.error(f"{error_msg} Error message: {e}", exc_info=True) + raise HTTPException(status_code=400, detail=error_msg) + + +@admin_router.get("/embedding-provider") +def list_embedding_providers( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[CloudEmbeddingProvider]: + return [ + CloudEmbeddingProvider.from_request(embedding_provider_model) + for embedding_provider_model in fetch_existing_embedding_providers(db_session) + ] + + +@admin_router.delete("/embedding-provider/{embedding_provider_name}") +def delete_embedding_provider( + embedding_provider_name: str, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + embedding_provider = get_current_db_embedding_provider(db_session=db_session) + if ( + embedding_provider is not None + and embedding_provider_name == embedding_provider.name + ): + raise HTTPException( + status_code=400, detail="You can't delete a currently active model" + ) + + remove_embedding_provider(db_session, embedding_provider_name) + + +@admin_router.put("/embedding-provider") +def put_cloud_embedding_provider( + provider: CloudEmbeddingProviderCreationRequest, + _: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> CloudEmbeddingProvider: + return upsert_cloud_embedding_provider(db_session, provider) diff --git a/backend/danswer/server/manage/embedding/models.py b/backend/danswer/server/manage/embedding/models.py new file mode 100644 index 000000000..4f1e72319 --- /dev/null +++ b/backend/danswer/server/manage/embedding/models.py @@ -0,0 +1,35 @@ +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +if TYPE_CHECKING: + from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel + + +class TestEmbeddingRequest(BaseModel): + provider: str + api_key: str | None = None + + +class CloudEmbeddingProvider(BaseModel): + name: str + api_key: str | None = None + default_model_id: int | None = None + id: int + + @classmethod + def from_request( + cls, cloud_provider_model: "CloudEmbeddingProviderModel" + ) -> "CloudEmbeddingProvider": + return cls( + id=cloud_provider_model.id, + name=cloud_provider_model.name, + api_key=cloud_provider_model.api_key, + default_model_id=cloud_provider_model.default_model_id, + ) + + +class CloudEmbeddingProviderCreationRequest(BaseModel): + name: str + api_key: str | None = None + default_model_id: int | None = None diff --git a/backend/danswer/server/manage/llm/models.py b/backend/danswer/server/manage/llm/models.py index 05a596ffd..10765a18a 100644 --- a/backend/danswer/server/manage/llm/models.py +++ b/backend/danswer/server/manage/llm/models.py @@ -4,6 +4,7 @@ from pydantic import BaseModel from danswer.llm.llm_provider_options import fetch_models_for_provider + if TYPE_CHECKING: from danswer.db.models import LLMProvider as LLMProviderModel diff --git a/backend/danswer/server/manage/secondary_index.py b/backend/danswer/server/manage/secondary_index.py index 6f5adf752..a64c3422d 100644 --- a/backend/danswer/server/manage/secondary_index.py +++ b/backend/danswer/server/manage/secondary_index.py @@ -11,6 +11,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.connector_credential_pair import resync_cc_pair from danswer.db.embedding_model import create_embedding_model from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.embedding_model import get_model_id_from_name from danswer.db.embedding_model import get_secondary_db_embedding_model from danswer.db.embedding_model import update_embedding_model_status from danswer.db.engine import get_session @@ -38,6 +39,19 @@ def set_new_embedding_model( """ current_model = get_current_db_embedding_model(db_session) + if embed_model_details.cloud_provider_name is not None: + cloud_id = get_model_id_from_name( + db_session, embed_model_details.cloud_provider_name + ) + + if cloud_id is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No ID exists for given provider name", + ) + + embed_model_details.cloud_provider_id = cloud_id + if embed_model_details.model_name == current_model.model_name: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/backend/model_server/constants.py b/backend/model_server/constants.py index bc842f546..d2aa4cf36 100644 --- a/backend/model_server/constants.py +++ b/backend/model_server/constants.py @@ -1 +1,38 @@ +from enum import Enum + +from danswer.search.enums import EmbedTextType + + MODEL_WARM_UP_STRING = "hi " * 512 +DEFAULT_OPENAI_MODEL = "text-embedding-3-small" +DEFAULT_COHERE_MODEL = "embed-english-light-v3.0" +DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct" +DEFAULT_VERTEX_MODEL = "text-embedding-004" + + +class EmbeddingProvider(Enum): + OPENAI = "openai" + COHERE = "cohere" + VOYAGE = "voyage" + GOOGLE = "google" + + +class EmbeddingModelTextType: + PROVIDER_TEXT_TYPE_MAP = { + EmbeddingProvider.COHERE: { + EmbedTextType.QUERY: "search_query", + EmbedTextType.PASSAGE: "search_document", + }, + EmbeddingProvider.VOYAGE: { + EmbedTextType.QUERY: "query", + EmbedTextType.PASSAGE: "document", + }, + EmbeddingProvider.GOOGLE: { + EmbedTextType.QUERY: "RETRIEVAL_QUERY", + EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT", + }, + } + + @staticmethod + def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str: + return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type] diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 705386a8c..1c82698e9 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -1,12 +1,28 @@ import gc +import json +from typing import Any from typing import Optional +import openai +import vertexai # type: ignore +import voyageai # type: ignore +from cohere import Client as CohereClient from fastapi import APIRouter from fastapi import HTTPException +from google.oauth2 import service_account from sentence_transformers import CrossEncoder # type: ignore from sentence_transformers import SentenceTransformer # type: ignore +from vertexai.language_models import TextEmbeddingInput # type: ignore +from vertexai.language_models import TextEmbeddingModel # type: ignore +from danswer.search.enums import EmbedTextType from danswer.utils.logger import setup_logger +from model_server.constants import DEFAULT_COHERE_MODEL +from model_server.constants import DEFAULT_OPENAI_MODEL +from model_server.constants import DEFAULT_VERTEX_MODEL +from model_server.constants import DEFAULT_VOYAGE_MODEL +from model_server.constants import EmbeddingModelTextType +from model_server.constants import EmbeddingProvider from model_server.constants import MODEL_WARM_UP_STRING from model_server.utils import simple_log_function_time from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE @@ -17,6 +33,7 @@ from shared_configs.model_server_models import EmbedResponse from shared_configs.model_server_models import RerankRequest from shared_configs.model_server_models import RerankResponse + logger = setup_logger() router = APIRouter(prefix="/encoder") @@ -25,6 +42,117 @@ _GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {} _RERANK_MODELS: Optional[list["CrossEncoder"]] = None +class CloudEmbedding: + def __init__(self, api_key: str, provider: str, model: str | None = None): + self.api_key = api_key + + # Only for Google as is needed on client setup + self.model = model + try: + self.provider = EmbeddingProvider(provider.lower()) + except ValueError: + raise ValueError(f"Unsupported provider: {provider}") + self.client = self._initialize_client() + + def _initialize_client(self) -> Any: + if self.provider == EmbeddingProvider.OPENAI: + return openai.OpenAI(api_key=self.api_key) + elif self.provider == EmbeddingProvider.COHERE: + return CohereClient(api_key=self.api_key) + elif self.provider == EmbeddingProvider.VOYAGE: + return voyageai.Client(api_key=self.api_key) + elif self.provider == EmbeddingProvider.GOOGLE: + credentials = service_account.Credentials.from_service_account_info( + json.loads(self.api_key) + ) + project_id = json.loads(self.api_key)["project_id"] + vertexai.init(project=project_id, credentials=credentials) + return TextEmbeddingModel.from_pretrained( + self.model or DEFAULT_VERTEX_MODEL + ) + + else: + raise ValueError(f"Unsupported provider: {self.provider}") + + def encode( + self, texts: list[str], model_name: str | None, text_type: EmbedTextType + ) -> list[list[float]]: + return [ + self.embed(text=text, text_type=text_type, model=model_name) + for text in texts + ] + + def embed( + self, *, text: str, text_type: EmbedTextType, model: str | None = None + ) -> list[float]: + logger.debug(f"Embedding text with provider: {self.provider}") + if self.provider == EmbeddingProvider.OPENAI: + return self._embed_openai(text, model) + + embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) + + if self.provider == EmbeddingProvider.COHERE: + return self._embed_cohere(text, model, embedding_type) + elif self.provider == EmbeddingProvider.VOYAGE: + return self._embed_voyage(text, model, embedding_type) + elif self.provider == EmbeddingProvider.GOOGLE: + return self._embed_vertex(text, model, embedding_type) + else: + raise ValueError(f"Unsupported provider: {self.provider}") + + def _embed_openai(self, text: str, model: str | None) -> list[float]: + if model is None: + model = DEFAULT_OPENAI_MODEL + + response = self.client.embeddings.create(input=text, model=model) + return response.data[0].embedding + + def _embed_cohere( + self, text: str, model: str | None, embedding_type: str + ) -> list[float]: + if model is None: + model = DEFAULT_COHERE_MODEL + + response = self.client.embed( + texts=[text], + model=model, + input_type=embedding_type, + ) + return response.embeddings[0] + + def _embed_voyage( + self, text: str, model: str | None, embedding_type: str + ) -> list[float]: + if model is None: + model = DEFAULT_VOYAGE_MODEL + + response = self.client.embed(text, model=model, input_type=embedding_type) + return response.embeddings[0] + + def _embed_vertex( + self, text: str, model: str | None, embedding_type: str + ) -> list[float]: + if model is None: + model = DEFAULT_VERTEX_MODEL + + embedding = self.client.get_embeddings( + [ + TextEmbeddingInput( + text, + embedding_type, + ) + ] + ) + return embedding[0].values + + @staticmethod + def create( + api_key: str, provider: str, model: str | None = None + ) -> "CloudEmbedding": + logger.debug(f"Creating Embedding instance for provider: {provider}") + return CloudEmbedding(api_key, provider, model) + + def get_embedding_model( model_name: str, max_context_length: int, @@ -78,18 +206,35 @@ def warm_up_cross_encoders() -> None: @simple_log_function_time() def embed_text( texts: list[str], - model_name: str, + text_type: EmbedTextType, + model_name: str | None, max_context_length: int, normalize_embeddings: bool, + api_key: str | None, + provider_type: str | None, ) -> list[list[float]]: - model = get_embedding_model( - model_name=model_name, max_context_length=max_context_length - ) - embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings) + if provider_type is not None: + if api_key is None: + raise RuntimeError("API key not provided for cloud model") + + cloud_model = CloudEmbedding( + api_key=api_key, provider=provider_type, model=model_name + ) + embeddings = cloud_model.encode(texts, model_name, text_type) + + elif model_name is not None: + hosted_model = get_embedding_model( + model_name=model_name, max_context_length=max_context_length + ) + embeddings = hosted_model.encode( + texts, normalize_embeddings=normalize_embeddings + ) + + if embeddings is None: + raise RuntimeError("Embeddings were not created") if not isinstance(embeddings, list): embeddings = embeddings.tolist() - return embeddings @@ -113,6 +258,9 @@ async def process_embed_request( model_name=embed_request.model_name, max_context_length=embed_request.max_context_length, normalize_embeddings=embed_request.normalize_embeddings, + api_key=embed_request.api_key, + provider_type=embed_request.provider_type, + text_type=embed_request.text_type, ) return EmbedResponse(embeddings=embeddings) except Exception as e: diff --git a/backend/requirements/model_server.txt b/backend/requirements/model_server.txt index 4ef8ffa5b..3e47015f3 100644 --- a/backend/requirements/model_server.txt +++ b/backend/requirements/model_server.txt @@ -7,3 +7,7 @@ tensorflow==2.15.0 torch==2.0.1 transformers==4.39.2 uvicorn==0.21.1 +voyageai==0.2.3 +openai==1.14.3 +cohere==5.5.8 +google-cloud-aiplatform==1.58.0 \ No newline at end of file diff --git a/backend/shared_configs/model_server_models.py b/backend/shared_configs/model_server_models.py index 020a24a30..e31d24e53 100644 --- a/backend/shared_configs/model_server_models.py +++ b/backend/shared_configs/model_server_models.py @@ -1,12 +1,19 @@ from pydantic import BaseModel +from danswer.search.enums import EmbedTextType + class EmbedRequest(BaseModel): # This already includes any prefixes, the text is just passed directly to the model texts: list[str] - model_name: str + + # Can be none for cloud embedding model requests, error handling logic exists for other cases + model_name: str | None max_context_length: int normalize_embeddings: bool + api_key: str | None + provider_type: str | None + text_type: EmbedTextType class EmbedResponse(BaseModel): diff --git a/web/public/Cohere.svg b/web/public/Cohere.svg new file mode 100644 index 000000000..543bc2d6c --- /dev/null +++ b/web/public/Cohere.svg @@ -0,0 +1,30 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/public/Google.webp b/web/public/Google.webp new file mode 100644 index 000000000..7b903159b Binary files /dev/null and b/web/public/Google.webp differ diff --git a/web/public/Openai.svg b/web/public/Openai.svg index e04db75a5..c0bcb8bc1 100644 --- a/web/public/Openai.svg +++ b/web/public/Openai.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/web/public/Voyage.png b/web/public/Voyage.png new file mode 100644 index 000000000..63901a6e4 Binary files /dev/null and b/web/public/Voyage.png differ diff --git a/web/src/app/admin/models/embedding/CloudEmbeddingPage.tsx b/web/src/app/admin/models/embedding/CloudEmbeddingPage.tsx new file mode 100644 index 000000000..6b1feee97 --- /dev/null +++ b/web/src/app/admin/models/embedding/CloudEmbeddingPage.tsx @@ -0,0 +1,163 @@ +"use client"; + +import { Text, Title } from "@tremor/react"; + +import { + CloudEmbeddingProvider, + CloudEmbeddingModel, + AVAILABLE_CLOUD_PROVIDERS, + CloudEmbeddingProviderFull, + EmbeddingModelDescriptor, +} from "./components/types"; +import { EmbeddingDetails } from "./page"; +import { FiInfo } from "react-icons/fi"; +import { HoverPopup } from "@/components/HoverPopup"; +import { Dispatch, SetStateAction } from "react"; + +export default function CloudEmbeddingPage({ + currentModel, + embeddingProviderDetails, + newEnabledProviders, + newUnenabledProviders, + setShowTentativeProvider, + setChangeCredentialsProvider, + setAlreadySelectedModel, + setShowTentativeModel, + setShowModelInQueue, +}: { + setShowModelInQueue: Dispatch>; + setShowTentativeModel: Dispatch>; + currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel; + setAlreadySelectedModel: Dispatch>; + newUnenabledProviders: string[]; + embeddingProviderDetails?: EmbeddingDetails[]; + newEnabledProviders: string[]; + selectedModel: CloudEmbeddingProvider; + + // create modal functions + + setShowTentativeProvider: React.Dispatch< + React.SetStateAction + >; + setChangeCredentialsProvider: React.Dispatch< + React.SetStateAction + >; +}) { + function hasNameInArray( + arr: Array<{ name: string }>, + searchName: string + ): boolean { + return arr.some( + (item) => item.name.toLowerCase() === searchName.toLowerCase() + ); + } + + let providers: CloudEmbeddingProviderFull[] = []; + AVAILABLE_CLOUD_PROVIDERS.forEach((model, ind) => { + let temporary_model: CloudEmbeddingProviderFull = { + ...model, + configured: + !newUnenabledProviders.includes(model.name) && + (newEnabledProviders.includes(model.name) || + (embeddingProviderDetails && + hasNameInArray(embeddingProviderDetails, model.name))!), + }; + providers.push(temporary_model); + }); + + return ( +
+ + Here are some cloud-based models to choose from. + + + They require API keys and run in the clouds of the respective providers. + + +
+ {providers.map((provider, ind) => ( +
+
+ {provider.icon({ size: 40 })} +

{provider.name}

+ +
+ +
+ {provider.embedding_models.map((model, index) => { + const enabled = model.model_name == currentModel.model_name; + + return ( +
{ + if (enabled) { + setAlreadySelectedModel(model); + } else if (provider.configured) { + setShowTentativeModel(model); + } else { + setShowModelInQueue(model); + setShowTentativeProvider(provider); + } + }} + > +
+
+ {model.model_name} +
+

+ ${model.pricePerMillion}/M tokens +

+
+
+ {model.description} +
+
+ ); + })} +
+ +
+ ))} +
+
+ ); +} diff --git a/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx b/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx deleted file mode 100644 index 7572ac2ce..000000000 --- a/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx +++ /dev/null @@ -1,74 +0,0 @@ -import { Modal } from "@/components/Modal"; -import { Button, Text, Callout } from "@tremor/react"; -import { EmbeddingModelDescriptor } from "./embeddingModels"; - -export function ModelSelectionConfirmaion({ - selectedModel, - isCustom, - onConfirm, -}: { - selectedModel: EmbeddingModelDescriptor; - isCustom: boolean; - onConfirm: () => void; -}) { - return ( -
- - You have selected: {selectedModel.model_name}. Are you sure you - want to update to this new embedding model? - - - We will re-index all your documents in the background so you will be - able to continue to use Danswer as normal with the old model in the - meantime. Depending on how many documents you have indexed, this may - take a while. - - - NOTE: this re-indexing process will consume more resources than - normal. If you are self-hosting, we recommend that you allocate at least - 16GB of RAM to Danswer during this process. - - - {isCustom && ( - - We've detected that this is a custom-specified embedding model. - Since we have to download the model files before verifying the - configuration's correctness, we won't be able to let you - know if the configuration is valid until after we start - re-indexing your documents. If there is an issue, it will show up on - this page as an indexing error on this page after clicking Confirm. - - )} - -
- -
-
- ); -} - -export function ModelSelectionConfirmaionModal({ - selectedModel, - isCustom, - onConfirm, - onCancel, -}: { - selectedModel: EmbeddingModelDescriptor; - isCustom: boolean; - onConfirm: () => void; - onCancel: () => void; -}) { - return ( - -
- -
-
- ); -} diff --git a/web/src/app/admin/models/embedding/OpenEmbeddingPage.tsx b/web/src/app/admin/models/embedding/OpenEmbeddingPage.tsx new file mode 100644 index 000000000..4653c864b --- /dev/null +++ b/web/src/app/admin/models/embedding/OpenEmbeddingPage.tsx @@ -0,0 +1,55 @@ +"use client"; +import { Card, Text, Title } from "@tremor/react"; +import { ModelSelector } from "./components/ModelSelector"; +import { + AVAILABLE_MODELS, + EmbeddingModelDescriptor, + HostedEmbeddingModel, +} from "./components/types"; +import { CustomModelForm } from "./components/CustomModelForm"; + +export default function OpenEmbeddingPage({ + onSelectOpenSource, + currentModelName, +}: { + currentModelName: string; + onSelectOpenSource: (model: HostedEmbeddingModel) => Promise; +}) { + return ( +
+ modelOption.model_name !== currentModelName + )} + setSelectedModel={onSelectOpenSource} + /> + + + Alternatively, (if you know what you're doing) you can specify a{" "} + + SentenceTransformers + + -compatible model of your choice below. The rough list of supported + models can be found{" "} + + here + + . +
+ NOTE: not all models listed will work with Danswer, since some + have unique interfaces or special requirements. If in doubt, reach out + to the Danswer team. +
+ +
+ + + +
+
+ ); +} diff --git a/web/src/app/admin/models/embedding/CustomModelForm.tsx b/web/src/app/admin/models/embedding/components/CustomModelForm.tsx similarity index 89% rename from web/src/app/admin/models/embedding/CustomModelForm.tsx rename to web/src/app/admin/models/embedding/components/CustomModelForm.tsx index 23676bc61..2193f2835 100644 --- a/web/src/app/admin/models/embedding/CustomModelForm.tsx +++ b/web/src/app/admin/models/embedding/components/CustomModelForm.tsx @@ -2,16 +2,15 @@ import { BooleanFormField, TextFormField, } from "@/components/admin/connectors/Field"; -import { Button, Divider, Text } from "@tremor/react"; +import { Button } from "@tremor/react"; import { Form, Formik } from "formik"; - import * as Yup from "yup"; -import { EmbeddingModelDescriptor } from "./embeddingModels"; +import { EmbeddingModelDescriptor, HostedEmbeddingModel } from "./types"; export function CustomModelForm({ onSubmit, }: { - onSubmit: (model: EmbeddingModelDescriptor) => void; + onSubmit: (model: HostedEmbeddingModel) => void; }) { return (
@@ -21,6 +20,7 @@ export function CustomModelForm({ model_dim: "", query_prefix: "", passage_prefix: "", + description: "", normalize: true, }} validationSchema={Yup.object().shape({ @@ -62,6 +62,13 @@ export function CustomModelForm({ } }} /> + - +
{model.model_name}
+
+ {model.description + ? model.description + : "Custom model—no description is available."} +
+
+ ); +} export function ModelOption({ model, onSelect, }: { - model: FullEmbeddingModelDescriptor; - onSelect?: (model: EmbeddingModelDescriptor) => void; + model: HostedEmbeddingModel; + onSelect?: (model: HostedEmbeddingModel) => void; }) { return (
void; + modelOptions: HostedEmbeddingModel[]; + setSelectedModel: (model: HostedEmbeddingModel) => void; }) { return (
diff --git a/web/src/app/admin/models/embedding/ReindexingProgressTable.tsx b/web/src/app/admin/models/embedding/components/ReindexingProgressTable.tsx similarity index 100% rename from web/src/app/admin/models/embedding/ReindexingProgressTable.tsx rename to web/src/app/admin/models/embedding/components/ReindexingProgressTable.tsx diff --git a/web/src/app/admin/models/embedding/components/types.ts b/web/src/app/admin/models/embedding/components/types.ts new file mode 100644 index 000000000..5258f69ab --- /dev/null +++ b/web/src/app/admin/models/embedding/components/types.ts @@ -0,0 +1,286 @@ +import { + CohereIcon, + GoogleIcon, + IconProps, + OpenAIIcon, + VoyageIcon, +} from "@/components/icons/icons"; + +// Cloud Provider (not needed for hosted ones) + +export interface CloudEmbeddingProvider { + id: number; + name: string; + api_key?: string; + custom_config?: Record; + docsLink?: string; + + // Frontend-specific properties + website: string; + icon: ({ size, className }: IconProps) => JSX.Element; + description: string; + apiLink: string; + costslink?: string; + + // Relationships + embedding_models: CloudEmbeddingModel[]; + default_model?: CloudEmbeddingModel; +} + +// Embedding Models +export interface EmbeddingModelDescriptor { + model_name: string; + model_dim: number; + normalize: boolean; + query_prefix: string; + passage_prefix: string; + cloud_provider_name?: string | null; + description: string; +} + +export interface CloudEmbeddingModel extends EmbeddingModelDescriptor { + cloud_provider_name: string | null; + pricePerMillion: number; + enabled?: boolean; + mtebScore: number; + maxContext: number; +} + +export interface HostedEmbeddingModel extends EmbeddingModelDescriptor { + link?: string; + model_dim: number; + normalize: boolean; + query_prefix: string; + passage_prefix: string; + isDefault?: boolean; +} + +// Responses +export interface FullEmbeddingModelResponse { + current_model_name: string; + secondary_model_name: string | null; +} + +export interface CloudEmbeddingProviderFull extends CloudEmbeddingProvider { + configured: boolean; +} + +export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ + { + model_name: "intfloat/e5-base-v2", + model_dim: 768, + normalize: true, + description: + "The recommended default for most situations. If you aren't sure which model to use, this is probably the one.", + isDefault: true, + link: "https://huggingface.co/intfloat/e5-base-v2", + query_prefix: "query: ", + passage_prefix: "passage: ", + }, + { + model_name: "intfloat/e5-small-v2", + model_dim: 384, + normalize: true, + description: + "A smaller / faster version of the default model. If you're running Danswer on a resource constrained system, then this is a good choice.", + link: "https://huggingface.co/intfloat/e5-small-v2", + query_prefix: "query: ", + passage_prefix: "passage: ", + }, + { + model_name: "intfloat/multilingual-e5-base", + model_dim: 768, + normalize: true, + description: + "If you have many documents in other languages besides English, this is the one to go for.", + link: "https://huggingface.co/intfloat/multilingual-e5-base", + query_prefix: "query: ", + passage_prefix: "passage: ", + }, + { + model_name: "intfloat/multilingual-e5-small", + model_dim: 384, + normalize: true, + description: + "If you have many documents in other languages besides English, and you're running on a resource constrained system, then this is the one to go for.", + link: "https://huggingface.co/intfloat/multilingual-e5-base", + query_prefix: "query: ", + passage_prefix: "passage: ", + }, +]; + +export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ + { + id: 0, + name: "OpenAI", + website: "https://openai.com", + icon: OpenAIIcon, + description: "AI industry leader known for ChatGPT and DALL-E", + apiLink: "https://platform.openai.com/api-keys", + docsLink: + "https://docs.danswer.dev/guides/embedding_providers#openai-models", + costslink: "https://openai.com/pricing", + embedding_models: [ + { + model_name: "text-embedding-3-large", + cloud_provider_name: "OpenAI", + description: + "OpenAI's large embedding model. Best performance, but more expensive.", + pricePerMillion: 0.13, + model_dim: 3072, + normalize: false, + query_prefix: "", + passage_prefix: "", + mtebScore: 64.6, + maxContext: 8191, + enabled: false, + }, + { + model_name: "text-embedding-3-small", + cloud_provider_name: "OpenAI", + model_dim: 1536, + normalize: false, + query_prefix: "", + passage_prefix: "", + description: + "OpenAI's newer, more efficient embedding model. Good balance of performance and cost.", + pricePerMillion: 0.02, + enabled: false, + mtebScore: 62.3, + maxContext: 8191, + }, + ], + }, + { + id: 1, + name: "Cohere", + website: "https://cohere.ai", + icon: CohereIcon, + docsLink: + "https://docs.danswer.dev/guides/embedding_providers#cohere-models", + description: + "AI company specializing in NLP models for various text-based tasks", + apiLink: "https://dashboard.cohere.ai/api-keys", + costslink: "https://cohere.com/pricing", + embedding_models: [ + { + model_name: "embed-english-v3.0", + cloud_provider_name: "Cohere", + description: + "Cohere's English embedding model. Good performance for English-language tasks.", + pricePerMillion: 0.1, + mtebScore: 64.5, + maxContext: 512, + enabled: false, + model_dim: 1024, + normalize: false, + query_prefix: "", + passage_prefix: "", + }, + { + model_name: "embed-english-light-v3.0", + cloud_provider_name: "Cohere", + description: + "Cohere's lightweight English embedding model. Faster and more efficient for simpler tasks.", + pricePerMillion: 0.1, + mtebScore: 62, + maxContext: 512, + enabled: false, + model_dim: 384, + normalize: false, + query_prefix: "", + passage_prefix: "", + }, + ], + }, + + { + id: 2, + name: "Google", + website: "https://ai.google", + icon: GoogleIcon, + docsLink: + "https://docs.danswer.dev/guides/embedding_providers#vertex-ai-google-model", + description: + "Offers a wide range of AI services including language and vision models", + apiLink: "https://console.cloud.google.com/apis/credentials", + costslink: "https://cloud.google.com/vertex-ai/pricing", + embedding_models: [ + { + cloud_provider_name: "Google", + model_name: "text-embedding-004", + description: "Google's most recent text embedding model.", + pricePerMillion: 0.025, + mtebScore: 66.31, + maxContext: 2048, + enabled: false, + model_dim: 768, + normalize: false, + query_prefix: "", + passage_prefix: "", + }, + { + cloud_provider_name: "Google", + model_name: "textembedding-gecko@003", + description: "Google's Gecko embedding model. Powerful and efficient.", + pricePerMillion: 0.025, + mtebScore: 66.31, + maxContext: 2048, + enabled: false, + model_dim: 768, + normalize: false, + query_prefix: "", + passage_prefix: "", + }, + ], + }, + { + id: 3, + name: "Voyage", + website: "https://www.voyageai.com", + icon: VoyageIcon, + description: "Advanced NLP research startup born from Stanford AI Labs", + docsLink: + "https://docs.danswer.dev/guides/embedding_providers#voyage-models", + apiLink: "https://www.voyageai.com/dashboard", + costslink: "https://www.voyageai.com/pricing", + embedding_models: [ + { + cloud_provider_name: "Voyage", + model_name: "voyage-large-2-instruct", + description: + "Voyage's large embedding model. High performance with instruction fine-tuning.", + pricePerMillion: 0.12, + mtebScore: 68.28, + maxContext: 4000, + enabled: false, + model_dim: 1024, + normalize: false, + query_prefix: "", + passage_prefix: "", + }, + { + cloud_provider_name: "Voyage", + model_name: "voyage-light-2-instruct", + description: + "Voyage's lightweight embedding model. Good balance of performance and efficiency.", + pricePerMillion: 0.12, + mtebScore: 67.13, + maxContext: 16000, + enabled: false, + model_dim: 1024, + normalize: false, + query_prefix: "", + passage_prefix: "", + }, + ], + }, +]; + +export const INVALID_OLD_MODEL = "thenlper/gte-small"; + +export function checkModelNameIsValid( + modelName: string | undefined | null +): boolean { + return !!modelName && modelName !== INVALID_OLD_MODEL; +} diff --git a/web/src/app/admin/models/embedding/embeddingModels.ts b/web/src/app/admin/models/embedding/embeddingModels.ts deleted file mode 100644 index 7c5d09180..000000000 --- a/web/src/app/admin/models/embedding/embeddingModels.ts +++ /dev/null @@ -1,87 +0,0 @@ -export interface EmbeddingModelResponse { - model_name: string | null; -} - -export interface FullEmbeddingModelResponse { - current_model_name: string; - secondary_model_name: string | null; -} - -export interface EmbeddingModelDescriptor { - model_name: string; - model_dim: number; - normalize: boolean; - query_prefix?: string; - passage_prefix?: string; -} - -export interface FullEmbeddingModelDescriptor extends EmbeddingModelDescriptor { - description: string; - isDefault?: boolean; - link?: string; -} - -export const AVAILABLE_MODELS: FullEmbeddingModelDescriptor[] = [ - { - model_name: "intfloat/e5-base-v2", - model_dim: 768, - normalize: true, - description: - "The recommended default for most situations. If you aren't sure which model to use, this is probably the one.", - isDefault: true, - link: "https://huggingface.co/intfloat/e5-base-v2", - query_prefix: "query: ", - passage_prefix: "passage: ", - }, - { - model_name: "intfloat/e5-small-v2", - model_dim: 384, - normalize: true, - description: - "A smaller / faster version of the default model. If you're running Danswer on a resource constrained system, then this is a good choice.", - link: "https://huggingface.co/intfloat/e5-small-v2", - query_prefix: "query: ", - passage_prefix: "passage: ", - }, - { - model_name: "intfloat/multilingual-e5-base", - model_dim: 768, - normalize: true, - description: - "If you have many documents in other languages besides English, this is the one to go for.", - link: "https://huggingface.co/intfloat/multilingual-e5-base", - query_prefix: "query: ", - passage_prefix: "passage: ", - }, - { - model_name: "intfloat/multilingual-e5-small", - model_dim: 384, - normalize: true, - description: - "If you have many documents in other languages besides English, and you're running on a resource constrained system, then this is the one to go for.", - link: "https://huggingface.co/intfloat/multilingual-e5-base", - query_prefix: "query: ", - passage_prefix: "passage: ", - }, -]; - -export const INVALID_OLD_MODEL = "thenlper/gte-small"; - -export function checkModelNameIsValid(modelName: string | undefined | null) { - if (!modelName) { - return false; - } - if (modelName === INVALID_OLD_MODEL) { - return false; - } - return true; -} - -export function fillOutEmeddingModelDescriptor( - embeddingModel: EmbeddingModelDescriptor | FullEmbeddingModelDescriptor -): FullEmbeddingModelDescriptor { - return { - ...embeddingModel, - description: "", - }; -} diff --git a/web/src/app/admin/models/embedding/modals/AlreadyPickedModal.tsx b/web/src/app/admin/models/embedding/modals/AlreadyPickedModal.tsx new file mode 100644 index 000000000..c8d9430c9 --- /dev/null +++ b/web/src/app/admin/models/embedding/modals/AlreadyPickedModal.tsx @@ -0,0 +1,31 @@ +import React from "react"; +import { Modal } from "@/components/Modal"; +import { Button, Text } from "@tremor/react"; + +import { CloudEmbeddingModel } from "../components/types"; + +export function AlreadyPickedModal({ + model, + onClose, +}: { + model: CloudEmbeddingModel; + onClose: () => void; +}) { + return ( + +
+ + You can select a different one if you want! + +
+ +
+
+
+ ); +} diff --git a/web/src/app/admin/models/embedding/modals/ChangeCredentialsModal.tsx b/web/src/app/admin/models/embedding/modals/ChangeCredentialsModal.tsx new file mode 100644 index 000000000..fa94ba9cc --- /dev/null +++ b/web/src/app/admin/models/embedding/modals/ChangeCredentialsModal.tsx @@ -0,0 +1,244 @@ +import React, { useRef, useState } from "react"; +import { Modal } from "@/components/Modal"; +import { Button, Text, Callout, Subtitle, Divider } from "@tremor/react"; +import { Label, TextFormField } from "@/components/admin/connectors/Field"; +import { CloudEmbeddingProvider } from "../components/types"; +import { + EMBEDDING_PROVIDERS_ADMIN_URL, + LLM_PROVIDERS_ADMIN_URL, +} from "../../llm/constants"; +import { mutate } from "swr"; +import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; +import { Field } from "formik"; + +export function ChangeCredentialsModal({ + provider, + onConfirm, + onCancel, + onDeleted, + useFileUpload, +}: { + provider: CloudEmbeddingProvider; + onConfirm: () => void; + onCancel: () => void; + onDeleted: () => void; + useFileUpload: boolean; +}) { + const [apiKey, setApiKey] = useState(""); + const [testError, setTestError] = useState(""); + const [fileName, setFileName] = useState(""); + const fileInputRef = useRef(null); + const [isProcessing, setIsProcessing] = useState(false); + const [deletionError, setDeletionError] = useState(""); + + const clearFileInput = () => { + setFileName(""); + if (fileInputRef.current) { + fileInputRef.current.value = ""; + } + }; + + const handleFileUpload = async ( + event: React.ChangeEvent + ) => { + const file = event.target.files?.[0]; + setFileName(""); + + if (file) { + setFileName(file.name); + try { + setDeletionError(""); + const fileContent = await file.text(); + let jsonContent; + try { + jsonContent = JSON.parse(fileContent); + setApiKey(JSON.stringify(jsonContent)); + } catch (parseError) { + throw new Error( + "Failed to parse JSON file. Please ensure it's a valid JSON." + ); + } + } catch (error) { + setTestError( + error instanceof Error + ? error.message + : "An unknown error occurred while processing the file." + ); + setApiKey(""); + clearFileInput(); + } + } + }; + + const handleDelete = async () => { + setDeletionError(""); + setIsProcessing(true); + + try { + const response = await fetch( + `${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.name}`, + { + method: "DELETE", + } + ); + + if (!response.ok) { + const errorData = await response.json(); + setDeletionError(errorData.detail); + return; + } + + mutate(LLM_PROVIDERS_ADMIN_URL); + onDeleted(); + } catch (error) { + setDeletionError( + error instanceof Error ? error.message : "An unknown error occurred" + ); + } finally { + setIsProcessing(false); + } + }; + + const handleSubmit = async () => { + setTestError(""); + + try { + const body = JSON.stringify({ + api_key: apiKey, + provider: provider.name.toLowerCase().split(" ")[0], + default_model_id: provider.name, + }); + + const testResponse = await fetch("/api/admin/embedding/test-embedding", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + provider: provider.name.toLowerCase().split(" ")[0], + api_key: apiKey, + }), + }); + + if (!testResponse.ok) { + const errorMsg = (await testResponse.json()).detail; + throw new Error(errorMsg); + } + + const updateResponse = await fetch(EMBEDDING_PROVIDERS_ADMIN_URL, { + method: "PUT", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + name: provider.name, + api_key: apiKey, + is_default_provider: false, + is_configured: true, + }), + }); + + if (!updateResponse.ok) { + const errorData = await updateResponse.json(); + throw new Error( + errorData.detail || "Failed to update provider- check your API key" + ); + } + + onConfirm(); + } catch (error) { + setTestError( + error instanceof Error ? error.message : "An unknown error occurred" + ); + } + }; + + return ( + +
+ + Want to swap out your key? + + +
+ {useFileUpload ? ( + <> + + + {fileName &&

Uploaded file: {fileName}

} + + ) : ( + <> +
+ +
+ setApiKey(e.target.value)} + placeholder="Paste your API key here" + /> + + )} + + Visit API + +
+ + {testError && ( + + {testError} + + )} + +
+ +
+ + + + You can also delete your key. + + + This is only possible if you have already switched to a different + embedding type! + + + + {deletionError && ( + + {deletionError} + + )} +
+
+ ); +} diff --git a/web/src/app/admin/models/embedding/modals/DeleteCredentialsModal.tsx b/web/src/app/admin/models/embedding/modals/DeleteCredentialsModal.tsx new file mode 100644 index 000000000..b13d8051c --- /dev/null +++ b/web/src/app/admin/models/embedding/modals/DeleteCredentialsModal.tsx @@ -0,0 +1,41 @@ +import React from "react"; +import { Modal } from "@/components/Modal"; +import { Button, Text, Callout } from "@tremor/react"; +import { CloudEmbeddingProvider } from "../components/types"; + +export function DeleteCredentialsModal({ + modelProvider, + onConfirm, + onCancel, +}: { + modelProvider: CloudEmbeddingProvider; + onConfirm: () => void; + onCancel: () => void; +}) { + return ( + +
+ + You're about to delete your {modelProvider.name} credentials. Are + you sure? + + +
+ + +
+
+
+ ); +} diff --git a/web/src/app/admin/models/embedding/modals/ModelSelectionModal.tsx b/web/src/app/admin/models/embedding/modals/ModelSelectionModal.tsx new file mode 100644 index 000000000..59b44b7d5 --- /dev/null +++ b/web/src/app/admin/models/embedding/modals/ModelSelectionModal.tsx @@ -0,0 +1,61 @@ +import { Modal } from "@/components/Modal"; +import { Button, Text, Callout } from "@tremor/react"; +import { + EmbeddingModelDescriptor, + HostedEmbeddingModel, +} from "../components/types"; + +export function ModelSelectionConfirmationModal({ + selectedModel, + isCustom, + onConfirm, + onCancel, +}: { + selectedModel: HostedEmbeddingModel; + isCustom: boolean; + onConfirm: () => void; + onCancel: () => void; +}) { + return ( + +
+
+ + You have selected: {selectedModel.model_name}. Are you sure + you want to update to this new embedding model? + + + We will re-index all your documents in the background so you will be + able to continue to use Danswer as normal with the old model in the + meantime. Depending on how many documents you have indexed, this may + take a while. + + + NOTE: this re-indexing process will consume more resources + than normal. If you are self-hosting, we recommend that you allocate + at least 16GB of RAM to Danswer during this process. + + + {/* TODO Change this back- ensure functional */} + {!isCustom && ( + + We've detected that this is a custom-specified embedding + model. Since we have to download the model files before verifying + the configuration's correctness, we won't be able to let + you know if the configuration is valid until after we start + re-indexing your documents. If there is an issue, it will show up + on this page as an indexing error on this page after clicking + Confirm. + + )} + +
+ +
+
+
+
+ ); +} diff --git a/web/src/app/admin/models/embedding/modals/ProviderCreationModal.tsx b/web/src/app/admin/models/embedding/modals/ProviderCreationModal.tsx new file mode 100644 index 000000000..981dfa0ca --- /dev/null +++ b/web/src/app/admin/models/embedding/modals/ProviderCreationModal.tsx @@ -0,0 +1,232 @@ +import React, { useRef, useState } from "react"; +import { Text, Button, Callout } from "@tremor/react"; +import { Formik, Form, Field } from "formik"; +import * as Yup from "yup"; +import { Label, TextFormField } from "@/components/admin/connectors/Field"; +import { LoadingAnimation } from "@/components/Loading"; +import { CloudEmbeddingProvider } from "../components/types"; +import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../../llm/constants"; +import { Modal } from "@/components/Modal"; + +export function ProviderCreationModal({ + selectedProvider, + onConfirm, + onCancel, + existingProvider, +}: { + selectedProvider: CloudEmbeddingProvider; + onConfirm: () => void; + onCancel: () => void; + existingProvider?: CloudEmbeddingProvider; +}) { + const useFileUpload = selectedProvider.name == "Google"; + + const [isProcessing, setIsProcessing] = useState(false); + const [errorMsg, setErrorMsg] = useState(""); + const [fileName, setFileName] = useState(""); + + const initialValues = { + name: existingProvider?.name || selectedProvider.name, + api_key: existingProvider?.api_key || "", + custom_config: existingProvider?.custom_config + ? Object.entries(existingProvider.custom_config) + : [], + default_model_name: "", + model_id: 0, + }; + + const validationSchema = Yup.object({ + name: Yup.string().required("Name is required"), + api_key: useFileUpload + ? Yup.string() + : Yup.string().required("API Key is required"), + custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)), + }); + + const fileInputRef = useRef(null); + + const handleFileUpload = async ( + event: React.ChangeEvent, + setFieldValue: (field: string, value: any) => void + ) => { + const file = event.target.files?.[0]; + setFileName(""); + if (file) { + setFileName(file.name); + try { + const fileContent = await file.text(); + let jsonContent; + try { + jsonContent = JSON.parse(fileContent); + } catch (parseError) { + throw new Error( + "Failed to parse JSON file. Please ensure it's a valid JSON." + ); + } + setFieldValue("api_key", JSON.stringify(jsonContent)); + } catch (error) { + setFieldValue("api_key", ""); + } + } + }; + + const handleSubmit = async ( + values: any, + { setSubmitting }: { setSubmitting: (isSubmitting: boolean) => void } + ) => { + setIsProcessing(true); + setErrorMsg(""); + + try { + const customConfig = Object.fromEntries(values.custom_config); + + const initialResponse = await fetch( + "/api/admin/embedding/test-embedding", + { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + provider: values.name.toLowerCase().split(" ")[0], + api_key: values.api_key, + }), + } + ); + + if (!initialResponse.ok) { + const errorMsg = (await initialResponse.json()).detail; + setErrorMsg(errorMsg); + setIsProcessing(false); + setSubmitting(false); + return; + } + + const response = await fetch(EMBEDDING_PROVIDERS_ADMIN_URL, { + method: "PUT", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + ...values, + custom_config: customConfig, + is_default_provider: false, + is_configured: true, + }), + }); + + if (!response.ok) { + const errorData = await response.json(); + throw new Error( + errorData.detail || "Failed to update provider- check your API key" + ); + } + + onConfirm(); + } catch (error: unknown) { + if (error instanceof Error) { + setErrorMsg(error.message); + } else { + setErrorMsg("An unknown error occurred"); + } + } finally { + setIsProcessing(false); + setSubmitting(false); + } + }; + + return ( + +
+ + {({ + values, + errors, + touched, + isSubmitting, + handleSubmit, + setFieldValue, + }) => ( +
+ + You are setting the credentials for this provider. To access + this information, follow the instructions{" "} + + here + {" "} + and gather your{" "} + + API KEY + + + +
+ {useFileUpload ? ( + <> + + handleFileUpload(e, setFieldValue)} + className="text-lg w-full p-1" + /> + {fileName &&

Uploaded file: {fileName}

} + + ) : ( + + )} + + + Learn more here + +
+ + {errorMsg && ( + + {errorMsg} + + )} + + +
+ )} +
+
+
+ ); +} diff --git a/web/src/app/admin/models/embedding/modals/SelectModelModal.tsx b/web/src/app/admin/models/embedding/modals/SelectModelModal.tsx new file mode 100644 index 000000000..a9cfa3597 --- /dev/null +++ b/web/src/app/admin/models/embedding/modals/SelectModelModal.tsx @@ -0,0 +1,37 @@ +import React from "react"; +import { Modal } from "@/components/Modal"; +import { Button, Text, Callout } from "@tremor/react"; +import { CloudEmbeddingModel } from "../components/types"; + +export function SelectModelModal({ + model, + onConfirm, + onCancel, +}: { + model: CloudEmbeddingModel; + onConfirm: () => void; + onCancel: () => void; +}) { + return ( + +
+ + You're about to set your embedding model to {model.model_name}. +
+ Are you sure? +
+
+ + +
+
+
+ ); +} diff --git a/web/src/app/admin/models/embedding/page.tsx b/web/src/app/admin/models/embedding/page.tsx index ccda9af19..92e7de6df 100644 --- a/web/src/app/admin/models/embedding/page.tsx +++ b/web/src/app/admin/models/embedding/page.tsx @@ -3,28 +3,75 @@ import { ThreeDotsLoader } from "@/components/Loading"; import { AdminPageTitle } from "@/components/admin/Title"; import { errorHandlingFetcher } from "@/lib/fetcher"; -import { Button, Card, Text, Title } from "@tremor/react"; +import { Button, Text, Title } from "@tremor/react"; import { FiPackage } from "react-icons/fi"; import useSWR, { mutate } from "swr"; -import { ModelOption, ModelSelector } from "./ModelSelector"; +import { ModelOption, ModelPreview } from "./components/ModelSelector"; import { useState } from "react"; -import { ModelSelectionConfirmaionModal } from "./ModelSelectionConfirmation"; -import { ReindexingProgressTable } from "./ReindexingProgressTable"; +import { ReindexingProgressTable } from "./components/ReindexingProgressTable"; import { Modal } from "@/components/Modal"; import { + CloudEmbeddingProvider, + CloudEmbeddingModel, + AVAILABLE_CLOUD_PROVIDERS, AVAILABLE_MODELS, - EmbeddingModelDescriptor, INVALID_OLD_MODEL, - fillOutEmeddingModelDescriptor, -} from "./embeddingModels"; + HostedEmbeddingModel, + EmbeddingModelDescriptor, +} from "./components/types"; import { ErrorCallout } from "@/components/ErrorCallout"; import { Connector, ConnectorIndexingStatus } from "@/lib/types"; import Link from "next/link"; -import { CustomModelForm } from "./CustomModelForm"; +import OpenEmbeddingPage from "./OpenEmbeddingPage"; +import CloudEmbeddingPage from "./CloudEmbeddingPage"; +import { ProviderCreationModal } from "./modals/ProviderCreationModal"; + +import { DeleteCredentialsModal } from "./modals/DeleteCredentialsModal"; +import { SelectModelModal } from "./modals/SelectModelModal"; +import { ChangeCredentialsModal } from "./modals/ChangeCredentialsModal"; +import { ModelSelectionConfirmationModal } from "./modals/ModelSelectionModal"; +import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../llm/constants"; +import { AlreadyPickedModal } from "./modals/AlreadyPickedModal"; + +export interface EmbeddingDetails { + api_key: string; + custom_config: any; + default_model_id?: number; + name: string; +} function Main() { - const [tentativeNewEmbeddingModel, setTentativeNewEmbeddingModel] = - useState(null); + const [openToggle, setOpenToggle] = useState(true); + + // Cloud Provider based modals + const [showTentativeProvider, setShowTentativeProvider] = + useState(null); + const [showUnconfiguredProvider, setShowUnconfiguredProvider] = + useState(null); + const [changeCredentialsProvider, setChangeCredentialsProvider] = + useState(null); + + // Cloud Model based modals + const [alreadySelectedModel, setAlreadySelectedModel] = + useState(null); + const [showTentativeModel, setShowTentativeModel] = + useState(null); + + const [showModelInQueue, setShowModelInQueue] = + useState(null); + + // Open Model based modals + const [showTentativeOpenProvider, setShowTentativeOpenProvider] = + useState(null); + + // Enabled / unenabled providers + const [newEnabledProviders, setNewEnabledProviders] = useState([]); + const [newUnenabledProviders, setNewUnenabledProviders] = useState( + [] + ); + + const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] = + useState(false); const [isCancelling, setIsCancelling] = useState(false); const [showAddConnectorPopup, setShowAddConnectorPopup] = useState(false); @@ -33,16 +80,22 @@ function Main() { data: currentEmeddingModel, isLoading: isLoadingCurrentModel, error: currentEmeddingModelError, - } = useSWR( + } = useSWR( "/api/secondary-index/get-current-embedding-model", errorHandlingFetcher, { refreshInterval: 5000 } // 5 seconds ); + + const { data: embeddingProviderDetails } = useSWR( + EMBEDDING_PROVIDERS_ADMIN_URL, + errorHandlingFetcher + ); + const { data: futureEmbeddingModel, isLoading: isLoadingFutureModel, error: futureEmeddingModelError, - } = useSWR( + } = useSWR( "/api/secondary-index/get-secondary-embedding-model", errorHandlingFetcher, { refreshInterval: 5000 } // 5 seconds @@ -61,27 +114,41 @@ function Main() { { refreshInterval: 5000 } // 5 seconds ); - const onSelect = async (model: EmbeddingModelDescriptor) => { - if (currentEmeddingModel?.model_name === INVALID_OLD_MODEL) { - await onConfirm(model); - } else { - setTentativeNewEmbeddingModel(model); - } - }; + const onConfirm = async ( + model: CloudEmbeddingModel | HostedEmbeddingModel + ) => { + let newModel: EmbeddingModelDescriptor; + + if ("cloud_provider_name" in model) { + // This is a CloudEmbeddingModel + newModel = { + ...model, + model_name: model.model_name, + cloud_provider_name: model.cloud_provider_name, + // cloud_provider_id: model.cloud_provider_id || 0, + }; + } else { + // This is an EmbeddingModelDescriptor + newModel = { + ...model, + model_name: model.model_name!, + description: "", + cloud_provider_name: null, + }; + } - const onConfirm = async (model: EmbeddingModelDescriptor) => { const response = await fetch( "/api/secondary-index/set-new-embedding-model", { method: "POST", - body: JSON.stringify(model), + body: JSON.stringify(newModel), headers: { "Content-Type": "application/json", }, } ); if (response.ok) { - setTentativeNewEmbeddingModel(null); + setShowTentativeModel(null); mutate("/api/secondary-index/get-secondary-embedding-model"); if (!connectors || !connectors.length) { setShowAddConnectorPopup(true); @@ -96,14 +163,13 @@ function Main() { method: "POST", }); if (response.ok) { - setTentativeNewEmbeddingModel(null); + setShowTentativeModel(null); mutate("/api/secondary-index/get-secondary-embedding-model"); } else { alert( `Failed to cancel embedding model update - ${await response.text()}` ); } - setIsCancelling(false); }; @@ -119,216 +185,328 @@ function Main() { return ; } - const currentModelName = currentEmeddingModel.model_name; - const currentModel = - AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) || - fillOutEmeddingModelDescriptor(currentEmeddingModel); + const onConfirmSelection = async (model: EmbeddingModelDescriptor) => { + const response = await fetch( + "/api/secondary-index/set-new-embedding-model", + { + method: "POST", + body: JSON.stringify(model), + headers: { + "Content-Type": "application/json", + }, + } + ); + if (response.ok) { + setShowTentativeModel(null); + mutate("/api/secondary-index/get-secondary-embedding-model"); + if (!connectors || !connectors.length) { + setShowAddConnectorPopup(true); + } + } else { + alert(`Failed to update embedding model - ${await response.text()}`); + } + }; - const newModelSelection = futureEmbeddingModel - ? AVAILABLE_MODELS.find( - (model) => model.model_name === futureEmbeddingModel.model_name - ) || fillOutEmeddingModelDescriptor(futureEmbeddingModel) - : null; + const currentModelName = currentEmeddingModel?.model_name; + const AVAILABLE_CLOUD_PROVIDERS_FLATTENED = AVAILABLE_CLOUD_PROVIDERS.flatMap( + (provider) => + provider.embedding_models.map((model) => ({ + ...model, + cloud_provider_id: provider.id, + model_name: model.model_name, // Ensure model_name is set for consistency + })) + ); + + const currentModel: CloudEmbeddingModel | HostedEmbeddingModel = + AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) || + AVAILABLE_CLOUD_PROVIDERS_FLATTENED.find( + (model) => model.model_name === currentEmeddingModel.model_name + )!; + // || + // fillOutEmeddingModelDescriptor(currentEmeddingModel); + + const onSelectOpenSource = async (model: HostedEmbeddingModel) => { + if (currentEmeddingModel?.model_name === INVALID_OLD_MODEL) { + await onConfirmSelection(model); + } else { + setShowTentativeOpenProvider(model); + } + }; + + const selectedModel = AVAILABLE_CLOUD_PROVIDERS[0]; + const clientsideAddProvider = (provider: CloudEmbeddingProvider) => { + const providerName = provider.name; + setNewEnabledProviders((newEnabledProviders) => [ + ...newEnabledProviders, + providerName, + ]); + setNewUnenabledProviders((newUnenabledProviders) => + newUnenabledProviders.filter( + (givenProvidername) => givenProvidername != providerName + ) + ); + }; + + const clientsideRemoveProvider = (provider: CloudEmbeddingProvider) => { + const providerName = provider.name; + setNewEnabledProviders((newEnabledProviders) => + newEnabledProviders.filter( + (givenProvidername) => givenProvidername != providerName + ) + ); + setNewUnenabledProviders((newUnenabledProviders) => [ + ...newUnenabledProviders, + providerName, + ]); + }; return ( -
- {tentativeNewEmbeddingModel && ( - - model.model_name === tentativeNewEmbeddingModel.model_name - ) === undefined - } - onConfirm={() => onConfirm(tentativeNewEmbeddingModel)} - onCancel={() => setTentativeNewEmbeddingModel(null)} - /> - )} - - {showAddConnectorPopup && ( - -
-
- Embeding model successfully selected{" "} - 🙌 -
-
- To complete the initial setup, let's add a connector! -
-
- Connectors are the way that Danswer gets data from your - organization's various data sources. Once setup, we'll - automatically sync data from your apps and docs into Danswer, so - you can search all through all of them in one place. -
-
- - - -
-
-
- )} - - {isCancelling && ( - setIsCancelling(false)} - title="Cancel Embedding Model Switch" - > -
-
- Are you sure you want to cancel? -
-
- Cancelling will revert to the previous model and all progress will - be lost. -
-
- -
-
-
- )} - +
Embedding models are used to generate embeddings for your documents, which then power Danswer's search. + {alreadySelectedModel && ( + setAlreadySelectedModel(null)} + /> + )} + {showTentativeOpenProvider && ( + + model.model_name === showTentativeOpenProvider.model_name + ) === undefined + } + onConfirm={() => onConfirm(showTentativeOpenProvider)} + onCancel={() => setShowTentativeOpenProvider(null)} + /> + )} + + {showTentativeProvider && ( + { + setShowTentativeProvider(showUnconfiguredProvider); + clientsideAddProvider(showTentativeProvider); + if (showModelInQueue) { + setShowTentativeModel(showModelInQueue); + } + }} + onCancel={() => { + setShowModelInQueue(null); + setShowTentativeProvider(null); + }} + /> + )} + {changeCredentialsProvider && ( + { + clientsideRemoveProvider(changeCredentialsProvider); + setChangeCredentialsProvider(null); + }} + provider={changeCredentialsProvider} + onConfirm={() => setChangeCredentialsProvider(null)} + onCancel={() => setChangeCredentialsProvider(null)} + /> + )} + + {showTentativeModel && ( + { + setShowModelInQueue(null); + onConfirm(showTentativeModel); + }} + onCancel={() => { + setShowModelInQueue(null); + setShowTentativeModel(null); + }} + /> + )} + + {showDeleteCredentialsModal && ( + { + setShowDeleteCredentialsModal(false); + }} + onCancel={() => setShowDeleteCredentialsModal(false)} + /> + )} + {currentModel ? ( <> Current Embedding Model - - + ) : ( - newModelSelection && - (!connectors || !connectors.length) && ( - <> - Current Embedding Model + Choose your Embedding Model + )} - - - - - ) + {!(futureEmbeddingModel && connectors && connectors.length > 0) && ( + <> + Switch your Embedding Model + + If the current model is not working for you, you can update your + model choice below. Note that this will require a complete + re-indexing of all your documents across every connected source. We + will take care of this in the background, but depending on the size + of your corpus, this could take hours, day, or even weeks. You can + monitor the progress of the re-indexing on this page. + + +
+ +
+ +
+
+ )} {!showAddConnectorPopup && - (!newModelSelection ? ( -
- {currentModel ? ( - <> - Switch your Embedding Model - - - If the current model is not working for you, you can update - your model choice below. Note that this will require a - complete re-indexing of all your documents across every - connected source. We will take care of this in the background, - but depending on the size of your corpus, this could take - hours, day, or even weeks. You can monitor the progress of the - re-indexing on this page. - - - ) : ( - <> - Choose your Embedding Model - - )} - - - Below are a curated selection of quality models that we recommend - you choose from. - - - modelOption.model_name !== currentModelName - )} - setSelectedModel={onSelect} - /> - - - Alternatively, (if you know what you're doing) you can - specify a{" "} - - SentenceTransformers - - -compatible model of your choice below. The rough list of - supported models can be found{" "} - - here - - . -
- NOTE: not all models listed will work with Danswer, since - some have unique interfaces or special requirements. If in doubt, - reach out to the Danswer team. -
- -
- - - -
-
+ !futureEmbeddingModel && + (openToggle ? ( + ) : ( - connectors && - connectors.length > 0 && ( -
- Current Upgrade Status -
-
- Currently in the process of switching to: -
- - - - - - The table below shows the re-indexing progress of all existing - connectors. Once all connectors have been re-indexed - successfully, the new model will be used for all search - queries. Until then, we will use the old model so that no - downtime is necessary during this transition. - - - {isLoadingOngoingReIndexingStatus ? ( - - ) : ongoingReIndexingStatus ? ( - - ) : ( - - )} -
-
- ) + ))} + + {openToggle && ( + <> + {showAddConnectorPopup && ( + +
+
+ + Embedding model successfully selected + {" "} + 🙌 +
+
+ To complete the initial setup, let's add a connector! +
+
+ Connectors are the way that Danswer gets data from your + organization's various data sources. Once setup, + we'll automatically sync data from your apps and docs + into Danswer, so you can search all through all of them in one + place. +
+
+ + + +
+
+
+ )} + + {isCancelling && ( + setIsCancelling(false)} + title="Cancel Embedding Model Switch" + > +
+
+ Are you sure you want to cancel? +
+
+ Cancelling will revert to the previous model and all progress + will be lost. +
+
+ +
+
+
+ )} + + )} + + {futureEmbeddingModel && connectors && connectors.length > 0 && ( +
+ Current Upgrade Status +
+
+ Currently in the process of switching to:{" "} + {futureEmbeddingModel.model_name} +
+ {/* */} + + + + + The table below shows the re-indexing progress of all existing + connectors. Once all connectors have been re-indexed successfully, + the new model will be used for all search queries. Until then, we + will use the old model so that no downtime is necessary during + this transition. + + + {isLoadingOngoingReIndexingStatus ? ( + + ) : ongoingReIndexingStatus ? ( + + ) : ( + + )} +
+
+ )}
); } diff --git a/web/src/app/admin/models/llm/constants.ts b/web/src/app/admin/models/llm/constants.ts index 2db434ee9..a265f4a2b 100644 --- a/web/src/app/admin/models/llm/constants.ts +++ b/web/src/app/admin/models/llm/constants.ts @@ -1 +1,4 @@ export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider"; + +export const EMBEDDING_PROVIDERS_ADMIN_URL = + "/api/admin/embedding/embedding-provider"; diff --git a/web/src/app/search/page.tsx b/web/src/app/search/page.tsx index 891014b07..364e00ed1 100644 --- a/web/src/app/search/page.tsx +++ b/web/src/app/search/page.tsx @@ -20,7 +20,7 @@ import { import { unstable_noStore as noStore } from "next/cache"; import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh"; import { personaComparator } from "../admin/assistants/lib"; -import { FullEmbeddingModelResponse } from "../admin/models/embedding/embeddingModels"; +import { FullEmbeddingModelResponse } from "../admin/models/embedding/components/types"; import { NoSourcesModal } from "@/components/initialSetup/search/NoSourcesModal"; import { NoCompleteSourcesModal } from "@/components/initialSetup/search/NoCompleteSourceModal"; import { ChatPopup } from "../chat/ChatPopup"; diff --git a/web/src/components/Modal.tsx b/web/src/components/Modal.tsx index 9f9893afd..a9e1f5cbd 100644 --- a/web/src/components/Modal.tsx +++ b/web/src/components/Modal.tsx @@ -1,7 +1,9 @@ import { Divider } from "@tremor/react"; import { FiX } from "react-icons/fi"; +import { IconProps } from "./icons/icons"; interface ModalProps { + icon?: ({ size, className }: IconProps) => JSX.Element; children: JSX.Element | string; title?: JSX.Element | string; onOutsideClick?: () => void; @@ -13,6 +15,7 @@ interface ModalProps { } export function Modal({ + icon, children, title, onOutsideClick, @@ -44,10 +47,15 @@ export function Modal({ <>

{title} + {icon && icon({ size: 30 })}

+ {onOutsideClick && (
{ - return ( -
- Logo -
- ); -}; - export const OpenSourceIcon = ({ size = 16, className = defaultTailwindCSS, @@ -528,6 +519,62 @@ export const ZulipIcon = ({ ); }; +export const OpenAIIcon = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => { + return ( +
+ Logo +
+ ); +}; + +export const VoyageIcon = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => { + return ( +
+ Logo +
+ ); +}; + +export const GoogleIcon = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => { + return ( +
+ Logo +
+ ); +}; + +export const CohereIcon = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => { + return ( +
+ Logo +
+ ); +}; + export const GoogleStorageIcon = ({ size = 16, className = defaultTailwindCSS, diff --git a/web/src/lib/chat/fetchChatData.ts b/web/src/lib/chat/fetchChatData.ts index 8c1c3fdc9..de78dec8e 100644 --- a/web/src/lib/chat/fetchChatData.ts +++ b/web/src/lib/chat/fetchChatData.ts @@ -13,7 +13,7 @@ import { } from "@/lib/types"; import { ChatSession } from "@/app/chat/interfaces"; import { Persona } from "@/app/admin/assistants/interfaces"; -import { FullEmbeddingModelResponse } from "@/app/admin/models/embedding/embeddingModels"; +import { FullEmbeddingModelResponse } from "@/app/admin/models/embedding/components/types"; import { Settings } from "@/app/admin/settings/interfaces"; import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs"; import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";