Add litellm proxy embeddings (#2291)

* add litellm proxy

* formatting

* move `api_url` to cloud provider + nits

* remove log

* typing

* quick tuyping fix

* update LiteLLM selection logic

* remove logs + validate functionality

* rename proxy var

* update path casing

* remove pricing for custom models

* functional values
This commit is contained in:
pablodanswer 2024-09-02 09:08:35 -07:00 committed by GitHub
parent 910821c723
commit 299cb5035c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 524 additions and 50 deletions

View File

@ -0,0 +1,26 @@
"""Add base_url to CloudEmbeddingProvider
Revision ID: bceb1e139447
Revises: 1f60f60c3401
Create Date: 2024-08-28 17:00:52.554580
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "bceb1e139447"
down_revision = "1f60f60c3401"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"embedding_provider", sa.Column("api_url", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("embedding_provider", "api_url")

View File

@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import LLMProvider as LLMProviderModel from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import SearchSettings
from danswer.db.models import User from danswer.db.models import User
from danswer.db.models import User__UserGroup from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider from danswer.server.manage.embedding.models import CloudEmbeddingProvider
@ -50,6 +51,7 @@ def upsert_cloud_embedding_provider(
setattr(existing_provider, key, value) setattr(existing_provider, key, value)
else: else:
new_provider = CloudEmbeddingProviderModel(**provider.model_dump()) new_provider = CloudEmbeddingProviderModel(**provider.model_dump())
db_session.add(new_provider) db_session.add(new_provider)
existing_provider = new_provider existing_provider = new_provider
db_session.commit() db_session.commit()
@ -157,12 +159,19 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
def remove_embedding_provider( def remove_embedding_provider(
db_session: Session, provider_type: EmbeddingProvider db_session: Session, provider_type: EmbeddingProvider
) -> None: ) -> None:
db_session.execute(
delete(SearchSettings).where(SearchSettings.provider_type == provider_type)
)
# Delete the embedding provider
db_session.execute( db_session.execute(
delete(CloudEmbeddingProviderModel).where( delete(CloudEmbeddingProviderModel).where(
CloudEmbeddingProviderModel.provider_type == provider_type CloudEmbeddingProviderModel.provider_type == provider_type
) )
) )
db_session.commit()
def remove_llm_provider(db_session: Session, provider_id: int) -> None: def remove_llm_provider(db_session: Session, provider_id: int) -> None:
# Remove LLMProvider's dependent relationships # Remove LLMProvider's dependent relationships

View File

@ -607,6 +607,10 @@ class SearchSettings(Base):
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\ return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>" cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
@property
def api_url(self) -> str | None:
return self.cloud_provider.api_url if self.cloud_provider is not None else None
@property @property
def api_key(self) -> str | None: def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider is not None else None return self.cloud_provider.api_key if self.cloud_provider is not None else None
@ -1085,6 +1089,7 @@ class CloudEmbeddingProvider(Base):
provider_type: Mapped[EmbeddingProvider] = mapped_column( provider_type: Mapped[EmbeddingProvider] = mapped_column(
Enum(EmbeddingProvider), primary_key=True Enum(EmbeddingProvider), primary_key=True
) )
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString()) api_key: Mapped[str | None] = mapped_column(EncryptedString())
search_settings: Mapped[list["SearchSettings"]] = relationship( search_settings: Mapped[list["SearchSettings"]] = relationship(
"SearchSettings", "SearchSettings",

View File

@ -115,6 +115,13 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
return latest_settings return latest_settings
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
query = select(SearchSettings).order_by(SearchSettings.id.desc())
result = db_session.execute(query)
all_settings = result.scalars().all()
return list(all_settings)
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]: def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None: if db_session is None:
with Session(get_sqlalchemy_engine()) as db_session: with Session(get_sqlalchemy_engine()) as db_session:
@ -234,6 +241,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""), passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
index_name="danswer_chunk", index_name="danswer_chunk",
multipass_indexing=False, multipass_indexing=False,
api_url=None,
) )
@ -246,4 +254,5 @@ def get_new_default_embedding_model() -> IndexingSetting:
passage_prefix=ASYM_PASSAGE_PREFIX, passage_prefix=ASYM_PASSAGE_PREFIX,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}", index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
multipass_indexing=False, multipass_indexing=False,
api_url=None,
) )

View File

@ -32,6 +32,7 @@ class IndexingEmbedder(ABC):
passage_prefix: str | None, passage_prefix: str | None,
provider_type: EmbeddingProvider | None, provider_type: EmbeddingProvider | None,
api_key: str | None, api_key: str | None,
api_url: str | None,
): ):
self.model_name = model_name self.model_name = model_name
self.normalize = normalize self.normalize = normalize
@ -39,6 +40,7 @@ class IndexingEmbedder(ABC):
self.passage_prefix = passage_prefix self.passage_prefix = passage_prefix
self.provider_type = provider_type self.provider_type = provider_type
self.api_key = api_key self.api_key = api_key
self.api_url = api_url
self.embedding_model = EmbeddingModel( self.embedding_model = EmbeddingModel(
model_name=model_name, model_name=model_name,
@ -47,6 +49,7 @@ class IndexingEmbedder(ABC):
normalize=normalize, normalize=normalize,
api_key=api_key, api_key=api_key,
provider_type=provider_type, provider_type=provider_type,
api_url=api_url,
# The below are globally set, this flow always uses the indexing one # The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST, server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT, server_port=INDEXING_MODEL_SERVER_PORT,
@ -70,9 +73,16 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
passage_prefix: str | None, passage_prefix: str | None,
provider_type: EmbeddingProvider | None = None, provider_type: EmbeddingProvider | None = None,
api_key: str | None = None, api_key: str | None = None,
api_url: str | None = None,
): ):
super().__init__( super().__init__(
model_name, normalize, query_prefix, passage_prefix, provider_type, api_key model_name,
normalize,
query_prefix,
passage_prefix,
provider_type,
api_key,
api_url,
) )
@log_function_time() @log_function_time()
@ -156,7 +166,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embed_dict[title] = title_embedding title_embed_dict[title] = title_embedding
new_embedded_chunk = IndexChunk( new_embedded_chunk = IndexChunk(
**chunk.model_dump(), **chunk.dict(),
embeddings=ChunkEmbedding( embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0], full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:], mini_chunk_embeddings=chunk_embeddings[1:],
@ -179,6 +189,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
passage_prefix=search_settings.passage_prefix, passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type, provider_type=search_settings.provider_type,
api_key=search_settings.api_key, api_key=search_settings.api_key,
api_url=search_settings.api_url,
) )
@ -202,4 +213,5 @@ def get_embedding_model_from_search_settings(
passage_prefix=search_settings.passage_prefix, passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type, provider_type=search_settings.provider_type,
api_key=search_settings.api_key, api_key=search_settings.api_key,
api_url=search_settings.api_url,
) )

View File

@ -99,6 +99,7 @@ class EmbeddingModelDetail(BaseModel):
normalize: bool normalize: bool
query_prefix: str | None query_prefix: str | None
passage_prefix: str | None passage_prefix: str | None
api_url: str | None = None
provider_type: EmbeddingProvider | None = None provider_type: EmbeddingProvider | None = None
api_key: str | None = None api_key: str | None = None
@ -117,6 +118,7 @@ class EmbeddingModelDetail(BaseModel):
passage_prefix=search_settings.passage_prefix, passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type, provider_type=search_settings.provider_type,
api_key=search_settings.api_key, api_key=search_settings.api_key,
api_url=search_settings.api_url,
) )

View File

@ -90,6 +90,7 @@ class EmbeddingModel:
query_prefix: str | None, query_prefix: str | None,
passage_prefix: str | None, passage_prefix: str | None,
api_key: str | None, api_key: str | None,
api_url: str | None,
provider_type: EmbeddingProvider | None, provider_type: EmbeddingProvider | None,
retrim_content: bool = False, retrim_content: bool = False,
) -> None: ) -> None:
@ -100,6 +101,7 @@ class EmbeddingModel:
self.normalize = normalize self.normalize = normalize
self.model_name = model_name self.model_name = model_name
self.retrim_content = retrim_content self.retrim_content = retrim_content
self.api_url = api_url
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type model_name=model_name, provider_type=provider_type
) )
@ -157,6 +159,7 @@ class EmbeddingModel:
text_type=text_type, text_type=text_type,
manual_query_prefix=self.query_prefix, manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix, manual_passage_prefix=self.passage_prefix,
api_url=self.api_url,
) )
response = self._make_model_server_request(embed_request) response = self._make_model_server_request(embed_request)
@ -226,6 +229,7 @@ class EmbeddingModel:
passage_prefix=search_settings.passage_prefix, passage_prefix=search_settings.passage_prefix,
api_key=search_settings.api_key, api_key=search_settings.api_key,
provider_type=search_settings.provider_type, provider_type=search_settings.provider_type,
api_url=search_settings.api_url,
retrim_content=retrim_content, retrim_content=retrim_content,
) )

View File

@ -81,6 +81,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
num_rerank=search_settings.num_rerank, num_rerank=search_settings.num_rerank,
# Multilingual Expansion # Multilingual Expansion
multilingual_expansion=search_settings.multilingual_expansion, multilingual_expansion=search_settings.multilingual_expansion,
api_url=search_settings.api_url,
) )

View File

@ -9,7 +9,9 @@ from danswer.db.llm import fetch_existing_embedding_providers
from danswer.db.llm import remove_embedding_provider from danswer.db.llm import remove_embedding_provider
from danswer.db.llm import upsert_cloud_embedding_provider from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.models import User from danswer.db.models import User
from danswer.db.search_settings import get_all_search_settings
from danswer.db.search_settings import get_current_db_embedding_provider from danswer.db.search_settings import get_current_db_embedding_provider
from danswer.indexing.models import EmbeddingModelDetail
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.server.manage.embedding.models import CloudEmbeddingProvider from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
@ -20,6 +22,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType from shared_configs.enums import EmbedTextType
logger = setup_logger() logger = setup_logger()
@ -37,6 +40,7 @@ def test_embedding_configuration(
server_host=MODEL_SERVER_HOST, server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT, server_port=MODEL_SERVER_PORT,
api_key=test_llm_request.api_key, api_key=test_llm_request.api_key,
api_url=test_llm_request.api_url,
provider_type=test_llm_request.provider_type, provider_type=test_llm_request.provider_type,
normalize=False, normalize=False,
query_prefix=None, query_prefix=None,
@ -56,6 +60,15 @@ def test_embedding_configuration(
raise HTTPException(status_code=400, detail=error_msg) raise HTTPException(status_code=400, detail=error_msg)
@admin_router.get("", response_model=list[EmbeddingModelDetail])
def list_embedding_models(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[EmbeddingModelDetail]:
search_settings = get_all_search_settings(db_session)
return [EmbeddingModelDetail.from_db_model(setting) for setting in search_settings]
@admin_router.get("/embedding-provider") @admin_router.get("/embedding-provider")
def list_embedding_providers( def list_embedding_providers(
_: User | None = Depends(current_admin_user), _: User | None = Depends(current_admin_user),

View File

@ -11,11 +11,13 @@ if TYPE_CHECKING:
class TestEmbeddingRequest(BaseModel): class TestEmbeddingRequest(BaseModel):
provider_type: EmbeddingProvider provider_type: EmbeddingProvider
api_key: str | None = None api_key: str | None = None
api_url: str | None = None
class CloudEmbeddingProvider(BaseModel): class CloudEmbeddingProvider(BaseModel):
provider_type: EmbeddingProvider provider_type: EmbeddingProvider
api_key: str | None = None api_key: str | None = None
api_url: str | None = None
@classmethod @classmethod
def from_request( def from_request(
@ -24,9 +26,11 @@ class CloudEmbeddingProvider(BaseModel):
return cls( return cls(
provider_type=cloud_provider_model.provider_type, provider_type=cloud_provider_model.provider_type,
api_key=cloud_provider_model.api_key, api_key=cloud_provider_model.api_key,
api_url=cloud_provider_model.api_url,
) )
class CloudEmbeddingProviderCreationRequest(BaseModel): class CloudEmbeddingProviderCreationRequest(BaseModel):
provider_type: EmbeddingProvider provider_type: EmbeddingProvider
api_key: str | None = None api_key: str | None = None
api_url: str | None = None

View File

@ -45,7 +45,7 @@ def set_new_search_settings(
if search_settings_new.index_name: if search_settings_new.index_name:
logger.warning("Index name was specified by request, this is not suggested") logger.warning("Index name was specified by request, this is not suggested")
# Validate cloud provider exists # Validate cloud provider exists or create new LiteLLM provider
if search_settings_new.provider_type is not None: if search_settings_new.provider_type is not None:
cloud_provider = get_embedding_provider_from_provider_type( cloud_provider = get_embedding_provider_from_provider_type(
db_session, provider_type=search_settings_new.provider_type db_session, provider_type=search_settings_new.provider_type
@ -133,7 +133,7 @@ def cancel_new_embedding(
@router.get("/get-current-search-settings") @router.get("/get-current-search-settings")
def get_curr_search_settings( def get_current_search_settings_endpoint(
_: User | None = Depends(current_user), _: User | None = Depends(current_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> SavedSearchSettings: ) -> SavedSearchSettings:
@ -142,7 +142,7 @@ def get_curr_search_settings(
@router.get("/get-secondary-search-settings") @router.get("/get-secondary-search-settings")
def get_sec_search_settings( def get_secondary_search_settings_endpoint(
_: User | None = Depends(current_user), _: User | None = Depends(current_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> SavedSearchSettings | None: ) -> SavedSearchSettings | None:

View File

@ -2,6 +2,7 @@ import json
from typing import Any from typing import Any
from typing import Optional from typing import Optional
import httpx
import openai import openai
import vertexai # type: ignore import vertexai # type: ignore
import voyageai # type: ignore import voyageai # type: ignore
@ -235,6 +236,22 @@ def get_local_reranking_model(
return _RERANK_MODEL return _RERANK_MODEL
def embed_with_litellm_proxy(
texts: list[str], api_url: str, model: str
) -> list[Embedding]:
with httpx.Client() as client:
response = client.post(
api_url,
json={
"model": model,
"input": texts,
},
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
@simple_log_function_time() @simple_log_function_time()
def embed_text( def embed_text(
texts: list[str], texts: list[str],
@ -245,21 +262,37 @@ def embed_text(
api_key: str | None, api_key: str | None,
provider_type: EmbeddingProvider | None, provider_type: EmbeddingProvider | None,
prefix: str | None, prefix: str | None,
api_url: str | None,
) -> list[Embedding]: ) -> list[Embedding]:
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
if not all(texts): if not all(texts):
logger.error("Empty strings provided for embedding")
raise ValueError("Empty strings are not allowed for embedding.") raise ValueError("Empty strings are not allowed for embedding.")
# Third party API based embedding model
if not texts: if not texts:
logger.error("No texts provided for embedding")
raise ValueError("No texts provided for embedding.") raise ValueError("No texts provided for embedding.")
if provider_type == EmbeddingProvider.LITELLM:
logger.debug(f"Using LiteLLM proxy for embedding with URL: {api_url}")
if not api_url:
logger.error("API URL not provided for LiteLLM proxy")
raise ValueError("API URL is required for LiteLLM proxy embedding.")
try:
return embed_with_litellm_proxy(texts, api_url, model_name or "")
except Exception as e:
logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}")
raise
elif provider_type is not None: elif provider_type is not None:
logger.debug(f"Embedding text with provider: {provider_type}") logger.debug(f"Using cloud provider {provider_type} for embedding")
if api_key is None: if api_key is None:
logger.error("API key not provided for cloud model")
raise RuntimeError("API key not provided for cloud model") raise RuntimeError("API key not provided for cloud model")
if prefix: if prefix:
# This may change in the future if some providers require the user logger.warning("Prefix provided for cloud model, which is not supported")
# to manually append a prefix but this is not the case currently
raise ValueError( raise ValueError(
"Prefix string is not valid for cloud models. " "Prefix string is not valid for cloud models. "
"Cloud models take an explicit text type instead." "Cloud models take an explicit text type instead."
@ -274,14 +307,15 @@ def embed_text(
text_type=text_type, text_type=text_type,
) )
# Check for None values in embeddings
if any(embedding is None for embedding in embeddings): if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n" error_message = "Embeddings contain None values\n"
error_message += "Corresponding texts:\n" error_message += "Corresponding texts:\n"
error_message += "\n".join(texts) error_message += "\n".join(texts)
logger.error(error_message)
raise ValueError(error_message) raise ValueError(error_message)
elif model_name is not None: elif model_name is not None:
logger.debug(f"Using local model {model_name} for embedding")
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
local_model = get_embedding_model( local_model = get_embedding_model(
@ -296,10 +330,12 @@ def embed_text(
] ]
else: else:
logger.error("Neither model name nor provider specified for embedding")
raise ValueError( raise ValueError(
"Either model name or provider must be provided to run embeddings." "Either model name or provider must be provided to run embeddings."
) )
logger.info(f"Successfully embedded {len(texts)} texts")
return embeddings return embeddings
@ -344,6 +380,7 @@ async def process_embed_request(
api_key=embed_request.api_key, api_key=embed_request.api_key,
provider_type=embed_request.provider_type, provider_type=embed_request.provider_type,
text_type=embed_request.text_type, text_type=embed_request.text_type,
api_url=embed_request.api_url,
prefix=prefix, prefix=prefix,
) )
return EmbedResponse(embeddings=embeddings) return EmbedResponse(embeddings=embeddings)

View File

@ -61,6 +61,7 @@ PRESERVED_SEARCH_FIELDS = [
"provider_type", "provider_type",
"api_key", "api_key",
"model_name", "model_name",
"api_url",
"index_name", "index_name",
"multipass_indexing", "multipass_indexing",
"model_dim", "model_dim",

View File

@ -6,6 +6,7 @@ class EmbeddingProvider(str, Enum):
COHERE = "cohere" COHERE = "cohere"
VOYAGE = "voyage" VOYAGE = "voyage"
GOOGLE = "google" GOOGLE = "google"
LITELLM = "litellm"
class RerankerProvider(str, Enum): class RerankerProvider(str, Enum):

View File

@ -18,6 +18,7 @@ class EmbedRequest(BaseModel):
text_type: EmbedTextType text_type: EmbedTextType
manual_query_prefix: str | None = None manual_query_prefix: str | None = None
manual_passage_prefix: str | None = None manual_passage_prefix: str | None = None
api_url: str | None = None
# This disables the "model_" protected namespace for pydantic # This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()} model_config = {"protected_namespaces": ()}

View File

@ -32,6 +32,7 @@ def openai_embedding_model() -> EmbeddingModel:
passage_prefix=None, passage_prefix=None,
api_key=os.getenv("OPENAI_API_KEY"), api_key=os.getenv("OPENAI_API_KEY"),
provider_type=EmbeddingProvider.OPENAI, provider_type=EmbeddingProvider.OPENAI,
api_url=None,
) )
@ -51,6 +52,7 @@ def cohere_embedding_model() -> EmbeddingModel:
passage_prefix=None, passage_prefix=None,
api_key=os.getenv("COHERE_API_KEY"), api_key=os.getenv("COHERE_API_KEY"),
provider_type=EmbeddingProvider.COHERE, provider_type=EmbeddingProvider.COHERE,
api_url=None,
) )
@ -70,6 +72,7 @@ def local_nomic_embedding_model() -> EmbeddingModel:
passage_prefix="search_document: ", passage_prefix="search_document: ",
api_key=None, api_key=None,
provider_type=None, provider_type=None,
api_url=None,
) )

BIN
web/public/LiteLLM.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

View File

@ -2,3 +2,5 @@ export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
export const EMBEDDING_PROVIDERS_ADMIN_URL = export const EMBEDDING_PROVIDERS_ADMIN_URL =
"/api/admin/embedding/embedding-provider"; "/api/admin/embedding/embedding-provider";
export const EMBEDDING_MODELS_ADMIN_URL = "/api/admin/embedding";

View File

@ -24,10 +24,14 @@ import { ChangeCredentialsModal } from "./modals/ChangeCredentialsModal";
import { ModelSelectionConfirmationModal } from "./modals/ModelSelectionModal"; import { ModelSelectionConfirmationModal } from "./modals/ModelSelectionModal";
import { AlreadyPickedModal } from "./modals/AlreadyPickedModal"; import { AlreadyPickedModal } from "./modals/AlreadyPickedModal";
import { ModelOption } from "../../../components/embedding/ModelSelector"; import { ModelOption } from "../../../components/embedding/ModelSelector";
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../configuration/llm/constants"; import {
EMBEDDING_MODELS_ADMIN_URL,
EMBEDDING_PROVIDERS_ADMIN_URL,
} from "../configuration/llm/constants";
export interface EmbeddingDetails { export interface EmbeddingDetails {
api_key: string; api_key?: string;
api_url?: string;
custom_config: any; custom_config: any;
provider_type: EmbeddingProvider; provider_type: EmbeddingProvider;
} }
@ -77,12 +81,20 @@ export function EmbeddingModelSelection({
const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] = const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] =
useState<boolean>(false); useState<boolean>(false);
const [showAddConnectorPopup, setShowAddConnectorPopup] = const [showAddConnectorPopup, setShowAddConnectorPopup] =
useState<boolean>(false); useState<boolean>(false);
const { data: embeddingModelDetails } = useSWR<CloudEmbeddingModel[]>(
EMBEDDING_MODELS_ADMIN_URL,
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
const { data: embeddingProviderDetails } = useSWR<EmbeddingDetails[]>( const { data: embeddingProviderDetails } = useSWR<EmbeddingDetails[]>(
EMBEDDING_PROVIDERS_ADMIN_URL, EMBEDDING_PROVIDERS_ADMIN_URL,
errorHandlingFetcher errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
); );
const { data: connectors } = useSWR<Connector<any>[]>( const { data: connectors } = useSWR<Connector<any>[]>(
@ -175,6 +187,7 @@ export function EmbeddingModelSelection({
{showTentativeProvider && ( {showTentativeProvider && (
<ProviderCreationModal <ProviderCreationModal
isProxy={showTentativeProvider.provider_type == "LiteLLM"}
selectedProvider={showTentativeProvider} selectedProvider={showTentativeProvider}
onConfirm={() => { onConfirm={() => {
setShowTentativeProvider(showUnconfiguredProvider); setShowTentativeProvider(showUnconfiguredProvider);
@ -189,8 +202,10 @@ export function EmbeddingModelSelection({
}} }}
/> />
)} )}
{changeCredentialsProvider && ( {changeCredentialsProvider && (
<ChangeCredentialsModal <ChangeCredentialsModal
isProxy={changeCredentialsProvider.provider_type == "LiteLLM"}
useFileUpload={changeCredentialsProvider.provider_type == "Google"} useFileUpload={changeCredentialsProvider.provider_type == "Google"}
onDeleted={() => { onDeleted={() => {
clientsideRemoveProvider(changeCredentialsProvider); clientsideRemoveProvider(changeCredentialsProvider);
@ -277,6 +292,7 @@ export function EmbeddingModelSelection({
{modelTab == "cloud" && ( {modelTab == "cloud" && (
<CloudEmbeddingPage <CloudEmbeddingPage
embeddingModelDetails={embeddingModelDetails}
setShowModelInQueue={setShowModelInQueue} setShowModelInQueue={setShowModelInQueue}
setShowTentativeModel={setShowTentativeModel} setShowTentativeModel={setShowTentativeModel}
currentModel={selectedProvider} currentModel={selectedProvider}

View File

@ -21,6 +21,7 @@ export interface AdvancedSearchConfiguration {
multipass_indexing: boolean; multipass_indexing: boolean;
multilingual_expansion: string[]; multilingual_expansion: string[];
disable_rerank_for_streaming: boolean; disable_rerank_for_streaming: boolean;
api_url: string | null;
} }
export interface SavedSearchSettings extends RerankingDetails { export interface SavedSearchSettings extends RerankingDetails {
@ -33,6 +34,7 @@ export interface SavedSearchSettings extends RerankingDetails {
multipass_indexing: boolean; multipass_indexing: boolean;
multilingual_expansion: string[]; multilingual_expansion: string[];
disable_rerank_for_streaming: boolean; disable_rerank_for_streaming: boolean;
api_url: string | null;
provider_type: EmbeddingProvider | null; provider_type: EmbeddingProvider | null;
} }

View File

@ -15,14 +15,16 @@ export function ChangeCredentialsModal({
onCancel, onCancel,
onDeleted, onDeleted,
useFileUpload, useFileUpload,
isProxy = false,
}: { }: {
provider: CloudEmbeddingProvider; provider: CloudEmbeddingProvider;
onConfirm: () => void; onConfirm: () => void;
onCancel: () => void; onCancel: () => void;
onDeleted: () => void; onDeleted: () => void;
useFileUpload: boolean; useFileUpload: boolean;
isProxy?: boolean;
}) { }) {
const [apiKey, setApiKey] = useState(""); const [apiKeyOrUrl, setApiKeyOrUrl] = useState("");
const [testError, setTestError] = useState<string>(""); const [testError, setTestError] = useState<string>("");
const [fileName, setFileName] = useState<string>(""); const [fileName, setFileName] = useState<string>("");
const fileInputRef = useRef<HTMLInputElement>(null); const fileInputRef = useRef<HTMLInputElement>(null);
@ -50,7 +52,7 @@ export function ChangeCredentialsModal({
let jsonContent; let jsonContent;
try { try {
jsonContent = JSON.parse(fileContent); jsonContent = JSON.parse(fileContent);
setApiKey(JSON.stringify(jsonContent)); setApiKeyOrUrl(JSON.stringify(jsonContent));
} catch (parseError) { } catch (parseError) {
throw new Error( throw new Error(
"Failed to parse JSON file. Please ensure it's a valid JSON." "Failed to parse JSON file. Please ensure it's a valid JSON."
@ -62,7 +64,7 @@ export function ChangeCredentialsModal({
? error.message ? error.message
: "An unknown error occurred while processing the file." : "An unknown error occurred while processing the file."
); );
setApiKey(""); setApiKeyOrUrl("");
clearFileInput(); clearFileInput();
} }
} }
@ -74,7 +76,7 @@ export function ChangeCredentialsModal({
try { try {
const response = await fetch( const response = await fetch(
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type}`, `${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type.toLowerCase()}`,
{ {
method: "DELETE", method: "DELETE",
} }
@ -105,7 +107,10 @@ export function ChangeCredentialsModal({
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
body: JSON.stringify({ body: JSON.stringify({
provider_type: provider.provider_type.toLowerCase().split(" ")[0], provider_type: provider.provider_type.toLowerCase().split(" ")[0],
api_key: apiKey, [isProxy ? "api_url" : "api_key"]: apiKeyOrUrl,
[isProxy ? "api_key" : "api_url"]: isProxy
? provider.api_key
: provider.api_url,
}), }),
}); });
@ -119,7 +124,7 @@ export function ChangeCredentialsModal({
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
body: JSON.stringify({ body: JSON.stringify({
provider_type: provider.provider_type.toLowerCase().split(" ")[0], provider_type: provider.provider_type.toLowerCase().split(" ")[0],
api_key: apiKey, [isProxy ? "api_url" : "api_key"]: apiKeyOrUrl,
is_default_provider: false, is_default_provider: false,
is_configured: true, is_configured: true,
}), }),
@ -128,7 +133,8 @@ export function ChangeCredentialsModal({
if (!updateResponse.ok) { if (!updateResponse.ok) {
const errorData = await updateResponse.json(); const errorData = await updateResponse.json();
throw new Error( throw new Error(
errorData.detail || "Failed to update provider- check your API key" errorData.detail ||
`Failed to update provider- check your ${isProxy ? "API URL" : "API key"}`
); );
} }
@ -144,12 +150,12 @@ export function ChangeCredentialsModal({
<Modal <Modal
width="max-w-3xl" width="max-w-3xl"
icon={provider.icon} icon={provider.icon}
title={`Modify your ${provider.provider_type} key`} title={`Modify your ${provider.provider_type} ${isProxy ? "URL" : "key"}`}
onOutsideClick={onCancel} onOutsideClick={onCancel}
> >
<div className="mb-4"> <div className="mb-4">
<Subtitle className="font-bold text-lg"> <Subtitle className="font-bold text-lg">
Want to swap out your key? Want to swap out your {isProxy ? "URL" : "key"}?
</Subtitle> </Subtitle>
<a <a
href={provider.apiLink} href={provider.apiLink}
@ -185,9 +191,9 @@ export function ChangeCredentialsModal({
px-3 px-3
bg-background-emphasis bg-background-emphasis
`} `}
value={apiKey} value={apiKeyOrUrl}
onChange={(e: any) => setApiKey(e.target.value)} onChange={(e: any) => setApiKeyOrUrl(e.target.value)}
placeholder="Paste your API key here" placeholder={`Paste your ${isProxy ? "API URL" : "API key"} here`}
/> />
</> </>
)} )}
@ -203,15 +209,15 @@ export function ChangeCredentialsModal({
<Button <Button
color="blue" color="blue"
onClick={() => handleSubmit()} onClick={() => handleSubmit()}
disabled={!apiKey} disabled={!apiKeyOrUrl}
> >
Swap Key Swap {isProxy ? "URL" : "Key"}
</Button> </Button>
</div> </div>
<Divider /> <Divider />
<Subtitle className="mt-4 font-bold text-lg mb-2"> <Subtitle className="mt-4 font-bold text-lg mb-2">
You can also delete your key. You can also delete your {isProxy ? "URL" : "key"}.
</Subtitle> </Subtitle>
<Text className="mb-2"> <Text className="mb-2">
This is only possible if you have already switched to a different This is only possible if you have already switched to a different
@ -219,7 +225,7 @@ export function ChangeCredentialsModal({
</Text> </Text>
<Button onClick={handleDelete} color="red"> <Button onClick={handleDelete} color="red">
Delete key Delete {isProxy ? "URL" : "key"}
</Button> </Button>
{deletionError && ( {deletionError && (
<Callout title="Error" color="red" className="mt-4"> <Callout title="Error" color="red" className="mt-4">

View File

@ -13,11 +13,13 @@ export function ProviderCreationModal({
onConfirm, onConfirm,
onCancel, onCancel,
existingProvider, existingProvider,
isProxy,
}: { }: {
selectedProvider: CloudEmbeddingProvider; selectedProvider: CloudEmbeddingProvider;
onConfirm: () => void; onConfirm: () => void;
onCancel: () => void; onCancel: () => void;
existingProvider?: CloudEmbeddingProvider; existingProvider?: CloudEmbeddingProvider;
isProxy?: boolean;
}) { }) {
const useFileUpload = selectedProvider.provider_type == "Google"; const useFileUpload = selectedProvider.provider_type == "Google";
@ -29,6 +31,7 @@ export function ProviderCreationModal({
provider_type: provider_type:
existingProvider?.provider_type || selectedProvider.provider_type, existingProvider?.provider_type || selectedProvider.provider_type,
api_key: existingProvider?.api_key || "", api_key: existingProvider?.api_key || "",
api_url: existingProvider?.api_url || "",
custom_config: existingProvider?.custom_config custom_config: existingProvider?.custom_config
? Object.entries(existingProvider.custom_config) ? Object.entries(existingProvider.custom_config)
: [], : [],
@ -37,9 +40,14 @@ export function ProviderCreationModal({
const validationSchema = Yup.object({ const validationSchema = Yup.object({
provider_type: Yup.string().required("Provider type is required"), provider_type: Yup.string().required("Provider type is required"),
api_key: useFileUpload api_key: isProxy
? Yup.string() ? Yup.string()
: Yup.string().required("API Key is required"), : useFileUpload
? Yup.string()
: Yup.string().required("API Key is required"),
api_url: isProxy
? Yup.string().required("API URL is required")
: Yup.string(),
custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)), custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)),
}); });
@ -87,6 +95,7 @@ export function ProviderCreationModal({
body: JSON.stringify({ body: JSON.stringify({
provider_type: values.provider_type.toLowerCase().split(" ")[0], provider_type: values.provider_type.toLowerCase().split(" ")[0],
api_key: values.api_key, api_key: values.api_key,
api_url: values.api_url,
}), }),
} }
); );
@ -169,12 +178,19 @@ export function ProviderCreationModal({
target="_blank" target="_blank"
href={selectedProvider.apiLink} href={selectedProvider.apiLink}
> >
API KEY {isProxy ? "API URL" : "API KEY"}
</a> </a>
</Text> </Text>
<div className="flex w-full flex-col gap-y-2"> <div className="flex w-full flex-col gap-y-2">
{useFileUpload ? ( {isProxy ? (
<TextFormField
name="api_url"
label="API URL"
placeholder="API URL"
type="text"
/>
) : useFileUpload ? (
<> <>
<Label>Upload JSON File</Label> <Label>Upload JSON File</Label>
<input <input

View File

@ -1,6 +1,6 @@
"use client"; "use client";
import { Text, Title } from "@tremor/react"; import { Button, Card, Text, Title } from "@tremor/react";
import { import {
CloudEmbeddingProvider, CloudEmbeddingProvider,
@ -8,15 +8,19 @@ import {
AVAILABLE_CLOUD_PROVIDERS, AVAILABLE_CLOUD_PROVIDERS,
CloudEmbeddingProviderFull, CloudEmbeddingProviderFull,
EmbeddingModelDescriptor, EmbeddingModelDescriptor,
EmbeddingProvider,
LITELLM_CLOUD_PROVIDER,
} from "../../../../components/embedding/interfaces"; } from "../../../../components/embedding/interfaces";
import { EmbeddingDetails } from "../EmbeddingModelSelectionForm"; import { EmbeddingDetails } from "../EmbeddingModelSelectionForm";
import { FiExternalLink, FiInfo } from "react-icons/fi"; import { FiExternalLink, FiInfo } from "react-icons/fi";
import { HoverPopup } from "@/components/HoverPopup"; import { HoverPopup } from "@/components/HoverPopup";
import { Dispatch, SetStateAction } from "react"; import { Dispatch, SetStateAction, useEffect, useState } from "react";
import { LiteLLMModelForm } from "@/components/embedding/LiteLLMModelForm";
export default function CloudEmbeddingPage({ export default function CloudEmbeddingPage({
currentModel, currentModel,
embeddingProviderDetails, embeddingProviderDetails,
embeddingModelDetails,
newEnabledProviders, newEnabledProviders,
newUnenabledProviders, newUnenabledProviders,
setShowTentativeProvider, setShowTentativeProvider,
@ -30,6 +34,7 @@ export default function CloudEmbeddingPage({
currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel; currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel;
setAlreadySelectedModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>; setAlreadySelectedModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
newUnenabledProviders: string[]; newUnenabledProviders: string[];
embeddingModelDetails?: CloudEmbeddingModel[];
embeddingProviderDetails?: EmbeddingDetails[]; embeddingProviderDetails?: EmbeddingDetails[];
newEnabledProviders: string[]; newEnabledProviders: string[];
setShowTentativeProvider: React.Dispatch< setShowTentativeProvider: React.Dispatch<
@ -61,6 +66,17 @@ export default function CloudEmbeddingPage({
))!), ))!),
}) })
); );
const [liteLLMProvider, setLiteLLMProvider] = useState<
EmbeddingDetails | undefined
>(undefined);
useEffect(() => {
const foundProvider = embeddingProviderDetails?.find(
(provider) =>
provider.provider_type === EmbeddingProvider.LITELLM.toLowerCase()
);
setLiteLLMProvider(foundProvider);
}, [embeddingProviderDetails]);
return ( return (
<div> <div>
@ -122,6 +138,127 @@ export default function CloudEmbeddingPage({
</div> </div>
</div> </div>
))} ))}
<Text className="mt-6">
Alternatively, you can use a self-hosted model using the LiteLLM
proxy. This allows you to leverage various LLM providers through a
unified interface that you control.{" "}
<a
href="https://docs.litellm.ai/"
target="_blank"
rel="noopener noreferrer"
className="text-blue-500 hover:underline"
>
Learn more about LiteLLM
</a>
</Text>
<div key={LITELLM_CLOUD_PROVIDER.provider_type} className="mt-4 w-full">
<div className="flex items-center mb-2">
{LITELLM_CLOUD_PROVIDER.icon({ size: 40 })}
<h2 className="ml-2 mt-2 text-xl font-bold">
{LITELLM_CLOUD_PROVIDER.provider_type}{" "}
{LITELLM_CLOUD_PROVIDER.provider_type == "Cohere" &&
"(recommended)"}
</h2>
<HoverPopup
mainContent={
<FiInfo className="ml-2 mt-2 cursor-pointer" size={18} />
}
popupContent={
<div className="text-sm text-text-800 w-52">
<div className="my-auto">
{LITELLM_CLOUD_PROVIDER.description}
</div>
</div>
}
style="dark"
/>
</div>
<div className="w-full flex flex-col items-start">
{!liteLLMProvider ? (
<button
onClick={() => setShowTentativeProvider(LITELLM_CLOUD_PROVIDER)}
className="mb-2 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 text-sm cursor-pointer"
>
Provide API URL
</button>
) : (
<button
onClick={() =>
setChangeCredentialsProvider(LITELLM_CLOUD_PROVIDER)
}
className="mb-2 hover:underline text-sm cursor-pointer"
>
Modify API URL
</button>
)}
{!liteLLMProvider && (
<Card className="mt-2 w-full max-w-4xl bg-gray-50 border border-gray-200">
<div className="p-4">
<Text className="text-lg font-semibold mb-2">
API URL Required
</Text>
<Text className="text-sm text-gray-600 mb-4">
Before you can add models, you need to provide an API URL
for your LiteLLM proxy. Click the &quot;Provide API
URL&quot; button above to set up your LiteLLM configuration.
</Text>
<div className="flex items-center">
<FiInfo className="text-blue-500 mr-2" size={18} />
<Text className="text-sm text-blue-500">
Once configured, you&apos;ll be able to add and manage
your LiteLLM models here.
</Text>
</div>
</div>
</Card>
)}
{liteLLMProvider && (
<>
<div className="flex mb-4 flex-wrap gap-4">
{embeddingModelDetails
?.filter(
(model) =>
model.provider_type ===
EmbeddingProvider.LITELLM.toLowerCase()
)
.map((model) => (
<CloudModelCard
key={model.model_name}
model={model}
provider={LITELLM_CLOUD_PROVIDER}
currentModel={currentModel}
setAlreadySelectedModel={setAlreadySelectedModel}
setShowTentativeModel={setShowTentativeModel}
setShowModelInQueue={setShowModelInQueue}
setShowTentativeProvider={setShowTentativeProvider}
/>
))}
</div>
<Card
className={`mt-2 w-full max-w-4xl ${
currentModel.provider_type === EmbeddingProvider.LITELLM
? "border-2 border-blue-500"
: ""
}`}
>
<LiteLLMModelForm
provider={liteLLMProvider}
currentValues={
currentModel.provider_type === EmbeddingProvider.LITELLM
? (currentModel as CloudEmbeddingModel)
: null
}
setShowTentativeModel={setShowTentativeModel}
/>
</Card>
</>
)}
</div>
</div>
</div> </div>
</div> </div>
); );
@ -146,7 +283,9 @@ export function CloudModelCard({
React.SetStateAction<CloudEmbeddingProvider | null> React.SetStateAction<CloudEmbeddingProvider | null>
>; >;
}) { }) {
const enabled = model.model_name === currentModel.model_name; const enabled =
model.model_name === currentModel.model_name &&
model.provider_type == currentModel.provider_type;
return ( return (
<div <div
@ -169,9 +308,12 @@ export function CloudModelCard({
</a> </a>
</div> </div>
<p className="text-sm text-gray-600 mb-2">{model.description}</p> <p className="text-sm text-gray-600 mb-2">{model.description}</p>
<div className="text-xs text-gray-500 mb-2"> {model?.provider_type?.toLowerCase() !=
${model.pricePerMillion}/M tokens EmbeddingProvider.LITELLM.toLowerCase() && (
</div> <div className="text-xs text-gray-500 mb-2">
${model.pricePerMillion}/M tokens
</div>
)}
<div className="mt-3"> <div className="mt-3">
<button <button
className={`w-full p-2 rounded-lg text-sm ${ className={`w-full p-2 rounded-lg text-sm ${
@ -182,7 +324,10 @@ export function CloudModelCard({
onClick={() => { onClick={() => {
if (enabled) { if (enabled) {
setAlreadySelectedModel(model); setAlreadySelectedModel(model);
} else if (provider.configured) { } else if (
provider.configured ||
provider.provider_type === EmbeddingProvider.LITELLM
) {
setShowTentativeModel(model); setShowTentativeModel(model);
} else { } else {
setShowModelInQueue(model); setShowModelInQueue(model);

View File

@ -41,6 +41,7 @@ export default function EmbeddingForm() {
multipass_indexing: true, multipass_indexing: true,
multilingual_expansion: [], multilingual_expansion: [],
disable_rerank_for_streaming: false, disable_rerank_for_streaming: false,
api_url: null,
}); });
const [rerankingDetails, setRerankingDetails] = useState<RerankingDetails>({ const [rerankingDetails, setRerankingDetails] = useState<RerankingDetails>({
@ -116,6 +117,7 @@ export default function EmbeddingForm() {
multilingual_expansion: searchSettings.multilingual_expansion, multilingual_expansion: searchSettings.multilingual_expansion,
disable_rerank_for_streaming: disable_rerank_for_streaming:
searchSettings.disable_rerank_for_streaming, searchSettings.disable_rerank_for_streaming,
api_url: null,
}); });
setRerankingDetails({ setRerankingDetails({
rerank_api_key: searchSettings.rerank_api_key, rerank_api_key: searchSettings.rerank_api_key,

View File

@ -41,6 +41,7 @@ export function CustomModelForm({
api_key: null, api_key: null,
provider_type: null, provider_type: null,
index_name: null, index_name: null,
api_url: null,
}); });
}} }}
> >
@ -106,20 +107,19 @@ export function CustomModelForm({
/> />
<BooleanFormField <BooleanFormField
removeIndent
name="normalize" name="normalize"
label="Normalize Embeddings" label="Normalize Embeddings"
subtext="Whether or not to normalize the embeddings generated by the model. When in doubt, leave this checked." subtext="Whether or not to normalize the embeddings generated by the model. When in doubt, leave this checked."
/> />
<div className="flex mt-6"> <Button
<Button type="submit"
type="submit" disabled={isSubmitting}
disabled={isSubmitting} className="w-64 mx-auto"
className="w-64 mx-auto" >
> Choose
Choose </Button>
</Button>
</div>
</Form> </Form>
)} )}
</Formik> </Formik>

View File

@ -0,0 +1,116 @@
import { CloudEmbeddingModel, CloudEmbeddingProvider } from "./interfaces";
import { Formik, Form } from "formik";
import * as Yup from "yup";
import { TextFormField, BooleanFormField } from "../admin/connectors/Field";
import { Dispatch, SetStateAction } from "react";
import { Button, Text } from "@tremor/react";
import { EmbeddingDetails } from "@/app/admin/embeddings/EmbeddingModelSelectionForm";
export function LiteLLMModelForm({
setShowTentativeModel,
currentValues,
provider,
}: {
setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
currentValues: CloudEmbeddingModel | null;
provider: EmbeddingDetails;
}) {
return (
<div>
<Formik
initialValues={
currentValues || {
model_name: "",
model_dim: 768,
normalize: false,
query_prefix: "",
passage_prefix: "",
provider_type: "LiteLLM",
api_key: "",
enabled: true,
api_url: provider.api_url,
description: "",
index_name: "",
pricePerMillion: 0,
mtebScore: 0,
maxContext: 4096,
max_tokens: 1024,
}
}
validationSchema={Yup.object().shape({
model_name: Yup.string().required("Model name is required"),
model_dim: Yup.number().required("Model dimension is required"),
normalize: Yup.boolean().required(),
query_prefix: Yup.string(),
passage_prefix: Yup.string(),
provider_type: Yup.string().required("Provider type is required"),
api_key: Yup.string().optional(),
enabled: Yup.boolean(),
api_url: Yup.string().required("API base URL is required"),
description: Yup.string(),
index_name: Yup.string().nullable(),
pricePerMillion: Yup.number(),
mtebScore: Yup.number(),
maxContext: Yup.number(),
max_tokens: Yup.number(),
})}
onSubmit={async (values) => {
setShowTentativeModel(values as CloudEmbeddingModel);
}}
>
{({ isSubmitting }) => (
<Form>
<Text className="text-xl text-text-900 font-bold mb-4">
Add a new model to LiteLLM proxy at {provider.api_url}
</Text>
<TextFormField
name="model_name"
label="Model Name:"
subtext="The name of the LiteLLM model"
placeholder="e.g. 'all-MiniLM-L6-v2'"
autoCompleteDisabled={true}
/>
<TextFormField
name="model_dim"
label="Model Dimension:"
subtext="The dimension of the model's embeddings"
placeholder="e.g. '1536'"
type="number"
autoCompleteDisabled={true}
/>
<BooleanFormField
removeIndent
name="normalize"
label="Normalize"
subtext="Whether to normalize the embeddings"
/>
<TextFormField
name="query_prefix"
label="Query Prefix:"
subtext="Prefix for query embeddings"
autoCompleteDisabled={true}
/>
<TextFormField
name="passage_prefix"
label="Passage Prefix:"
subtext="Prefix for passage embeddings"
autoCompleteDisabled={true}
/>
<Button
type="submit"
disabled={isSubmitting}
className="w-64 mx-auto"
>
Configure LiteLLM Model
</Button>
</Form>
)}
</Formik>
</div>
);
}

View File

@ -2,6 +2,7 @@ import {
CohereIcon, CohereIcon,
GoogleIcon, GoogleIcon,
IconProps, IconProps,
LiteLLMIcon,
MicrosoftIcon, MicrosoftIcon,
NomicIcon, NomicIcon,
OpenAIIcon, OpenAIIcon,
@ -14,11 +15,13 @@ export enum EmbeddingProvider {
COHERE = "Cohere", COHERE = "Cohere",
VOYAGE = "Voyage", VOYAGE = "Voyage",
GOOGLE = "Google", GOOGLE = "Google",
LITELLM = "LiteLLM",
} }
export interface CloudEmbeddingProvider { export interface CloudEmbeddingProvider {
provider_type: EmbeddingProvider; provider_type: EmbeddingProvider;
api_key?: string; api_key?: string;
api_url?: string;
custom_config?: Record<string, string>; custom_config?: Record<string, string>;
docsLink?: string; docsLink?: string;
@ -44,6 +47,7 @@ export interface EmbeddingModelDescriptor {
provider_type: string | null; provider_type: string | null;
description: string; description: string;
api_key: string | null; api_key: string | null;
api_url: string | null;
index_name: string | null; index_name: string | null;
} }
@ -70,7 +74,7 @@ export interface FullEmbeddingModelResponse {
} }
export interface CloudEmbeddingProviderFull extends CloudEmbeddingProvider { export interface CloudEmbeddingProviderFull extends CloudEmbeddingProvider {
configured: boolean; configured?: boolean;
} }
export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
@ -87,6 +91,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
index_name: "", index_name: "",
provider_type: null, provider_type: null,
api_key: null, api_key: null,
api_url: null,
}, },
{ {
model_name: "intfloat/e5-base-v2", model_name: "intfloat/e5-base-v2",
@ -99,6 +104,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
passage_prefix: "passage: ", passage_prefix: "passage: ",
index_name: "", index_name: "",
provider_type: null, provider_type: null,
api_url: null,
api_key: null, api_key: null,
}, },
{ {
@ -113,6 +119,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
index_name: "", index_name: "",
provider_type: null, provider_type: null,
api_key: null, api_key: null,
api_url: null,
}, },
{ {
model_name: "intfloat/multilingual-e5-base", model_name: "intfloat/multilingual-e5-base",
@ -126,6 +133,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
index_name: "", index_name: "",
provider_type: null, provider_type: null,
api_key: null, api_key: null,
api_url: null,
}, },
{ {
model_name: "intfloat/multilingual-e5-small", model_name: "intfloat/multilingual-e5-small",
@ -139,9 +147,19 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
index_name: "", index_name: "",
provider_type: null, provider_type: null,
api_key: null, api_key: null,
api_url: null,
}, },
]; ];
export const LITELLM_CLOUD_PROVIDER: CloudEmbeddingProvider = {
provider_type: EmbeddingProvider.LITELLM,
website: "https://github.com/BerriAI/litellm",
icon: LiteLLMIcon,
description: "Open-source library to call LLM APIs using OpenAI format",
apiLink: "https://docs.litellm.ai/docs/proxy/quick_start",
embedding_models: [], // No default embedding models
};
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
{ {
provider_type: EmbeddingProvider.COHERE, provider_type: EmbeddingProvider.COHERE,
@ -169,6 +187,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "", passage_prefix: "",
index_name: "", index_name: "",
api_key: null, api_key: null,
api_url: null,
}, },
{ {
model_name: "embed-english-light-v3.0", model_name: "embed-english-light-v3.0",
@ -185,6 +204,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "", passage_prefix: "",
index_name: "", index_name: "",
api_key: null, api_key: null,
api_url: null,
}, },
], ],
}, },
@ -213,6 +233,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
enabled: false, enabled: false,
index_name: "", index_name: "",
api_key: null, api_key: null,
api_url: null,
}, },
{ {
provider_type: EmbeddingProvider.OPENAI, provider_type: EmbeddingProvider.OPENAI,
@ -229,6 +250,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
maxContext: 8191, maxContext: 8191,
index_name: "", index_name: "",
api_key: null, api_key: null,
api_url: null,
}, },
], ],
}, },
@ -258,6 +280,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "", passage_prefix: "",
index_name: "", index_name: "",
api_key: null, api_key: null,
api_url: null,
}, },
{ {
provider_type: EmbeddingProvider.GOOGLE, provider_type: EmbeddingProvider.GOOGLE,
@ -273,6 +296,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "", passage_prefix: "",
index_name: "", index_name: "",
api_key: null, api_key: null,
api_url: null,
}, },
], ],
}, },
@ -301,6 +325,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "", passage_prefix: "",
index_name: "", index_name: "",
api_key: null, api_key: null,
api_url: null,
}, },
{ {
provider_type: EmbeddingProvider.VOYAGE, provider_type: EmbeddingProvider.VOYAGE,
@ -317,6 +342,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "", passage_prefix: "",
index_name: "", index_name: "",
api_key: null, api_key: null,
api_url: null,
}, },
], ],
}, },

View File

@ -48,6 +48,7 @@ import jiraSVG from "../../../public/Jira.svg";
import confluenceSVG from "../../../public/Confluence.svg"; import confluenceSVG from "../../../public/Confluence.svg";
import openAISVG from "../../../public/Openai.svg"; import openAISVG from "../../../public/Openai.svg";
import openSourceIcon from "../../../public/OpenSource.png"; import openSourceIcon from "../../../public/OpenSource.png";
import litellmIcon from "../../../public/LiteLLM.jpg";
import awsWEBP from "../../../public/Amazon.webp"; import awsWEBP from "../../../public/Amazon.webp";
import azureIcon from "../../../public/Azure.png"; import azureIcon from "../../../public/Azure.png";
@ -267,6 +268,20 @@ export const ColorSlackIcon = ({
); );
}; };
export const LiteLLMIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<div
style={{ width: `${size + 4}px`, height: `${size + 4}px` }}
className={`w-[${size + 4}px] h-[${size + 4}px] -m-0.5 ` + className}
>
<Image src={litellmIcon} alt="Logo" width="96" height="96" />
</div>
);
};
export const OpenSourceIcon = ({ export const OpenSourceIcon = ({
size = 16, size = 16,
className = defaultTailwindCSS, className = defaultTailwindCSS,