From 299cb5035cdf4bd63ea3bbdc7678b32cc9849a0a Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Mon, 2 Sep 2024 09:08:35 -0700 Subject: [PATCH] Add litellm proxy embeddings (#2291) * add litellm proxy * formatting * move `api_url` to cloud provider + nits * remove log * typing * quick tuyping fix * update LiteLLM selection logic * remove logs + validate functionality * rename proxy var * update path casing * remove pricing for custom models * functional values --- ..._add_base_url_to_cloudembeddingprovider.py | 26 +++ backend/danswer/db/llm.py | 9 + backend/danswer/db/models.py | 5 + backend/danswer/db/search_settings.py | 9 + backend/danswer/indexing/embedder.py | 16 +- backend/danswer/indexing/models.py | 2 + .../search_nlp_models.py | 4 + backend/danswer/search/models.py | 1 + .../danswer/server/manage/embedding/api.py | 13 ++ .../danswer/server/manage/embedding/models.py | 4 + .../danswer/server/manage/search_settings.py | 6 +- backend/model_server/encoders.py | 47 +++++- backend/shared_configs/configs.py | 1 + backend/shared_configs/enums.py | 1 + backend/shared_configs/model_server_models.py | 1 + .../tests/daily/embedding/test_embeddings.py | 3 + web/public/LiteLLM.jpg | Bin 0 -> 12575 bytes .../app/admin/configuration/llm/constants.ts | 2 + .../EmbeddingModelSelectionForm.tsx | 22 ++- web/src/app/admin/embeddings/interfaces.ts | 2 + .../modals/ChangeCredentialsModal.tsx | 38 +++-- .../modals/ProviderCreationModal.tsx | 24 ++- .../embeddings/pages/CloudEmbeddingPage.tsx | 159 +++++++++++++++++- .../embeddings/pages/EmbeddingFormPage.tsx | 2 + .../components/embedding/CustomModelForm.tsx | 18 +- .../components/embedding/LiteLLMModelForm.tsx | 116 +++++++++++++ web/src/components/embedding/interfaces.tsx | 28 ++- web/src/components/icons/icons.tsx | 15 ++ 28 files changed, 524 insertions(+), 50 deletions(-) create mode 100644 backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py create mode 100644 web/public/LiteLLM.jpg create mode 100644 web/src/components/embedding/LiteLLMModelForm.tsx diff --git a/backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py b/backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py new file mode 100644 index 0000000000..3fc01931fe --- /dev/null +++ b/backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py @@ -0,0 +1,26 @@ +"""Add base_url to CloudEmbeddingProvider + +Revision ID: bceb1e139447 +Revises: 1f60f60c3401 +Create Date: 2024-08-28 17:00:52.554580 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "bceb1e139447" +down_revision = "1f60f60c3401" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "embedding_provider", sa.Column("api_url", sa.String(), nullable=True) + ) + + +def downgrade() -> None: + op.drop_column("embedding_provider", "api_url") diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py index 152cb13057..18ad22e50b 100644 --- a/backend/danswer/db/llm.py +++ b/backend/danswer/db/llm.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel from danswer.db.models import LLMProvider as LLMProviderModel from danswer.db.models import LLMProvider__UserGroup +from danswer.db.models import SearchSettings from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.server.manage.embedding.models import CloudEmbeddingProvider @@ -50,6 +51,7 @@ def upsert_cloud_embedding_provider( setattr(existing_provider, key, value) else: new_provider = CloudEmbeddingProviderModel(**provider.model_dump()) + db_session.add(new_provider) existing_provider = new_provider db_session.commit() @@ -157,12 +159,19 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | def remove_embedding_provider( db_session: Session, provider_type: EmbeddingProvider ) -> None: + db_session.execute( + delete(SearchSettings).where(SearchSettings.provider_type == provider_type) + ) + + # Delete the embedding provider db_session.execute( delete(CloudEmbeddingProviderModel).where( CloudEmbeddingProviderModel.provider_type == provider_type ) ) + db_session.commit() + def remove_llm_provider(db_session: Session, provider_id: int) -> None: # Remove LLMProvider's dependent relationships diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 3cdec32396..6d2b92b197 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -607,6 +607,10 @@ class SearchSettings(Base): return f"" + @property + def api_url(self) -> str | None: + return self.cloud_provider.api_url if self.cloud_provider is not None else None + @property def api_key(self) -> str | None: return self.cloud_provider.api_key if self.cloud_provider is not None else None @@ -1085,6 +1089,7 @@ class CloudEmbeddingProvider(Base): provider_type: Mapped[EmbeddingProvider] = mapped_column( Enum(EmbeddingProvider), primary_key=True ) + api_url: Mapped[str | None] = mapped_column(String, nullable=True) api_key: Mapped[str | None] = mapped_column(EncryptedString()) search_settings: Mapped[list["SearchSettings"]] = relationship( "SearchSettings", diff --git a/backend/danswer/db/search_settings.py b/backend/danswer/db/search_settings.py index 1d0c218e10..0cb5029533 100644 --- a/backend/danswer/db/search_settings.py +++ b/backend/danswer/db/search_settings.py @@ -115,6 +115,13 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None: return latest_settings +def get_all_search_settings(db_session: Session) -> list[SearchSettings]: + query = select(SearchSettings).order_by(SearchSettings.id.desc()) + result = db_session.execute(query) + all_settings = result.scalars().all() + return list(all_settings) + + def get_multilingual_expansion(db_session: Session | None = None) -> list[str]: if db_session is None: with Session(get_sqlalchemy_engine()) as db_session: @@ -234,6 +241,7 @@ def get_old_default_embedding_model() -> IndexingSetting: passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""), index_name="danswer_chunk", multipass_indexing=False, + api_url=None, ) @@ -246,4 +254,5 @@ def get_new_default_embedding_model() -> IndexingSetting: passage_prefix=ASYM_PASSAGE_PREFIX, index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}", multipass_indexing=False, + api_url=None, ) diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index f7d8f4e740..d25a0659c6 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -32,6 +32,7 @@ class IndexingEmbedder(ABC): passage_prefix: str | None, provider_type: EmbeddingProvider | None, api_key: str | None, + api_url: str | None, ): self.model_name = model_name self.normalize = normalize @@ -39,6 +40,7 @@ class IndexingEmbedder(ABC): self.passage_prefix = passage_prefix self.provider_type = provider_type self.api_key = api_key + self.api_url = api_url self.embedding_model = EmbeddingModel( model_name=model_name, @@ -47,6 +49,7 @@ class IndexingEmbedder(ABC): normalize=normalize, api_key=api_key, provider_type=provider_type, + api_url=api_url, # The below are globally set, this flow always uses the indexing one server_host=INDEXING_MODEL_SERVER_HOST, server_port=INDEXING_MODEL_SERVER_PORT, @@ -70,9 +73,16 @@ class DefaultIndexingEmbedder(IndexingEmbedder): passage_prefix: str | None, provider_type: EmbeddingProvider | None = None, api_key: str | None = None, + api_url: str | None = None, ): super().__init__( - model_name, normalize, query_prefix, passage_prefix, provider_type, api_key + model_name, + normalize, + query_prefix, + passage_prefix, + provider_type, + api_key, + api_url, ) @log_function_time() @@ -156,7 +166,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder): title_embed_dict[title] = title_embedding new_embedded_chunk = IndexChunk( - **chunk.model_dump(), + **chunk.dict(), embeddings=ChunkEmbedding( full_embedding=chunk_embeddings[0], mini_chunk_embeddings=chunk_embeddings[1:], @@ -179,6 +189,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder): passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=search_settings.api_key, + api_url=search_settings.api_url, ) @@ -202,4 +213,5 @@ def get_embedding_model_from_search_settings( passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=search_settings.api_key, + api_url=search_settings.api_url, ) diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index b23de0eb47..c468b9fb18 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -99,6 +99,7 @@ class EmbeddingModelDetail(BaseModel): normalize: bool query_prefix: str | None passage_prefix: str | None + api_url: str | None = None provider_type: EmbeddingProvider | None = None api_key: str | None = None @@ -117,6 +118,7 @@ class EmbeddingModelDetail(BaseModel): passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=search_settings.api_key, + api_url=search_settings.api_url, ) diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index b7835c4e90..d2ab3a582b 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -90,6 +90,7 @@ class EmbeddingModel: query_prefix: str | None, passage_prefix: str | None, api_key: str | None, + api_url: str | None, provider_type: EmbeddingProvider | None, retrim_content: bool = False, ) -> None: @@ -100,6 +101,7 @@ class EmbeddingModel: self.normalize = normalize self.model_name = model_name self.retrim_content = retrim_content + self.api_url = api_url self.tokenizer = get_tokenizer( model_name=model_name, provider_type=provider_type ) @@ -157,6 +159,7 @@ class EmbeddingModel: text_type=text_type, manual_query_prefix=self.query_prefix, manual_passage_prefix=self.passage_prefix, + api_url=self.api_url, ) response = self._make_model_server_request(embed_request) @@ -226,6 +229,7 @@ class EmbeddingModel: passage_prefix=search_settings.passage_prefix, api_key=search_settings.api_key, provider_type=search_settings.provider_type, + api_url=search_settings.api_url, retrim_content=retrim_content, ) diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index 15387e6c63..e9201c9705 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -81,6 +81,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting): num_rerank=search_settings.num_rerank, # Multilingual Expansion multilingual_expansion=search_settings.multilingual_expansion, + api_url=search_settings.api_url, ) diff --git a/backend/danswer/server/manage/embedding/api.py b/backend/danswer/server/manage/embedding/api.py index 90fa69401c..2cee962ee6 100644 --- a/backend/danswer/server/manage/embedding/api.py +++ b/backend/danswer/server/manage/embedding/api.py @@ -9,7 +9,9 @@ from danswer.db.llm import fetch_existing_embedding_providers from danswer.db.llm import remove_embedding_provider from danswer.db.llm import upsert_cloud_embedding_provider from danswer.db.models import User +from danswer.db.search_settings import get_all_search_settings from danswer.db.search_settings import get_current_db_embedding_provider +from danswer.indexing.models import EmbeddingModelDetail from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.server.manage.embedding.models import CloudEmbeddingProvider from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest @@ -20,6 +22,7 @@ from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType + logger = setup_logger() @@ -37,6 +40,7 @@ def test_embedding_configuration( server_host=MODEL_SERVER_HOST, server_port=MODEL_SERVER_PORT, api_key=test_llm_request.api_key, + api_url=test_llm_request.api_url, provider_type=test_llm_request.provider_type, normalize=False, query_prefix=None, @@ -56,6 +60,15 @@ def test_embedding_configuration( raise HTTPException(status_code=400, detail=error_msg) +@admin_router.get("", response_model=list[EmbeddingModelDetail]) +def list_embedding_models( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[EmbeddingModelDetail]: + search_settings = get_all_search_settings(db_session) + return [EmbeddingModelDetail.from_db_model(setting) for setting in search_settings] + + @admin_router.get("/embedding-provider") def list_embedding_providers( _: User | None = Depends(current_admin_user), diff --git a/backend/danswer/server/manage/embedding/models.py b/backend/danswer/server/manage/embedding/models.py index 132d311413..50518e6ec0 100644 --- a/backend/danswer/server/manage/embedding/models.py +++ b/backend/danswer/server/manage/embedding/models.py @@ -11,11 +11,13 @@ if TYPE_CHECKING: class TestEmbeddingRequest(BaseModel): provider_type: EmbeddingProvider api_key: str | None = None + api_url: str | None = None class CloudEmbeddingProvider(BaseModel): provider_type: EmbeddingProvider api_key: str | None = None + api_url: str | None = None @classmethod def from_request( @@ -24,9 +26,11 @@ class CloudEmbeddingProvider(BaseModel): return cls( provider_type=cloud_provider_model.provider_type, api_key=cloud_provider_model.api_key, + api_url=cloud_provider_model.api_url, ) class CloudEmbeddingProviderCreationRequest(BaseModel): provider_type: EmbeddingProvider api_key: str | None = None + api_url: str | None = None diff --git a/backend/danswer/server/manage/search_settings.py b/backend/danswer/server/manage/search_settings.py index db483eff5d..831528b815 100644 --- a/backend/danswer/server/manage/search_settings.py +++ b/backend/danswer/server/manage/search_settings.py @@ -45,7 +45,7 @@ def set_new_search_settings( if search_settings_new.index_name: logger.warning("Index name was specified by request, this is not suggested") - # Validate cloud provider exists + # Validate cloud provider exists or create new LiteLLM provider if search_settings_new.provider_type is not None: cloud_provider = get_embedding_provider_from_provider_type( db_session, provider_type=search_settings_new.provider_type @@ -133,7 +133,7 @@ def cancel_new_embedding( @router.get("/get-current-search-settings") -def get_curr_search_settings( +def get_current_search_settings_endpoint( _: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> SavedSearchSettings: @@ -142,7 +142,7 @@ def get_curr_search_settings( @router.get("/get-secondary-search-settings") -def get_sec_search_settings( +def get_secondary_search_settings_endpoint( _: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> SavedSearchSettings | None: diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 4e97bd00f2..ad9d8582be 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -2,6 +2,7 @@ import json from typing import Any from typing import Optional +import httpx import openai import vertexai # type: ignore import voyageai # type: ignore @@ -235,6 +236,22 @@ def get_local_reranking_model( return _RERANK_MODEL +def embed_with_litellm_proxy( + texts: list[str], api_url: str, model: str +) -> list[Embedding]: + with httpx.Client() as client: + response = client.post( + api_url, + json={ + "model": model, + "input": texts, + }, + ) + response.raise_for_status() + result = response.json() + return [embedding["embedding"] for embedding in result["data"]] + + @simple_log_function_time() def embed_text( texts: list[str], @@ -245,21 +262,37 @@ def embed_text( api_key: str | None, provider_type: EmbeddingProvider | None, prefix: str | None, + api_url: str | None, ) -> list[Embedding]: + logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}") + if not all(texts): + logger.error("Empty strings provided for embedding") raise ValueError("Empty strings are not allowed for embedding.") - # Third party API based embedding model if not texts: + logger.error("No texts provided for embedding") raise ValueError("No texts provided for embedding.") + + if provider_type == EmbeddingProvider.LITELLM: + logger.debug(f"Using LiteLLM proxy for embedding with URL: {api_url}") + if not api_url: + logger.error("API URL not provided for LiteLLM proxy") + raise ValueError("API URL is required for LiteLLM proxy embedding.") + try: + return embed_with_litellm_proxy(texts, api_url, model_name or "") + except Exception as e: + logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}") + raise + elif provider_type is not None: - logger.debug(f"Embedding text with provider: {provider_type}") + logger.debug(f"Using cloud provider {provider_type} for embedding") if api_key is None: + logger.error("API key not provided for cloud model") raise RuntimeError("API key not provided for cloud model") if prefix: - # This may change in the future if some providers require the user - # to manually append a prefix but this is not the case currently + logger.warning("Prefix provided for cloud model, which is not supported") raise ValueError( "Prefix string is not valid for cloud models. " "Cloud models take an explicit text type instead." @@ -274,14 +307,15 @@ def embed_text( text_type=text_type, ) - # Check for None values in embeddings if any(embedding is None for embedding in embeddings): error_message = "Embeddings contain None values\n" error_message += "Corresponding texts:\n" error_message += "\n".join(texts) + logger.error(error_message) raise ValueError(error_message) elif model_name is not None: + logger.debug(f"Using local model {model_name} for embedding") prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts local_model = get_embedding_model( @@ -296,10 +330,12 @@ def embed_text( ] else: + logger.error("Neither model name nor provider specified for embedding") raise ValueError( "Either model name or provider must be provided to run embeddings." ) + logger.info(f"Successfully embedded {len(texts)} texts") return embeddings @@ -344,6 +380,7 @@ async def process_embed_request( api_key=embed_request.api_key, provider_type=embed_request.provider_type, text_type=embed_request.text_type, + api_url=embed_request.api_url, prefix=prefix, ) return EmbedResponse(embeddings=embeddings) diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index 5ad36cc93c..2357d96d95 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -61,6 +61,7 @@ PRESERVED_SEARCH_FIELDS = [ "provider_type", "api_key", "model_name", + "api_url", "index_name", "multipass_indexing", "model_dim", diff --git a/backend/shared_configs/enums.py b/backend/shared_configs/enums.py index 918872d44b..4dccd43e0a 100644 --- a/backend/shared_configs/enums.py +++ b/backend/shared_configs/enums.py @@ -6,6 +6,7 @@ class EmbeddingProvider(str, Enum): COHERE = "cohere" VOYAGE = "voyage" GOOGLE = "google" + LITELLM = "litellm" class RerankerProvider(str, Enum): diff --git a/backend/shared_configs/model_server_models.py b/backend/shared_configs/model_server_models.py index 3014616c62..4be72308e7 100644 --- a/backend/shared_configs/model_server_models.py +++ b/backend/shared_configs/model_server_models.py @@ -18,6 +18,7 @@ class EmbedRequest(BaseModel): text_type: EmbedTextType manual_query_prefix: str | None = None manual_passage_prefix: str | None = None + api_url: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} diff --git a/backend/tests/daily/embedding/test_embeddings.py b/backend/tests/daily/embedding/test_embeddings.py index a9c12b236c..b736f37474 100644 --- a/backend/tests/daily/embedding/test_embeddings.py +++ b/backend/tests/daily/embedding/test_embeddings.py @@ -32,6 +32,7 @@ def openai_embedding_model() -> EmbeddingModel: passage_prefix=None, api_key=os.getenv("OPENAI_API_KEY"), provider_type=EmbeddingProvider.OPENAI, + api_url=None, ) @@ -51,6 +52,7 @@ def cohere_embedding_model() -> EmbeddingModel: passage_prefix=None, api_key=os.getenv("COHERE_API_KEY"), provider_type=EmbeddingProvider.COHERE, + api_url=None, ) @@ -70,6 +72,7 @@ def local_nomic_embedding_model() -> EmbeddingModel: passage_prefix="search_document: ", api_key=None, provider_type=None, + api_url=None, ) diff --git a/web/public/LiteLLM.jpg b/web/public/LiteLLM.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d6a77b2d105b212542a43116f4379641d0739c82 GIT binary patch literal 12575 zcmbul1ymi)(k?u>%ig%VySuwP32wpN2^vUnCxJk44<6jzJ-7vTClClh!rdh2yyyJi z`tQ2yu3Iy^r{?LZ+C9}ZUEMSDwD`0EU@FQg$O2GM000HqfTu0!90e&U6Ln2BSp{X8 zKNZjk4i=8?u;`Edg#Y@CoB!o^ zo_XuP{OvRErmZOj0MMuqp4|E`&-~1r|K*>rh1|l%%@HEw0O6?}9Xugz_(RWEf@brW z?(?3W9`gYJW)T3oKKw`D`2)nAXAnH$KNxi;0ARcYfQF9$VCIDY&#Fqn`|yr=)8ifXNlURpOO z1YQ=*B<^4?E?5vhj=qK>!Cf;W@my(YZ(CXvD7iLqm})7zB>Cshpxy}zaT(3N*w!U8 zR;t~Rf3DyXgo(S;doV1&7cB3}Om~v@`j>d9uzW!%nMq~_At&sczu3Ub@G`;$k?WOa z^l-95VXe=9U=FQ0KY<5}pbxm3f5jjrk~N~q@^^1NzEx|~)*4!INQ8_bZNCu7aK40I ze)EqQ07SMmZQTh6#>Tm3+(VHA?6HP*rH>siZN{E&=zqxXBJ!B}a#bw*_zq_cKt6mZ zT(XE=Q~z@T5Fz+Aq1j|A(Iz6RyV8Te!u`0dJx#$#_^oVJVKkpNgAx)AWv%e-SwUZr z@MTQISm|=NL+jrBHBNjp#&~QVZOi_QI8@Ic; zSRKz!2ctI6>SEZS(uee8p4~+ONDrg1#-|;A@on36Om^xNW`=~5uHvsjegBccvL69V zdE1q3Xbrx{t}nEzVKB2JpzYw1Zi7R7J32p??NZ(TbQ`)N!o6zMZP&SGazOj093wYe zb^HVn+;hKLCV+%syByoB>GkWovQi^{$s+(jS4utP*IT!)1Uue|tdqHJ+Si8Xs~Vi@ z*g4dEhhk#$4D__B>$d{gN*$+m(3!UqG1cf&SeD{`R$f;*PHrrXt>qHU{Qd+~*>u<0 z^g%}39~n=8#Qxo!|4*Hf^#pE68liegx$Ew~ z;r*2Z8Ac5hLL)~iD(0Ljmlfl(cRi2V+>Kir^RFg;$;s~}iZG=~6}cNQ9_MQePtL^& zy9Wq5xUmWzyZ3Kd9JYH9{LW5H40Hjk!MC*q%aiFRr4^pnlf>8i)89j$^9EGRRO?Rj zV{^9DWL4m9A=5Wy>hDm(FHYC#;F6yJ?W)#Rl>H|_{JW(I?)Cgv(`)-Zy`MKzT(s+a zl%@clcDs{f4gQ3fJi2dH*{=8Nmw3B>Y`!lZ3$Z9bJcJe$k4xZ4BBuoq6H+0mfi_49 z**Wv%m8%Q>ehZ1n%us{<4}W(8NY#H4?5&7#fSW8EB!)}(?^WH^vW?=)&juqyYwez^|BZIP0=WzE?sP8jqv*tb`b0Tn>keDS>ew{yplxSH_r&v{uFlIYKGGH%f)mQE8zGK2x&WYa)4^oai`wkAp-BS z9E1ksub{hyFRT6JeK{_V^&bg5(T|=-hUc5P{N!j##?>-oMUSx!{73q^fSJsr=xrL* zxteMzOm-6$!ZUQO9Of8uJj%G9RA_c}APCHOIXo%B`9~oLk{MVq0Vr4~Xh>p% z0}2`j77W6I!=_*t$0VoZ;N((=$D!h;HicvhM2HL&9Q1cwS#ARkit#@8>WpJ_&k(~1 zB@X%}{PP)+Q=9u3iJ#OO(?e%7ul!^UlHrZaV-n*ZVH@-aXqt(clGmVF-{}_E;-s$c zsR&Wge))O(b>eo&p+=43Otvsuu|pd-qep=7$9KBS>VldUnal0Osh0Y8$-+#$V2+VW z6COO7GHW*Y)tDd|&B4H(Te^cowJNXKVIsrQ-A5XMMxy29g8?PVp(kLvdSp<7Qhs85 zxHjnUXtbV^wLLm%lPopXEfQ-yJRRAAhy|Cw|L)#K>3zUR3r!>+&7})bU5NZFXqq8r z#9_qkT)2^>07KqL3P$?t?%ZTF?az}D)hSmwUPFminfkM zV8S2^0vuWy z9!NkXBE37i7HAf>V|61tAsUV36W|_E61{xW*gPKetWSgtQK_lO-Rs^RcovKSNH1`G z(9D}sUfT2ev1N`+@)x>{fScmOyQcXqY>SJc%!}aZZI9kFrq`s`q9z{g2ax>qViKZ9 zDJu-TBwt*f)~grrNO0swJnESOVM=po;8UrfiA>zvmpa+uE#Xl+Ypgfrp3ZV7)vj zzPk70V;t2lKX1oDofPU%-?9f6=fB#+*JDi?Mp!Y8#i~4$!mj*Qz2GEoUdiX(L*zfY z{lDaH@3F-W_~Z9aT~Qp%urls3?A5jpGU7fvp-?3LLxYY-rz$Ek8l%&A{+yg3*{Tt? z!NvS1K%p;4%&M;!Y z=LkFLJMqlpAL+mknb|w6Mj#KIy<)h)4_9QY%u!03jYLI!Dy(q^jh{L71Iu=Vj z`Fmo%OJWJ0@LFM&Ppov3iqnJ{0(TOS3;D+rf4GFogK-d#ZJD$$dHm62V|B0AmcDY% zsr|y6Y-_6G@CH`EmU0cA^j?{|O3TppnYX!A}nnA5d?NFVDdrZXr7HQ?~JYhPC zC(bM9YUK(sfgjUje#4a{Av!7_?y8QeAli=e|3@{05UzYYS1*giKY=I8x0?cDoQe*$e)+VU2&(_fAm&hjbsKbA=_0uF8?{`s#-tc*H z#Etx%MW<>+Y5lhR_K9&g?eDC2BCTIG&ZlR?Js%r+h*KzEX5aBs>6s9LV$T%&5^MVo zGVk#7J>tV?`zC(GIK)4DyO1dWi1(JpsK>#!q z3=Aw3EcBl>7$nC*12C{4FgXV%76rQmHl?@*rzRDb4%e@lzd2My;Nz4-u%WW zDinW9t5a%676e|~7bWuhqVb+_jx`8*xg&6iJQGj;!Ss*!C> z(Jn(07RI)m&ez-8%c#M%A$SDI@CisbD^%L)-wx0njCeqiFgat)Vqba`sp}%$UBRnmX1gXIwWI7uFt(jg1Ybi6-u8sG!W( zi`XL$Zt2g=x;GrfKVtM2g-%eCO9a=&jQl{(i0@Wc1BYV1eNB*jI`WJ3hvZ;izBu`{ zZavzr1r?j33L=S?@X#mI){a+IrNH9T`#voWQIXfm6HU$SIr?`ZMnfaekf_@IAMN}G8 zp@iPhA-f$Ok5f6c+Za|ds6;&4o~asrj1MPbTjcRVrdXSD0LjF0Dse$Xqh$pvcYX}J zKayGH_Dft_*{<%3qPW&r^+CzsBn$)}4nou7jtBC zDOUAD;PnD)8h7II!=A?j${tZ&VqecueshmF-?80q2bbw@T=8r+5buo$jEFwCT>cng z7$LJNFDQ}wT)Js@yOBHQj$_*oT_L2rwu|@0*4_~rN;$w%zus=H2>MEqoXU}T7*cDT z_BeudFij%gu5_-+b9R$F`tcfG>3#r9_SC&|dhZD+c=fx|`P#j+6whQWt$h^T&z8Qf zBw!^~gpGGMoZDV6w2+ty`H--MwlcCTA0^-JJ^K7Rs!qf{gI!CBl-vk*o7cJBVX?Vi zxJO8Tda3t+@K&=jQ~26MpChW8A=UZWK`?jnJ8kU6SPK=l#ll1ek5tEBIpJ)`<;jACO=XWbi;{yU@@Gh|fdK-*E;K86XN6OfV%o2PZk#b6mJ4s$+5I z!cqleQ%fe9nuq*r)(bBNy-B-KA3FYOw{TSHBTtD!?`DzCVzWF@bQMaM%r9UJ%T$T5 z!4-K@xvIsnVeR9s*Y$IoD_iEP%QOcGo>ZT7GQ7Mir@BW^v>M8KUN!-4^4AM513vmJ zK?=P@+tY?&_Q8U%cQ6I+C=Ue>`kR>&+|#h zGTi3kXAI8L99KQ))1CIG0er7vVV>G@||kbjf^ zs6uFS`u0J4=G}N$ zfaJ2v`as>P5x!wk=PgFP4Z>@bfBt_0}*Ez=-Moh25aqz8DJUN2(JaI(4p`93KBp znzUG-wwQJ|b4IPEZq}zoL)+-6H{I?xD2n8lgDknWG^^N)=Sm9ty2&XMLGOr{5ipD8 zOg$OrOQj}SSwV=MKPsPqVf8z7@forr%=|nD_P1V+*Bi706)OQO|C`;#b%!_Nc4FUY zap*z&H7fLUv|67y`MM8pdg^U_>G?MLircMCCKVViXyIdz9`aCsjW$NN9x@Cmx={x- znmREmaa+ix)xp73%Z55#^7eh zMwWYGdzs~HGsUV`!Zijr(wigA+G0NrcgU^boUI($zka-AjdqDF0^fMgu3;x-RKAZwLQPavJf z5}>DmteYK}k0FutrXWkA1AL0wqVTg)jkZyZ%A##o zVZQm$CYmw+7q|#63C7f78EoGT^I*Sh6FzNC{n{V)JRj}4TB3a^g8?fg*n#dZ4n$Ns zn<$PCX-i$n9l-^bxaK~+FNs(K8zo6Z$YictPo`f{D}w})K7g~VhBIT6F68Sblh(O6 zs#USL6m~{MU#>sI%&_mG)*UGaE4XKC_Slv;)h6U0S5Lt2;gkfpgwF7?alG8(Lv4^8 zebi@xzBJ@nfg}4MNkA>zgBzo)Fldw2S1$$j6pjn)3QIGqIV5r%Jt5K1(lpI)BZ4ja z);_+JoC&9BRZejcB2ANHU}g(KT`@P7R6!)cZj>|{XDiA5Np+_uFkQ>)@|uJ!D%D-l zbI0W!=S2i&uh|QR%UACY+=af3f{9N#>oqy;5_Q;%?eXBr}c+?-~ zfOmfvvtqRD5g2lD`~E0DgmEB$f!+}MHn;d*G{l~Vl_1E?3i$Ys*O#B z>VuHrlRXhFyXE6UwlL#lhRPz~_P$qR%H%ODH?Lk+c%bPtPLASCfH+P{N4i{A&qpyY z!7PdxSyc_=*VH3fYLJ}ww7)Gi>9yJ$8JCa*A<3$#^R7XzkHw@Ext>$rGzY^LXw~)o zq`3Z)AK&-?_6U;OL((cRLKMxL^&<#Yb)Gus`i%u0^#n}45yr~QDkvC6!Z(*e7-c8i zkpJ?Ay`hgz>a6&7^9ev_aQM`ca4=j^U^or(u*ynhaIbvNRn0(1OPWL>(5d>0gi-a9 z`QAJjun3OUZ!j1tNJ@Mqvd>J@tM?WhztnI~8h^4F|06sXw}xC&F5&kA^>0^hp?X9) z*aGSz#WoM!2E&GSp!kw<-yd2%^{3&H|>!5tb>&WR$+|XR`{f=}Lqc2&9)f z#xFv99cw}igz4ncntFHXRIYK+6}JdYU=lH>vZ;1xXKlv5$sYCeky2wn6EU$ps9ZLe zFjf15FrKVlI$y{rt)ef9f(8WxlKg8s)Ktf6^hh)giB`I)`a9kyKz3*;oy*)or4px( z^qr5KFeTwcxiLjVVZlqSnVh7ErCAXqvL*~z-)T%^ichrfTMS^fw26-VGDtYgf zDSf;wf3Qe~()jPgg(n4gSfnv_N^7n@kQsR?TO*-biDycGerN7ZQzD{Zr?F=#KC_4U zDY-3_uM+EQRoz%dU2{}6sXSTo%X$}C2=z;CBSN%IUmP1Am^0e&g`8=kp~x&Ll-q9m zjxdLHx!bL;Z@s&lH`aobeb>KCs8?X*rj3IY_Q_b;rI2HElnc$B+@h9U4(zwhiO6kU zyzYonsY_n0DD6n4*6z&eq++oBK0Uw@V}akJK(9cvBwaIp&BQSz+jp=~-!+kc{Yo9p zIG~gzkxgGcpHuH9lxNfnbM$u$0k+U@X+L3P>dR4`d+W$(7)4Z&;BCLYW$9&(z)8fut-#U!-&8; zX)-}QqLrIRL$l}#L{YQQqR@z>L;@>0R6+?p{fzob;-TPr!PQu$D@t?Kn?+@0*(rA= zbL*5?U%&SD9aurqXh$KrN?VWSn=6T_HyNP4N};<8o2#{`o-l3sOKUN34{_|rEwg-3 zLQn>+!4Yft#>j;2TXQ@UI3tCokdT|t(iGHkxn9|j-1X-L;h#67m{*g+SCjO`6-g+F z3S!UvpM8W)$wdMdQv#OVZ00Z0ZLSPCecr9f-T+z-0Jjp!{fG%0nhSqppb?9Pp170 zu$XdP8-FbLuMYjn@{y0K!<0}(42H-6iyy=DDVYN~FS{gQiMOATYHZ)lxICPY%du`) zOxtJUlRvh$^4@4RJAcrH2wdbyz=991Fm7u(L+YSTu_5hmfk=AJ)y(K{oeS5z{s=Ke z%@2`;B^F;%RlR2zJG+O1#ca)iC}K)L2lnu1ni?tl!(u8@jyxl900%6#1uSMTET*i% z2nk%$Gx*QWq9$DH@t;!&PM8gGvd}*Q5C*05+3javhW`PHb^dh3gthp;Oi2b+{Xr&$ z_ZBA*pTXg4lYb!)T^eg-soCoFznuk{%8-K~%OI%dHvs>BkpKmNILIm3H83U2#9jYP zV4k13U~Y_(e`(49on`xoA$_z%vd^~Tv#!r$Sly7E{edV_DKtgaqttfE;N&HT8Et+# z!3ttv;wsT*brM6Ip}#lF6^}i}7>F9;6PQ{;y&t+JNR`-q^XkoAeDZf!n9tCYC^CaD z7D&NR9|sBtA;y)KVYW)M-Hu7SRYBb=5=#G2Q}dRweqY%FKF0|Bsn;k;Z4;$H6N9(*$A}3vt!QvN*}`Kqau)zf)Kfn$?+pl@6-wk?v$n{ zQYkZxS##HX4%P$P>X`+{x=DBMQ_ceDrrFko4w3V8vmTO2w-|$HkGqS%{^p=0m%w3zW;RFiQK_`3JX z0;$fgcp8TJ`ob|8oea{WWPf@T@&M$&dz1o`oSg$QJr)1cq0dh~FgK2dCccX(i~P{e zmcGoRCH4O(eeMnj@?-S)Db?|>e!fG^IgZ=l#Io_5`Gy zyzktW)FQiFO7;~{tVfu_n!=6pW34xTjEEJS57<_~K1hprC|pw_1@E_~U9B=A;`AhG zp0TrXpGNis_A-3X^u+Mv6#ZoB18*6<04tY*b>so7wby>`cznct9)>Q6^Xml%jRHd&rs zdA;w7u9G_yQ@Ed;)=%O3%!Tw4x$yH(%U0q>hD|+x9LhsiEAqZ^uW&ae9k*^6OTi_+ z3_dpf4SiO@qp&XAAXed4_j&}SoU#8W--<2 zhH=Z;%*T4(O+s8bWBndM81miS57`-N(k9g>>^?z%))vn$qGURM>pz zpNO0FeFADa_M7}OH*k@r4Mhnbn0dl;&f_1VaxYsljkcX5FC=7E4a9Z|s5t9S`l3M} z*)|&AnQyZ=F*!?2x4M!r2$cuom7AkT`LNEe`|ROp6~A_gp&5VJV}y%F;HXvF{CXQ; z%NR43*fNdH{|2Yc7y*f0<=D|rne##Uc+XIOtO*ytGrioOb>u-#?_)1cRkXJ~OcFjU znzA?%L8wl`6{}tfND>w831}LROg60Qx9)iYLZD%%s8}mkP2*mF>}#^!Wu=v_Ua*+- znTV)2Y<~jsB}L!1@Sh{d^Gd1FZYRS*-s{-z+M|i(i%X+@n3G_i%Zl^sAq?V)9UKS_ z*pdkiM*zo(6*h5o?1Q*j%YxrNT=Y~&CVo1*qL*XqU%x2+l}Z#qD%iEBg7s?&6-Iu$ zyk5ucL001&^$>PSBw|9XV5+yxDef4pTJ+RX@P0~jIl327Hh3^Ng-bnBhIlmmhHN+n z6R1z2DU(C!3xgG==oIW{5d~=R@ygsK$95^EH%jXV%u0W|o$7hP%9y6hQZ0S2uxi@k zD*-DStEy6FTD;=^&FrX468oFl{FKpBcfx`uTr2f$yks7}vl~9tNkaQlJw6IfDd+XX z7bk_(>p?GQJ=?KglM8*83;va~?@h=o9!SV>5AnU$4Pt*2iL=>!@$~LRJ4!EbWNqO1 z5yhzN^))9qJ@+Q1g3P9p%*MC+$9LsE45QpkpVaTn(xYv6PmoWzcvmi*d=j& zYBFp|^GSRk&@A*0H;d8SHNOv`m{$?nb54Y((6Aal7JT*R%X@42jiT{z2v_eYbRj2v z%-c%-m7-k@b!(fx>ia#&*9)mEMs`$0O1%n!xaq*X-1^2j%GXHZ#k3&MCilhunj!T5 zYmJX}dxI%h0o|W1GxecNvCo4kTZ&b-K1hLz=(H)Y*3M#b+fLz7&Jw%|OmTGpJRFuf z?f?^$-4d6nX(+0mP<`>UEeb}@KVRD z%g+gJqa4%g5QY}6>EkF0uj}-7!9;dxCcB;TuT+Ebe3g0PHHo$6iT7BtlbNaV>-Ja0 zsuKE6*wz5l$Ux|v6*q|>*4H0w5n$7s&{QUU?iAbRd7&9-IML3as_XPCNnXvVxU+J2#hzrElh5>&7XPXq5NVe1)i<}sotYudhcKvW_UT@?Kmisy9d!#zSE@8sH6OoP7CK632Y2L_loM};3qlazB z7rg{!oXmu2obk*xofs=`#_RGCH-ntMW8E_ zNyZB?{%fgwDGO%a{i(=JJaAol$G1rri=kM3IgQac{nNY2^@~5|H74Y0p8)@Y&v|nC zVJ@b9tJNG8Ez*xxdA-~$U28UHv1(qXE%LI(bxI~aDO8w*U*h$tZ1|h%t0q_+&o`O% z-?;TCconv0(y+A!(gYm@(jKpgVLiHk#}AnK&gRQdZoLcNU!fW3evQ$Ve^2^7p}M=c zjvOgP@@m=IQEwb7P-Q{nj_kF@9o@I8!4K#~h^&=1SOUt}j$TUo$=4bU2WbV`+|wi)~tyc9ialid8h)O$uhFH##~7Q>m-m(nZq+e7WIq9w+l);9viHBfE@J4(?n<^O zossaSiD~Wb%x^=BpxobD*gFWNiCU)!ffLDJ!kI5KslM`{7BOgK^KrmJcD;VfZR*QO z;u+JFWK~??tFV}n3;skf>o-xn7MSw9c~0%g`LtUyB#FYxd%EjZx~JW%3a1-r-!FU> zMxBdPR*}w{Z>wlg;hywqRx+ToLBh8Tf7~+Z*r694rAvu*1Gw5BuN2ux$Cak+A5oX6MNHRQcuNAKs~{ zQ2--3iZd?f#%Kl2U4pK@VrmTO&JSh68>l3wv8gN z6Y*H{q34udShQ6&P7}53)-`%LHqCN6ZdROVm(nqBR_Mbq;B;UA zYoQBOoaKdrOu-ii|M8+MG^3Z{vX*VN1$tuTp-aCk^vqpgspQj|Wx6HTRETs&By&Sw zGFTL$)rCR7Z5F}w;W_rP(jhsM1D~L1;D0Xy>KaB4u0!B(h~W`m<3N>d3rZ%Lh4V1UAc6_h;b2c6 zS{E8qQs%9Ayh3$?P?DJk){#(!up=Mg{d2UT7rAw^?xk)y>l8&J9JdJC)DHEQSA5$P z;o!4J?LNApdc1zC$r`#|E;AKY&szkUnko2(i|=WM{cvs_GGk)Z=J8gJ5uh!kRp=iJ zlAHJV`Z3OICxoi4Dt{Rs9tEkHNzY^|O1=4N^JkbTki$#6reTdkg}~f?LPA@0Fo@l$ z9B7H2NtSB6+oCiM-x^P?X*=#imQCz=`x9=KodRp3H40yN8#zosFlk_k>v-O_?zmu` zx_BG;D<3MN{@%V$cveDpQXP(pKK z@eLZXz?RB#X1KQJMxW4&vld^OQ?{n3iZOVO#MgK6WE_S-gP7@N}6AsT=dA8AFk zMy-k38h{p6xyWvnEEC0Z&b&X4pEmvFLxxZp9?=;Yy)=Lp!r38Mjf?0BHl*k4F%Usw zrST^y(9bG~pQbsBqMUK|e*)Zj(false); + const [showAddConnectorPopup, setShowAddConnectorPopup] = useState(false); + const { data: embeddingModelDetails } = useSWR( + EMBEDDING_MODELS_ADMIN_URL, + errorHandlingFetcher, + { refreshInterval: 5000 } // 5 seconds + ); + const { data: embeddingProviderDetails } = useSWR( EMBEDDING_PROVIDERS_ADMIN_URL, - errorHandlingFetcher + errorHandlingFetcher, + { refreshInterval: 5000 } // 5 seconds ); const { data: connectors } = useSWR[]>( @@ -175,6 +187,7 @@ export function EmbeddingModelSelection({ {showTentativeProvider && ( { setShowTentativeProvider(showUnconfiguredProvider); @@ -189,8 +202,10 @@ export function EmbeddingModelSelection({ }} /> )} + {changeCredentialsProvider && ( { clientsideRemoveProvider(changeCredentialsProvider); @@ -277,6 +292,7 @@ export function EmbeddingModelSelection({ {modelTab == "cloud" && ( void; onCancel: () => void; onDeleted: () => void; useFileUpload: boolean; + isProxy?: boolean; }) { - const [apiKey, setApiKey] = useState(""); + const [apiKeyOrUrl, setApiKeyOrUrl] = useState(""); const [testError, setTestError] = useState(""); const [fileName, setFileName] = useState(""); const fileInputRef = useRef(null); @@ -50,7 +52,7 @@ export function ChangeCredentialsModal({ let jsonContent; try { jsonContent = JSON.parse(fileContent); - setApiKey(JSON.stringify(jsonContent)); + setApiKeyOrUrl(JSON.stringify(jsonContent)); } catch (parseError) { throw new Error( "Failed to parse JSON file. Please ensure it's a valid JSON." @@ -62,7 +64,7 @@ export function ChangeCredentialsModal({ ? error.message : "An unknown error occurred while processing the file." ); - setApiKey(""); + setApiKeyOrUrl(""); clearFileInput(); } } @@ -74,7 +76,7 @@ export function ChangeCredentialsModal({ try { const response = await fetch( - `${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type}`, + `${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type.toLowerCase()}`, { method: "DELETE", } @@ -105,7 +107,10 @@ export function ChangeCredentialsModal({ headers: { "Content-Type": "application/json" }, body: JSON.stringify({ provider_type: provider.provider_type.toLowerCase().split(" ")[0], - api_key: apiKey, + [isProxy ? "api_url" : "api_key"]: apiKeyOrUrl, + [isProxy ? "api_key" : "api_url"]: isProxy + ? provider.api_key + : provider.api_url, }), }); @@ -119,7 +124,7 @@ export function ChangeCredentialsModal({ headers: { "Content-Type": "application/json" }, body: JSON.stringify({ provider_type: provider.provider_type.toLowerCase().split(" ")[0], - api_key: apiKey, + [isProxy ? "api_url" : "api_key"]: apiKeyOrUrl, is_default_provider: false, is_configured: true, }), @@ -128,7 +133,8 @@ export function ChangeCredentialsModal({ if (!updateResponse.ok) { const errorData = await updateResponse.json(); throw new Error( - errorData.detail || "Failed to update provider- check your API key" + errorData.detail || + `Failed to update provider- check your ${isProxy ? "API URL" : "API key"}` ); } @@ -144,12 +150,12 @@ export function ChangeCredentialsModal({ - You can also delete your key. + You can also delete your {isProxy ? "URL" : "key"}. This is only possible if you have already switched to a different @@ -219,7 +225,7 @@ export function ChangeCredentialsModal({ {deletionError && ( diff --git a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx index ab9fa663ee..b4aa909aea 100644 --- a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx +++ b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx @@ -13,11 +13,13 @@ export function ProviderCreationModal({ onConfirm, onCancel, existingProvider, + isProxy, }: { selectedProvider: CloudEmbeddingProvider; onConfirm: () => void; onCancel: () => void; existingProvider?: CloudEmbeddingProvider; + isProxy?: boolean; }) { const useFileUpload = selectedProvider.provider_type == "Google"; @@ -29,6 +31,7 @@ export function ProviderCreationModal({ provider_type: existingProvider?.provider_type || selectedProvider.provider_type, api_key: existingProvider?.api_key || "", + api_url: existingProvider?.api_url || "", custom_config: existingProvider?.custom_config ? Object.entries(existingProvider.custom_config) : [], @@ -37,9 +40,14 @@ export function ProviderCreationModal({ const validationSchema = Yup.object({ provider_type: Yup.string().required("Provider type is required"), - api_key: useFileUpload + api_key: isProxy ? Yup.string() - : Yup.string().required("API Key is required"), + : useFileUpload + ? Yup.string() + : Yup.string().required("API Key is required"), + api_url: isProxy + ? Yup.string().required("API URL is required") + : Yup.string(), custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)), }); @@ -87,6 +95,7 @@ export function ProviderCreationModal({ body: JSON.stringify({ provider_type: values.provider_type.toLowerCase().split(" ")[0], api_key: values.api_key, + api_url: values.api_url, }), } ); @@ -169,12 +178,19 @@ export function ProviderCreationModal({ target="_blank" href={selectedProvider.apiLink} > - API KEY + {isProxy ? "API URL" : "API KEY"}
- {useFileUpload ? ( + {isProxy ? ( + + ) : useFileUpload ? ( <> >; newUnenabledProviders: string[]; + embeddingModelDetails?: CloudEmbeddingModel[]; embeddingProviderDetails?: EmbeddingDetails[]; newEnabledProviders: string[]; setShowTentativeProvider: React.Dispatch< @@ -61,6 +66,17 @@ export default function CloudEmbeddingPage({ ))!), }) ); + const [liteLLMProvider, setLiteLLMProvider] = useState< + EmbeddingDetails | undefined + >(undefined); + + useEffect(() => { + const foundProvider = embeddingProviderDetails?.find( + (provider) => + provider.provider_type === EmbeddingProvider.LITELLM.toLowerCase() + ); + setLiteLLMProvider(foundProvider); + }, [embeddingProviderDetails]); return (
@@ -122,6 +138,127 @@ export default function CloudEmbeddingPage({
))} + + + Alternatively, you can use a self-hosted model using the LiteLLM + proxy. This allows you to leverage various LLM providers through a + unified interface that you control.{" "} + + Learn more about LiteLLM + + + +
+
+ {LITELLM_CLOUD_PROVIDER.icon({ size: 40 })} +

+ {LITELLM_CLOUD_PROVIDER.provider_type}{" "} + {LITELLM_CLOUD_PROVIDER.provider_type == "Cohere" && + "(recommended)"} +

+ + } + popupContent={ +
+
+ {LITELLM_CLOUD_PROVIDER.description} +
+
+ } + style="dark" + /> +
+
+ {!liteLLMProvider ? ( + + ) : ( + + )} + + {!liteLLMProvider && ( + +
+ + API URL Required + + + Before you can add models, you need to provide an API URL + for your LiteLLM proxy. Click the "Provide API + URL" button above to set up your LiteLLM configuration. + +
+ + + Once configured, you'll be able to add and manage + your LiteLLM models here. + +
+
+
+ )} + {liteLLMProvider && ( + <> +
+ {embeddingModelDetails + ?.filter( + (model) => + model.provider_type === + EmbeddingProvider.LITELLM.toLowerCase() + ) + .map((model) => ( + + ))} +
+ + + + + + )} +
+
); @@ -146,7 +283,9 @@ export function CloudModelCard({ React.SetStateAction >; }) { - const enabled = model.model_name === currentModel.model_name; + const enabled = + model.model_name === currentModel.model_name && + model.provider_type == currentModel.provider_type; return (

{model.description}

-
- ${model.pricePerMillion}/M tokens -
+ {model?.provider_type?.toLowerCase() != + EmbeddingProvider.LITELLM.toLowerCase() && ( +
+ ${model.pricePerMillion}/M tokens +
+ )}
-
+ )} diff --git a/web/src/components/embedding/LiteLLMModelForm.tsx b/web/src/components/embedding/LiteLLMModelForm.tsx new file mode 100644 index 0000000000..b84db4f906 --- /dev/null +++ b/web/src/components/embedding/LiteLLMModelForm.tsx @@ -0,0 +1,116 @@ +import { CloudEmbeddingModel, CloudEmbeddingProvider } from "./interfaces"; +import { Formik, Form } from "formik"; +import * as Yup from "yup"; +import { TextFormField, BooleanFormField } from "../admin/connectors/Field"; +import { Dispatch, SetStateAction } from "react"; +import { Button, Text } from "@tremor/react"; +import { EmbeddingDetails } from "@/app/admin/embeddings/EmbeddingModelSelectionForm"; + +export function LiteLLMModelForm({ + setShowTentativeModel, + currentValues, + provider, +}: { + setShowTentativeModel: Dispatch>; + currentValues: CloudEmbeddingModel | null; + provider: EmbeddingDetails; +}) { + return ( +
+ { + setShowTentativeModel(values as CloudEmbeddingModel); + }} + > + {({ isSubmitting }) => ( +
+ + Add a new model to LiteLLM proxy at {provider.api_url} + + + + + + + + + + + + + + )} +
+
+ ); +} diff --git a/web/src/components/embedding/interfaces.tsx b/web/src/components/embedding/interfaces.tsx index c719b7dc7b..0fafaa840c 100644 --- a/web/src/components/embedding/interfaces.tsx +++ b/web/src/components/embedding/interfaces.tsx @@ -2,6 +2,7 @@ import { CohereIcon, GoogleIcon, IconProps, + LiteLLMIcon, MicrosoftIcon, NomicIcon, OpenAIIcon, @@ -14,11 +15,13 @@ export enum EmbeddingProvider { COHERE = "Cohere", VOYAGE = "Voyage", GOOGLE = "Google", + LITELLM = "LiteLLM", } export interface CloudEmbeddingProvider { provider_type: EmbeddingProvider; api_key?: string; + api_url?: string; custom_config?: Record; docsLink?: string; @@ -44,6 +47,7 @@ export interface EmbeddingModelDescriptor { provider_type: string | null; description: string; api_key: string | null; + api_url: string | null; index_name: string | null; } @@ -70,7 +74,7 @@ export interface FullEmbeddingModelResponse { } export interface CloudEmbeddingProviderFull extends CloudEmbeddingProvider { - configured: boolean; + configured?: boolean; } export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ @@ -87,6 +91,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ index_name: "", provider_type: null, api_key: null, + api_url: null, }, { model_name: "intfloat/e5-base-v2", @@ -99,6 +104,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ passage_prefix: "passage: ", index_name: "", provider_type: null, + api_url: null, api_key: null, }, { @@ -113,6 +119,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ index_name: "", provider_type: null, api_key: null, + api_url: null, }, { model_name: "intfloat/multilingual-e5-base", @@ -126,6 +133,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ index_name: "", provider_type: null, api_key: null, + api_url: null, }, { model_name: "intfloat/multilingual-e5-small", @@ -139,9 +147,19 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ index_name: "", provider_type: null, api_key: null, + api_url: null, }, ]; +export const LITELLM_CLOUD_PROVIDER: CloudEmbeddingProvider = { + provider_type: EmbeddingProvider.LITELLM, + website: "https://github.com/BerriAI/litellm", + icon: LiteLLMIcon, + description: "Open-source library to call LLM APIs using OpenAI format", + apiLink: "https://docs.litellm.ai/docs/proxy/quick_start", + embedding_models: [], // No default embedding models +}; + export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ { provider_type: EmbeddingProvider.COHERE, @@ -169,6 +187,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, { model_name: "embed-english-light-v3.0", @@ -185,6 +204,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, ], }, @@ -213,6 +233,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ enabled: false, index_name: "", api_key: null, + api_url: null, }, { provider_type: EmbeddingProvider.OPENAI, @@ -229,6 +250,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ maxContext: 8191, index_name: "", api_key: null, + api_url: null, }, ], }, @@ -258,6 +280,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, { provider_type: EmbeddingProvider.GOOGLE, @@ -273,6 +296,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, ], }, @@ -301,6 +325,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, { provider_type: EmbeddingProvider.VOYAGE, @@ -317,6 +342,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, ], }, diff --git a/web/src/components/icons/icons.tsx b/web/src/components/icons/icons.tsx index 377307caef..b5e735b0e6 100644 --- a/web/src/components/icons/icons.tsx +++ b/web/src/components/icons/icons.tsx @@ -48,6 +48,7 @@ import jiraSVG from "../../../public/Jira.svg"; import confluenceSVG from "../../../public/Confluence.svg"; import openAISVG from "../../../public/Openai.svg"; import openSourceIcon from "../../../public/OpenSource.png"; +import litellmIcon from "../../../public/LiteLLM.jpg"; import awsWEBP from "../../../public/Amazon.webp"; import azureIcon from "../../../public/Azure.png"; @@ -267,6 +268,20 @@ export const ColorSlackIcon = ({ ); }; +export const LiteLLMIcon = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => { + return ( +
+ Logo +
+ ); +}; + export const OpenSourceIcon = ({ size = 16, className = defaultTailwindCSS,