diff --git a/backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py b/backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py new file mode 100644 index 0000000000..3fc01931fe --- /dev/null +++ b/backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py @@ -0,0 +1,26 @@ +"""Add base_url to CloudEmbeddingProvider + +Revision ID: bceb1e139447 +Revises: 1f60f60c3401 +Create Date: 2024-08-28 17:00:52.554580 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "bceb1e139447" +down_revision = "1f60f60c3401" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "embedding_provider", sa.Column("api_url", sa.String(), nullable=True) + ) + + +def downgrade() -> None: + op.drop_column("embedding_provider", "api_url") diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py index 152cb13057..18ad22e50b 100644 --- a/backend/danswer/db/llm.py +++ b/backend/danswer/db/llm.py @@ -6,6 +6,7 @@ from sqlalchemy.orm import Session from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel from danswer.db.models import LLMProvider as LLMProviderModel from danswer.db.models import LLMProvider__UserGroup +from danswer.db.models import SearchSettings from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.server.manage.embedding.models import CloudEmbeddingProvider @@ -50,6 +51,7 @@ def upsert_cloud_embedding_provider( setattr(existing_provider, key, value) else: new_provider = CloudEmbeddingProviderModel(**provider.model_dump()) + db_session.add(new_provider) existing_provider = new_provider db_session.commit() @@ -157,12 +159,19 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | def remove_embedding_provider( db_session: Session, provider_type: EmbeddingProvider ) -> None: + db_session.execute( + delete(SearchSettings).where(SearchSettings.provider_type == provider_type) + ) + + # Delete the embedding provider db_session.execute( delete(CloudEmbeddingProviderModel).where( CloudEmbeddingProviderModel.provider_type == provider_type ) ) + db_session.commit() + def remove_llm_provider(db_session: Session, provider_id: int) -> None: # Remove LLMProvider's dependent relationships diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 3cdec32396..6d2b92b197 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -607,6 +607,10 @@ class SearchSettings(Base): return f"" + @property + def api_url(self) -> str | None: + return self.cloud_provider.api_url if self.cloud_provider is not None else None + @property def api_key(self) -> str | None: return self.cloud_provider.api_key if self.cloud_provider is not None else None @@ -1085,6 +1089,7 @@ class CloudEmbeddingProvider(Base): provider_type: Mapped[EmbeddingProvider] = mapped_column( Enum(EmbeddingProvider), primary_key=True ) + api_url: Mapped[str | None] = mapped_column(String, nullable=True) api_key: Mapped[str | None] = mapped_column(EncryptedString()) search_settings: Mapped[list["SearchSettings"]] = relationship( "SearchSettings", diff --git a/backend/danswer/db/search_settings.py b/backend/danswer/db/search_settings.py index 1d0c218e10..0cb5029533 100644 --- a/backend/danswer/db/search_settings.py +++ b/backend/danswer/db/search_settings.py @@ -115,6 +115,13 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None: return latest_settings +def get_all_search_settings(db_session: Session) -> list[SearchSettings]: + query = select(SearchSettings).order_by(SearchSettings.id.desc()) + result = db_session.execute(query) + all_settings = result.scalars().all() + return list(all_settings) + + def get_multilingual_expansion(db_session: Session | None = None) -> list[str]: if db_session is None: with Session(get_sqlalchemy_engine()) as db_session: @@ -234,6 +241,7 @@ def get_old_default_embedding_model() -> IndexingSetting: passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""), index_name="danswer_chunk", multipass_indexing=False, + api_url=None, ) @@ -246,4 +254,5 @@ def get_new_default_embedding_model() -> IndexingSetting: passage_prefix=ASYM_PASSAGE_PREFIX, index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}", multipass_indexing=False, + api_url=None, ) diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index f7d8f4e740..d25a0659c6 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -32,6 +32,7 @@ class IndexingEmbedder(ABC): passage_prefix: str | None, provider_type: EmbeddingProvider | None, api_key: str | None, + api_url: str | None, ): self.model_name = model_name self.normalize = normalize @@ -39,6 +40,7 @@ class IndexingEmbedder(ABC): self.passage_prefix = passage_prefix self.provider_type = provider_type self.api_key = api_key + self.api_url = api_url self.embedding_model = EmbeddingModel( model_name=model_name, @@ -47,6 +49,7 @@ class IndexingEmbedder(ABC): normalize=normalize, api_key=api_key, provider_type=provider_type, + api_url=api_url, # The below are globally set, this flow always uses the indexing one server_host=INDEXING_MODEL_SERVER_HOST, server_port=INDEXING_MODEL_SERVER_PORT, @@ -70,9 +73,16 @@ class DefaultIndexingEmbedder(IndexingEmbedder): passage_prefix: str | None, provider_type: EmbeddingProvider | None = None, api_key: str | None = None, + api_url: str | None = None, ): super().__init__( - model_name, normalize, query_prefix, passage_prefix, provider_type, api_key + model_name, + normalize, + query_prefix, + passage_prefix, + provider_type, + api_key, + api_url, ) @log_function_time() @@ -156,7 +166,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder): title_embed_dict[title] = title_embedding new_embedded_chunk = IndexChunk( - **chunk.model_dump(), + **chunk.dict(), embeddings=ChunkEmbedding( full_embedding=chunk_embeddings[0], mini_chunk_embeddings=chunk_embeddings[1:], @@ -179,6 +189,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder): passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=search_settings.api_key, + api_url=search_settings.api_url, ) @@ -202,4 +213,5 @@ def get_embedding_model_from_search_settings( passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=search_settings.api_key, + api_url=search_settings.api_url, ) diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index b23de0eb47..c468b9fb18 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -99,6 +99,7 @@ class EmbeddingModelDetail(BaseModel): normalize: bool query_prefix: str | None passage_prefix: str | None + api_url: str | None = None provider_type: EmbeddingProvider | None = None api_key: str | None = None @@ -117,6 +118,7 @@ class EmbeddingModelDetail(BaseModel): passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=search_settings.api_key, + api_url=search_settings.api_url, ) diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index b7835c4e90..d2ab3a582b 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -90,6 +90,7 @@ class EmbeddingModel: query_prefix: str | None, passage_prefix: str | None, api_key: str | None, + api_url: str | None, provider_type: EmbeddingProvider | None, retrim_content: bool = False, ) -> None: @@ -100,6 +101,7 @@ class EmbeddingModel: self.normalize = normalize self.model_name = model_name self.retrim_content = retrim_content + self.api_url = api_url self.tokenizer = get_tokenizer( model_name=model_name, provider_type=provider_type ) @@ -157,6 +159,7 @@ class EmbeddingModel: text_type=text_type, manual_query_prefix=self.query_prefix, manual_passage_prefix=self.passage_prefix, + api_url=self.api_url, ) response = self._make_model_server_request(embed_request) @@ -226,6 +229,7 @@ class EmbeddingModel: passage_prefix=search_settings.passage_prefix, api_key=search_settings.api_key, provider_type=search_settings.provider_type, + api_url=search_settings.api_url, retrim_content=retrim_content, ) diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index 15387e6c63..e9201c9705 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -81,6 +81,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting): num_rerank=search_settings.num_rerank, # Multilingual Expansion multilingual_expansion=search_settings.multilingual_expansion, + api_url=search_settings.api_url, ) diff --git a/backend/danswer/server/manage/embedding/api.py b/backend/danswer/server/manage/embedding/api.py index 90fa69401c..2cee962ee6 100644 --- a/backend/danswer/server/manage/embedding/api.py +++ b/backend/danswer/server/manage/embedding/api.py @@ -9,7 +9,9 @@ from danswer.db.llm import fetch_existing_embedding_providers from danswer.db.llm import remove_embedding_provider from danswer.db.llm import upsert_cloud_embedding_provider from danswer.db.models import User +from danswer.db.search_settings import get_all_search_settings from danswer.db.search_settings import get_current_db_embedding_provider +from danswer.indexing.models import EmbeddingModelDetail from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.server.manage.embedding.models import CloudEmbeddingProvider from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest @@ -20,6 +22,7 @@ from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType + logger = setup_logger() @@ -37,6 +40,7 @@ def test_embedding_configuration( server_host=MODEL_SERVER_HOST, server_port=MODEL_SERVER_PORT, api_key=test_llm_request.api_key, + api_url=test_llm_request.api_url, provider_type=test_llm_request.provider_type, normalize=False, query_prefix=None, @@ -56,6 +60,15 @@ def test_embedding_configuration( raise HTTPException(status_code=400, detail=error_msg) +@admin_router.get("", response_model=list[EmbeddingModelDetail]) +def list_embedding_models( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[EmbeddingModelDetail]: + search_settings = get_all_search_settings(db_session) + return [EmbeddingModelDetail.from_db_model(setting) for setting in search_settings] + + @admin_router.get("/embedding-provider") def list_embedding_providers( _: User | None = Depends(current_admin_user), diff --git a/backend/danswer/server/manage/embedding/models.py b/backend/danswer/server/manage/embedding/models.py index 132d311413..50518e6ec0 100644 --- a/backend/danswer/server/manage/embedding/models.py +++ b/backend/danswer/server/manage/embedding/models.py @@ -11,11 +11,13 @@ if TYPE_CHECKING: class TestEmbeddingRequest(BaseModel): provider_type: EmbeddingProvider api_key: str | None = None + api_url: str | None = None class CloudEmbeddingProvider(BaseModel): provider_type: EmbeddingProvider api_key: str | None = None + api_url: str | None = None @classmethod def from_request( @@ -24,9 +26,11 @@ class CloudEmbeddingProvider(BaseModel): return cls( provider_type=cloud_provider_model.provider_type, api_key=cloud_provider_model.api_key, + api_url=cloud_provider_model.api_url, ) class CloudEmbeddingProviderCreationRequest(BaseModel): provider_type: EmbeddingProvider api_key: str | None = None + api_url: str | None = None diff --git a/backend/danswer/server/manage/search_settings.py b/backend/danswer/server/manage/search_settings.py index db483eff5d..831528b815 100644 --- a/backend/danswer/server/manage/search_settings.py +++ b/backend/danswer/server/manage/search_settings.py @@ -45,7 +45,7 @@ def set_new_search_settings( if search_settings_new.index_name: logger.warning("Index name was specified by request, this is not suggested") - # Validate cloud provider exists + # Validate cloud provider exists or create new LiteLLM provider if search_settings_new.provider_type is not None: cloud_provider = get_embedding_provider_from_provider_type( db_session, provider_type=search_settings_new.provider_type @@ -133,7 +133,7 @@ def cancel_new_embedding( @router.get("/get-current-search-settings") -def get_curr_search_settings( +def get_current_search_settings_endpoint( _: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> SavedSearchSettings: @@ -142,7 +142,7 @@ def get_curr_search_settings( @router.get("/get-secondary-search-settings") -def get_sec_search_settings( +def get_secondary_search_settings_endpoint( _: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> SavedSearchSettings | None: diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 4e97bd00f2..ad9d8582be 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -2,6 +2,7 @@ import json from typing import Any from typing import Optional +import httpx import openai import vertexai # type: ignore import voyageai # type: ignore @@ -235,6 +236,22 @@ def get_local_reranking_model( return _RERANK_MODEL +def embed_with_litellm_proxy( + texts: list[str], api_url: str, model: str +) -> list[Embedding]: + with httpx.Client() as client: + response = client.post( + api_url, + json={ + "model": model, + "input": texts, + }, + ) + response.raise_for_status() + result = response.json() + return [embedding["embedding"] for embedding in result["data"]] + + @simple_log_function_time() def embed_text( texts: list[str], @@ -245,21 +262,37 @@ def embed_text( api_key: str | None, provider_type: EmbeddingProvider | None, prefix: str | None, + api_url: str | None, ) -> list[Embedding]: + logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}") + if not all(texts): + logger.error("Empty strings provided for embedding") raise ValueError("Empty strings are not allowed for embedding.") - # Third party API based embedding model if not texts: + logger.error("No texts provided for embedding") raise ValueError("No texts provided for embedding.") + + if provider_type == EmbeddingProvider.LITELLM: + logger.debug(f"Using LiteLLM proxy for embedding with URL: {api_url}") + if not api_url: + logger.error("API URL not provided for LiteLLM proxy") + raise ValueError("API URL is required for LiteLLM proxy embedding.") + try: + return embed_with_litellm_proxy(texts, api_url, model_name or "") + except Exception as e: + logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}") + raise + elif provider_type is not None: - logger.debug(f"Embedding text with provider: {provider_type}") + logger.debug(f"Using cloud provider {provider_type} for embedding") if api_key is None: + logger.error("API key not provided for cloud model") raise RuntimeError("API key not provided for cloud model") if prefix: - # This may change in the future if some providers require the user - # to manually append a prefix but this is not the case currently + logger.warning("Prefix provided for cloud model, which is not supported") raise ValueError( "Prefix string is not valid for cloud models. " "Cloud models take an explicit text type instead." @@ -274,14 +307,15 @@ def embed_text( text_type=text_type, ) - # Check for None values in embeddings if any(embedding is None for embedding in embeddings): error_message = "Embeddings contain None values\n" error_message += "Corresponding texts:\n" error_message += "\n".join(texts) + logger.error(error_message) raise ValueError(error_message) elif model_name is not None: + logger.debug(f"Using local model {model_name} for embedding") prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts local_model = get_embedding_model( @@ -296,10 +330,12 @@ def embed_text( ] else: + logger.error("Neither model name nor provider specified for embedding") raise ValueError( "Either model name or provider must be provided to run embeddings." ) + logger.info(f"Successfully embedded {len(texts)} texts") return embeddings @@ -344,6 +380,7 @@ async def process_embed_request( api_key=embed_request.api_key, provider_type=embed_request.provider_type, text_type=embed_request.text_type, + api_url=embed_request.api_url, prefix=prefix, ) return EmbedResponse(embeddings=embeddings) diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index 5ad36cc93c..2357d96d95 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -61,6 +61,7 @@ PRESERVED_SEARCH_FIELDS = [ "provider_type", "api_key", "model_name", + "api_url", "index_name", "multipass_indexing", "model_dim", diff --git a/backend/shared_configs/enums.py b/backend/shared_configs/enums.py index 918872d44b..4dccd43e0a 100644 --- a/backend/shared_configs/enums.py +++ b/backend/shared_configs/enums.py @@ -6,6 +6,7 @@ class EmbeddingProvider(str, Enum): COHERE = "cohere" VOYAGE = "voyage" GOOGLE = "google" + LITELLM = "litellm" class RerankerProvider(str, Enum): diff --git a/backend/shared_configs/model_server_models.py b/backend/shared_configs/model_server_models.py index 3014616c62..4be72308e7 100644 --- a/backend/shared_configs/model_server_models.py +++ b/backend/shared_configs/model_server_models.py @@ -18,6 +18,7 @@ class EmbedRequest(BaseModel): text_type: EmbedTextType manual_query_prefix: str | None = None manual_passage_prefix: str | None = None + api_url: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} diff --git a/backend/tests/daily/embedding/test_embeddings.py b/backend/tests/daily/embedding/test_embeddings.py index a9c12b236c..b736f37474 100644 --- a/backend/tests/daily/embedding/test_embeddings.py +++ b/backend/tests/daily/embedding/test_embeddings.py @@ -32,6 +32,7 @@ def openai_embedding_model() -> EmbeddingModel: passage_prefix=None, api_key=os.getenv("OPENAI_API_KEY"), provider_type=EmbeddingProvider.OPENAI, + api_url=None, ) @@ -51,6 +52,7 @@ def cohere_embedding_model() -> EmbeddingModel: passage_prefix=None, api_key=os.getenv("COHERE_API_KEY"), provider_type=EmbeddingProvider.COHERE, + api_url=None, ) @@ -70,6 +72,7 @@ def local_nomic_embedding_model() -> EmbeddingModel: passage_prefix="search_document: ", api_key=None, provider_type=None, + api_url=None, ) diff --git a/web/public/LiteLLM.jpg b/web/public/LiteLLM.jpg new file mode 100644 index 0000000000..d6a77b2d10 Binary files /dev/null and b/web/public/LiteLLM.jpg differ diff --git a/web/src/app/admin/configuration/llm/constants.ts b/web/src/app/admin/configuration/llm/constants.ts index a265f4a2b2..d7e3449b34 100644 --- a/web/src/app/admin/configuration/llm/constants.ts +++ b/web/src/app/admin/configuration/llm/constants.ts @@ -2,3 +2,5 @@ export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider"; export const EMBEDDING_PROVIDERS_ADMIN_URL = "/api/admin/embedding/embedding-provider"; + +export const EMBEDDING_MODELS_ADMIN_URL = "/api/admin/embedding"; diff --git a/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx b/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx index 1b9fffda42..bd6d5760ef 100644 --- a/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx +++ b/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx @@ -24,10 +24,14 @@ import { ChangeCredentialsModal } from "./modals/ChangeCredentialsModal"; import { ModelSelectionConfirmationModal } from "./modals/ModelSelectionModal"; import { AlreadyPickedModal } from "./modals/AlreadyPickedModal"; import { ModelOption } from "../../../components/embedding/ModelSelector"; -import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../configuration/llm/constants"; +import { + EMBEDDING_MODELS_ADMIN_URL, + EMBEDDING_PROVIDERS_ADMIN_URL, +} from "../configuration/llm/constants"; export interface EmbeddingDetails { - api_key: string; + api_key?: string; + api_url?: string; custom_config: any; provider_type: EmbeddingProvider; } @@ -77,12 +81,20 @@ export function EmbeddingModelSelection({ const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] = useState(false); + const [showAddConnectorPopup, setShowAddConnectorPopup] = useState(false); + const { data: embeddingModelDetails } = useSWR( + EMBEDDING_MODELS_ADMIN_URL, + errorHandlingFetcher, + { refreshInterval: 5000 } // 5 seconds + ); + const { data: embeddingProviderDetails } = useSWR( EMBEDDING_PROVIDERS_ADMIN_URL, - errorHandlingFetcher + errorHandlingFetcher, + { refreshInterval: 5000 } // 5 seconds ); const { data: connectors } = useSWR[]>( @@ -175,6 +187,7 @@ export function EmbeddingModelSelection({ {showTentativeProvider && ( { setShowTentativeProvider(showUnconfiguredProvider); @@ -189,8 +202,10 @@ export function EmbeddingModelSelection({ }} /> )} + {changeCredentialsProvider && ( { clientsideRemoveProvider(changeCredentialsProvider); @@ -277,6 +292,7 @@ export function EmbeddingModelSelection({ {modelTab == "cloud" && ( void; onCancel: () => void; onDeleted: () => void; useFileUpload: boolean; + isProxy?: boolean; }) { - const [apiKey, setApiKey] = useState(""); + const [apiKeyOrUrl, setApiKeyOrUrl] = useState(""); const [testError, setTestError] = useState(""); const [fileName, setFileName] = useState(""); const fileInputRef = useRef(null); @@ -50,7 +52,7 @@ export function ChangeCredentialsModal({ let jsonContent; try { jsonContent = JSON.parse(fileContent); - setApiKey(JSON.stringify(jsonContent)); + setApiKeyOrUrl(JSON.stringify(jsonContent)); } catch (parseError) { throw new Error( "Failed to parse JSON file. Please ensure it's a valid JSON." @@ -62,7 +64,7 @@ export function ChangeCredentialsModal({ ? error.message : "An unknown error occurred while processing the file." ); - setApiKey(""); + setApiKeyOrUrl(""); clearFileInput(); } } @@ -74,7 +76,7 @@ export function ChangeCredentialsModal({ try { const response = await fetch( - `${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type}`, + `${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type.toLowerCase()}`, { method: "DELETE", } @@ -105,7 +107,10 @@ export function ChangeCredentialsModal({ headers: { "Content-Type": "application/json" }, body: JSON.stringify({ provider_type: provider.provider_type.toLowerCase().split(" ")[0], - api_key: apiKey, + [isProxy ? "api_url" : "api_key"]: apiKeyOrUrl, + [isProxy ? "api_key" : "api_url"]: isProxy + ? provider.api_key + : provider.api_url, }), }); @@ -119,7 +124,7 @@ export function ChangeCredentialsModal({ headers: { "Content-Type": "application/json" }, body: JSON.stringify({ provider_type: provider.provider_type.toLowerCase().split(" ")[0], - api_key: apiKey, + [isProxy ? "api_url" : "api_key"]: apiKeyOrUrl, is_default_provider: false, is_configured: true, }), @@ -128,7 +133,8 @@ export function ChangeCredentialsModal({ if (!updateResponse.ok) { const errorData = await updateResponse.json(); throw new Error( - errorData.detail || "Failed to update provider- check your API key" + errorData.detail || + `Failed to update provider- check your ${isProxy ? "API URL" : "API key"}` ); } @@ -144,12 +150,12 @@ export function ChangeCredentialsModal({ - You can also delete your key. + You can also delete your {isProxy ? "URL" : "key"}. This is only possible if you have already switched to a different @@ -219,7 +225,7 @@ export function ChangeCredentialsModal({ {deletionError && ( diff --git a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx index ab9fa663ee..b4aa909aea 100644 --- a/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx +++ b/web/src/app/admin/embeddings/modals/ProviderCreationModal.tsx @@ -13,11 +13,13 @@ export function ProviderCreationModal({ onConfirm, onCancel, existingProvider, + isProxy, }: { selectedProvider: CloudEmbeddingProvider; onConfirm: () => void; onCancel: () => void; existingProvider?: CloudEmbeddingProvider; + isProxy?: boolean; }) { const useFileUpload = selectedProvider.provider_type == "Google"; @@ -29,6 +31,7 @@ export function ProviderCreationModal({ provider_type: existingProvider?.provider_type || selectedProvider.provider_type, api_key: existingProvider?.api_key || "", + api_url: existingProvider?.api_url || "", custom_config: existingProvider?.custom_config ? Object.entries(existingProvider.custom_config) : [], @@ -37,9 +40,14 @@ export function ProviderCreationModal({ const validationSchema = Yup.object({ provider_type: Yup.string().required("Provider type is required"), - api_key: useFileUpload + api_key: isProxy ? Yup.string() - : Yup.string().required("API Key is required"), + : useFileUpload + ? Yup.string() + : Yup.string().required("API Key is required"), + api_url: isProxy + ? Yup.string().required("API URL is required") + : Yup.string(), custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)), }); @@ -87,6 +95,7 @@ export function ProviderCreationModal({ body: JSON.stringify({ provider_type: values.provider_type.toLowerCase().split(" ")[0], api_key: values.api_key, + api_url: values.api_url, }), } ); @@ -169,12 +178,19 @@ export function ProviderCreationModal({ target="_blank" href={selectedProvider.apiLink} > - API KEY + {isProxy ? "API URL" : "API KEY"}
- {useFileUpload ? ( + {isProxy ? ( + + ) : useFileUpload ? ( <> >; newUnenabledProviders: string[]; + embeddingModelDetails?: CloudEmbeddingModel[]; embeddingProviderDetails?: EmbeddingDetails[]; newEnabledProviders: string[]; setShowTentativeProvider: React.Dispatch< @@ -61,6 +66,17 @@ export default function CloudEmbeddingPage({ ))!), }) ); + const [liteLLMProvider, setLiteLLMProvider] = useState< + EmbeddingDetails | undefined + >(undefined); + + useEffect(() => { + const foundProvider = embeddingProviderDetails?.find( + (provider) => + provider.provider_type === EmbeddingProvider.LITELLM.toLowerCase() + ); + setLiteLLMProvider(foundProvider); + }, [embeddingProviderDetails]); return (
@@ -122,6 +138,127 @@ export default function CloudEmbeddingPage({
))} + + + Alternatively, you can use a self-hosted model using the LiteLLM + proxy. This allows you to leverage various LLM providers through a + unified interface that you control.{" "} + + Learn more about LiteLLM + + + +
+
+ {LITELLM_CLOUD_PROVIDER.icon({ size: 40 })} +

+ {LITELLM_CLOUD_PROVIDER.provider_type}{" "} + {LITELLM_CLOUD_PROVIDER.provider_type == "Cohere" && + "(recommended)"} +

+ + } + popupContent={ +
+
+ {LITELLM_CLOUD_PROVIDER.description} +
+
+ } + style="dark" + /> +
+
+ {!liteLLMProvider ? ( + + ) : ( + + )} + + {!liteLLMProvider && ( + +
+ + API URL Required + + + Before you can add models, you need to provide an API URL + for your LiteLLM proxy. Click the "Provide API + URL" button above to set up your LiteLLM configuration. + +
+ + + Once configured, you'll be able to add and manage + your LiteLLM models here. + +
+
+
+ )} + {liteLLMProvider && ( + <> +
+ {embeddingModelDetails + ?.filter( + (model) => + model.provider_type === + EmbeddingProvider.LITELLM.toLowerCase() + ) + .map((model) => ( + + ))} +
+ + + + + + )} +
+
); @@ -146,7 +283,9 @@ export function CloudModelCard({ React.SetStateAction >; }) { - const enabled = model.model_name === currentModel.model_name; + const enabled = + model.model_name === currentModel.model_name && + model.provider_type == currentModel.provider_type; return (

{model.description}

-
- ${model.pricePerMillion}/M tokens -
+ {model?.provider_type?.toLowerCase() != + EmbeddingProvider.LITELLM.toLowerCase() && ( +
+ ${model.pricePerMillion}/M tokens +
+ )}
-
+ )} diff --git a/web/src/components/embedding/LiteLLMModelForm.tsx b/web/src/components/embedding/LiteLLMModelForm.tsx new file mode 100644 index 0000000000..b84db4f906 --- /dev/null +++ b/web/src/components/embedding/LiteLLMModelForm.tsx @@ -0,0 +1,116 @@ +import { CloudEmbeddingModel, CloudEmbeddingProvider } from "./interfaces"; +import { Formik, Form } from "formik"; +import * as Yup from "yup"; +import { TextFormField, BooleanFormField } from "../admin/connectors/Field"; +import { Dispatch, SetStateAction } from "react"; +import { Button, Text } from "@tremor/react"; +import { EmbeddingDetails } from "@/app/admin/embeddings/EmbeddingModelSelectionForm"; + +export function LiteLLMModelForm({ + setShowTentativeModel, + currentValues, + provider, +}: { + setShowTentativeModel: Dispatch>; + currentValues: CloudEmbeddingModel | null; + provider: EmbeddingDetails; +}) { + return ( +
+ { + setShowTentativeModel(values as CloudEmbeddingModel); + }} + > + {({ isSubmitting }) => ( +
+ + Add a new model to LiteLLM proxy at {provider.api_url} + + + + + + + + + + + + + + )} +
+
+ ); +} diff --git a/web/src/components/embedding/interfaces.tsx b/web/src/components/embedding/interfaces.tsx index c719b7dc7b..0fafaa840c 100644 --- a/web/src/components/embedding/interfaces.tsx +++ b/web/src/components/embedding/interfaces.tsx @@ -2,6 +2,7 @@ import { CohereIcon, GoogleIcon, IconProps, + LiteLLMIcon, MicrosoftIcon, NomicIcon, OpenAIIcon, @@ -14,11 +15,13 @@ export enum EmbeddingProvider { COHERE = "Cohere", VOYAGE = "Voyage", GOOGLE = "Google", + LITELLM = "LiteLLM", } export interface CloudEmbeddingProvider { provider_type: EmbeddingProvider; api_key?: string; + api_url?: string; custom_config?: Record; docsLink?: string; @@ -44,6 +47,7 @@ export interface EmbeddingModelDescriptor { provider_type: string | null; description: string; api_key: string | null; + api_url: string | null; index_name: string | null; } @@ -70,7 +74,7 @@ export interface FullEmbeddingModelResponse { } export interface CloudEmbeddingProviderFull extends CloudEmbeddingProvider { - configured: boolean; + configured?: boolean; } export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ @@ -87,6 +91,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ index_name: "", provider_type: null, api_key: null, + api_url: null, }, { model_name: "intfloat/e5-base-v2", @@ -99,6 +104,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ passage_prefix: "passage: ", index_name: "", provider_type: null, + api_url: null, api_key: null, }, { @@ -113,6 +119,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ index_name: "", provider_type: null, api_key: null, + api_url: null, }, { model_name: "intfloat/multilingual-e5-base", @@ -126,6 +133,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ index_name: "", provider_type: null, api_key: null, + api_url: null, }, { model_name: "intfloat/multilingual-e5-small", @@ -139,9 +147,19 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ index_name: "", provider_type: null, api_key: null, + api_url: null, }, ]; +export const LITELLM_CLOUD_PROVIDER: CloudEmbeddingProvider = { + provider_type: EmbeddingProvider.LITELLM, + website: "https://github.com/BerriAI/litellm", + icon: LiteLLMIcon, + description: "Open-source library to call LLM APIs using OpenAI format", + apiLink: "https://docs.litellm.ai/docs/proxy/quick_start", + embedding_models: [], // No default embedding models +}; + export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ { provider_type: EmbeddingProvider.COHERE, @@ -169,6 +187,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, { model_name: "embed-english-light-v3.0", @@ -185,6 +204,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, ], }, @@ -213,6 +233,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ enabled: false, index_name: "", api_key: null, + api_url: null, }, { provider_type: EmbeddingProvider.OPENAI, @@ -229,6 +250,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ maxContext: 8191, index_name: "", api_key: null, + api_url: null, }, ], }, @@ -258,6 +280,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, { provider_type: EmbeddingProvider.GOOGLE, @@ -273,6 +296,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, ], }, @@ -301,6 +325,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, { provider_type: EmbeddingProvider.VOYAGE, @@ -317,6 +342,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ passage_prefix: "", index_name: "", api_key: null, + api_url: null, }, ], }, diff --git a/web/src/components/icons/icons.tsx b/web/src/components/icons/icons.tsx index 377307caef..b5e735b0e6 100644 --- a/web/src/components/icons/icons.tsx +++ b/web/src/components/icons/icons.tsx @@ -48,6 +48,7 @@ import jiraSVG from "../../../public/Jira.svg"; import confluenceSVG from "../../../public/Confluence.svg"; import openAISVG from "../../../public/Openai.svg"; import openSourceIcon from "../../../public/OpenSource.png"; +import litellmIcon from "../../../public/LiteLLM.jpg"; import awsWEBP from "../../../public/Amazon.webp"; import azureIcon from "../../../public/Azure.png"; @@ -267,6 +268,20 @@ export const ColorSlackIcon = ({ ); }; +export const LiteLLMIcon = ({ + size = 16, + className = defaultTailwindCSS, +}: IconProps) => { + return ( +
+ Logo +
+ ); +}; + export const OpenSourceIcon = ({ size = 16, className = defaultTailwindCSS,