mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-18 12:00:58 +02:00
Add litellm proxy embeddings (#2291)
* add litellm proxy * formatting * move `api_url` to cloud provider + nits * remove log * typing * quick tuyping fix * update LiteLLM selection logic * remove logs + validate functionality * rename proxy var * update path casing * remove pricing for custom models * functional values
This commit is contained in:
parent
910821c723
commit
299cb5035c
@ -0,0 +1,26 @@
|
||||
"""Add base_url to CloudEmbeddingProvider
|
||||
|
||||
Revision ID: bceb1e139447
|
||||
Revises: 1f60f60c3401
|
||||
Create Date: 2024-08-28 17:00:52.554580
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bceb1e139447"
|
||||
down_revision = "1f60f60c3401"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("api_url", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("embedding_provider", "api_url")
|
@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
@ -50,6 +51,7 @@ def upsert_cloud_embedding_provider(
|
||||
setattr(existing_provider, key, value)
|
||||
else:
|
||||
new_provider = CloudEmbeddingProviderModel(**provider.model_dump())
|
||||
|
||||
db_session.add(new_provider)
|
||||
existing_provider = new_provider
|
||||
db_session.commit()
|
||||
@ -157,12 +159,19 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
|
||||
def remove_embedding_provider(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(SearchSettings).where(SearchSettings.provider_type == provider_type)
|
||||
)
|
||||
|
||||
# Delete the embedding provider
|
||||
db_session.execute(
|
||||
delete(CloudEmbeddingProviderModel).where(
|
||||
CloudEmbeddingProviderModel.provider_type == provider_type
|
||||
)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
# Remove LLMProvider's dependent relationships
|
||||
|
@ -607,6 +607,10 @@ class SearchSettings(Base):
|
||||
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
|
||||
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
|
||||
|
||||
@property
|
||||
def api_url(self) -> str | None:
|
||||
return self.cloud_provider.api_url if self.cloud_provider is not None else None
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider is not None else None
|
||||
@ -1085,6 +1089,7 @@ class CloudEmbeddingProvider(Base):
|
||||
provider_type: Mapped[EmbeddingProvider] = mapped_column(
|
||||
Enum(EmbeddingProvider), primary_key=True
|
||||
)
|
||||
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||
search_settings: Mapped[list["SearchSettings"]] = relationship(
|
||||
"SearchSettings",
|
||||
|
@ -115,6 +115,13 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
|
||||
return latest_settings
|
||||
|
||||
|
||||
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
query = select(SearchSettings).order_by(SearchSettings.id.desc())
|
||||
result = db_session.execute(query)
|
||||
all_settings = result.scalars().all()
|
||||
return list(all_settings)
|
||||
|
||||
|
||||
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
|
||||
if db_session is None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
@ -234,6 +241,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
|
||||
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
||||
index_name="danswer_chunk",
|
||||
multipass_indexing=False,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
|
||||
@ -246,4 +254,5 @@ def get_new_default_embedding_model() -> IndexingSetting:
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
||||
multipass_indexing=False,
|
||||
api_url=None,
|
||||
)
|
||||
|
@ -32,6 +32,7 @@ class IndexingEmbedder(ABC):
|
||||
passage_prefix: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
api_key: str | None,
|
||||
api_url: str | None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.normalize = normalize
|
||||
@ -39,6 +40,7 @@ class IndexingEmbedder(ABC):
|
||||
self.passage_prefix = passage_prefix
|
||||
self.provider_type = provider_type
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
|
||||
self.embedding_model = EmbeddingModel(
|
||||
model_name=model_name,
|
||||
@ -47,6 +49,7 @@ class IndexingEmbedder(ABC):
|
||||
normalize=normalize,
|
||||
api_key=api_key,
|
||||
provider_type=provider_type,
|
||||
api_url=api_url,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
@ -70,9 +73,16 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
passage_prefix: str | None,
|
||||
provider_type: EmbeddingProvider | None = None,
|
||||
api_key: str | None = None,
|
||||
api_url: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
model_name, normalize, query_prefix, passage_prefix, provider_type, api_key
|
||||
model_name,
|
||||
normalize,
|
||||
query_prefix,
|
||||
passage_prefix,
|
||||
provider_type,
|
||||
api_key,
|
||||
api_url,
|
||||
)
|
||||
|
||||
@log_function_time()
|
||||
@ -156,7 +166,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
title_embed_dict[title] = title_embedding
|
||||
|
||||
new_embedded_chunk = IndexChunk(
|
||||
**chunk.model_dump(),
|
||||
**chunk.dict(),
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=chunk_embeddings[0],
|
||||
mini_chunk_embeddings=chunk_embeddings[1:],
|
||||
@ -179,6 +189,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
api_url=search_settings.api_url,
|
||||
)
|
||||
|
||||
|
||||
@ -202,4 +213,5 @@ def get_embedding_model_from_search_settings(
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
api_url=search_settings.api_url,
|
||||
)
|
||||
|
@ -99,6 +99,7 @@ class EmbeddingModelDetail(BaseModel):
|
||||
normalize: bool
|
||||
query_prefix: str | None
|
||||
passage_prefix: str | None
|
||||
api_url: str | None = None
|
||||
provider_type: EmbeddingProvider | None = None
|
||||
api_key: str | None = None
|
||||
|
||||
@ -117,6 +118,7 @@ class EmbeddingModelDetail(BaseModel):
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
api_url=search_settings.api_url,
|
||||
)
|
||||
|
||||
|
||||
|
@ -90,6 +90,7 @@ class EmbeddingModel:
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
api_key: str | None,
|
||||
api_url: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
retrim_content: bool = False,
|
||||
) -> None:
|
||||
@ -100,6 +101,7 @@ class EmbeddingModel:
|
||||
self.normalize = normalize
|
||||
self.model_name = model_name
|
||||
self.retrim_content = retrim_content
|
||||
self.api_url = api_url
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_name=model_name, provider_type=provider_type
|
||||
)
|
||||
@ -157,6 +159,7 @@ class EmbeddingModel:
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
api_url=self.api_url,
|
||||
)
|
||||
|
||||
response = self._make_model_server_request(embed_request)
|
||||
@ -226,6 +229,7 @@ class EmbeddingModel:
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
api_key=search_settings.api_key,
|
||||
provider_type=search_settings.provider_type,
|
||||
api_url=search_settings.api_url,
|
||||
retrim_content=retrim_content,
|
||||
)
|
||||
|
||||
|
@ -81,6 +81,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
num_rerank=search_settings.num_rerank,
|
||||
# Multilingual Expansion
|
||||
multilingual_expansion=search_settings.multilingual_expansion,
|
||||
api_url=search_settings.api_url,
|
||||
)
|
||||
|
||||
|
||||
|
@ -9,7 +9,9 @@ from danswer.db.llm import fetch_existing_embedding_providers
|
||||
from danswer.db.llm import remove_embedding_provider
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_all_search_settings
|
||||
from danswer.db.search_settings import get_current_db_embedding_provider
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
@ -20,6 +22,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@ -37,6 +40,7 @@ def test_embedding_configuration(
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
api_key=test_llm_request.api_key,
|
||||
api_url=test_llm_request.api_url,
|
||||
provider_type=test_llm_request.provider_type,
|
||||
normalize=False,
|
||||
query_prefix=None,
|
||||
@ -56,6 +60,15 @@ def test_embedding_configuration(
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.get("", response_model=list[EmbeddingModelDetail])
|
||||
def list_embedding_models(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[EmbeddingModelDetail]:
|
||||
search_settings = get_all_search_settings(db_session)
|
||||
return [EmbeddingModelDetail.from_db_model(setting) for setting in search_settings]
|
||||
|
||||
|
||||
@admin_router.get("/embedding-provider")
|
||||
def list_embedding_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
|
@ -11,11 +11,13 @@ if TYPE_CHECKING:
|
||||
class TestEmbeddingRequest(BaseModel):
|
||||
provider_type: EmbeddingProvider
|
||||
api_key: str | None = None
|
||||
api_url: str | None = None
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(BaseModel):
|
||||
provider_type: EmbeddingProvider
|
||||
api_key: str | None = None
|
||||
api_url: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
@ -24,9 +26,11 @@ class CloudEmbeddingProvider(BaseModel):
|
||||
return cls(
|
||||
provider_type=cloud_provider_model.provider_type,
|
||||
api_key=cloud_provider_model.api_key,
|
||||
api_url=cloud_provider_model.api_url,
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProviderCreationRequest(BaseModel):
|
||||
provider_type: EmbeddingProvider
|
||||
api_key: str | None = None
|
||||
api_url: str | None = None
|
||||
|
@ -45,7 +45,7 @@ def set_new_search_settings(
|
||||
if search_settings_new.index_name:
|
||||
logger.warning("Index name was specified by request, this is not suggested")
|
||||
|
||||
# Validate cloud provider exists
|
||||
# Validate cloud provider exists or create new LiteLLM provider
|
||||
if search_settings_new.provider_type is not None:
|
||||
cloud_provider = get_embedding_provider_from_provider_type(
|
||||
db_session, provider_type=search_settings_new.provider_type
|
||||
@ -133,7 +133,7 @@ def cancel_new_embedding(
|
||||
|
||||
|
||||
@router.get("/get-current-search-settings")
|
||||
def get_curr_search_settings(
|
||||
def get_current_search_settings_endpoint(
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SavedSearchSettings:
|
||||
@ -142,7 +142,7 @@ def get_curr_search_settings(
|
||||
|
||||
|
||||
@router.get("/get-secondary-search-settings")
|
||||
def get_sec_search_settings(
|
||||
def get_secondary_search_settings_endpoint(
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SavedSearchSettings | None:
|
||||
|
@ -2,6 +2,7 @@ import json
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
import vertexai # type: ignore
|
||||
import voyageai # type: ignore
|
||||
@ -235,6 +236,22 @@ def get_local_reranking_model(
|
||||
return _RERANK_MODEL
|
||||
|
||||
|
||||
def embed_with_litellm_proxy(
|
||||
texts: list[str], api_url: str, model: str
|
||||
) -> list[Embedding]:
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
api_url,
|
||||
json={
|
||||
"model": model,
|
||||
"input": texts,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [embedding["embedding"] for embedding in result["data"]]
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def embed_text(
|
||||
texts: list[str],
|
||||
@ -245,21 +262,37 @@ def embed_text(
|
||||
api_key: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
prefix: str | None,
|
||||
api_url: str | None,
|
||||
) -> list[Embedding]:
|
||||
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
|
||||
|
||||
if not all(texts):
|
||||
logger.error("Empty strings provided for embedding")
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
|
||||
# Third party API based embedding model
|
||||
if not texts:
|
||||
logger.error("No texts provided for embedding")
|
||||
raise ValueError("No texts provided for embedding.")
|
||||
|
||||
if provider_type == EmbeddingProvider.LITELLM:
|
||||
logger.debug(f"Using LiteLLM proxy for embedding with URL: {api_url}")
|
||||
if not api_url:
|
||||
logger.error("API URL not provided for LiteLLM proxy")
|
||||
raise ValueError("API URL is required for LiteLLM proxy embedding.")
|
||||
try:
|
||||
return embed_with_litellm_proxy(texts, api_url, model_name or "")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}")
|
||||
raise
|
||||
|
||||
elif provider_type is not None:
|
||||
logger.debug(f"Embedding text with provider: {provider_type}")
|
||||
logger.debug(f"Using cloud provider {provider_type} for embedding")
|
||||
if api_key is None:
|
||||
logger.error("API key not provided for cloud model")
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
|
||||
if prefix:
|
||||
# This may change in the future if some providers require the user
|
||||
# to manually append a prefix but this is not the case currently
|
||||
logger.warning("Prefix provided for cloud model, which is not supported")
|
||||
raise ValueError(
|
||||
"Prefix string is not valid for cloud models. "
|
||||
"Cloud models take an explicit text type instead."
|
||||
@ -274,14 +307,15 @@ def embed_text(
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
# Check for None values in embeddings
|
||||
if any(embedding is None for embedding in embeddings):
|
||||
error_message = "Embeddings contain None values\n"
|
||||
error_message += "Corresponding texts:\n"
|
||||
error_message += "\n".join(texts)
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
elif model_name is not None:
|
||||
logger.debug(f"Using local model {model_name} for embedding")
|
||||
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||
|
||||
local_model = get_embedding_model(
|
||||
@ -296,10 +330,12 @@ def embed_text(
|
||||
]
|
||||
|
||||
else:
|
||||
logger.error("Neither model name nor provider specified for embedding")
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
)
|
||||
|
||||
logger.info(f"Successfully embedded {len(texts)} texts")
|
||||
return embeddings
|
||||
|
||||
|
||||
@ -344,6 +380,7 @@ async def process_embed_request(
|
||||
api_key=embed_request.api_key,
|
||||
provider_type=embed_request.provider_type,
|
||||
text_type=embed_request.text_type,
|
||||
api_url=embed_request.api_url,
|
||||
prefix=prefix,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
|
@ -61,6 +61,7 @@ PRESERVED_SEARCH_FIELDS = [
|
||||
"provider_type",
|
||||
"api_key",
|
||||
"model_name",
|
||||
"api_url",
|
||||
"index_name",
|
||||
"multipass_indexing",
|
||||
"model_dim",
|
||||
|
@ -6,6 +6,7 @@ class EmbeddingProvider(str, Enum):
|
||||
COHERE = "cohere"
|
||||
VOYAGE = "voyage"
|
||||
GOOGLE = "google"
|
||||
LITELLM = "litellm"
|
||||
|
||||
|
||||
class RerankerProvider(str, Enum):
|
||||
|
@ -18,6 +18,7 @@ class EmbedRequest(BaseModel):
|
||||
text_type: EmbedTextType
|
||||
manual_query_prefix: str | None = None
|
||||
manual_passage_prefix: str | None = None
|
||||
api_url: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
@ -32,6 +32,7 @@ def openai_embedding_model() -> EmbeddingModel:
|
||||
passage_prefix=None,
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
provider_type=EmbeddingProvider.OPENAI,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
|
||||
@ -51,6 +52,7 @@ def cohere_embedding_model() -> EmbeddingModel:
|
||||
passage_prefix=None,
|
||||
api_key=os.getenv("COHERE_API_KEY"),
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
|
||||
@ -70,6 +72,7 @@ def local_nomic_embedding_model() -> EmbeddingModel:
|
||||
passage_prefix="search_document: ",
|
||||
api_key=None,
|
||||
provider_type=None,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
|
||||
|
BIN
web/public/LiteLLM.jpg
Normal file
BIN
web/public/LiteLLM.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
@ -2,3 +2,5 @@ export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
|
||||
|
||||
export const EMBEDDING_PROVIDERS_ADMIN_URL =
|
||||
"/api/admin/embedding/embedding-provider";
|
||||
|
||||
export const EMBEDDING_MODELS_ADMIN_URL = "/api/admin/embedding";
|
||||
|
@ -24,10 +24,14 @@ import { ChangeCredentialsModal } from "./modals/ChangeCredentialsModal";
|
||||
import { ModelSelectionConfirmationModal } from "./modals/ModelSelectionModal";
|
||||
import { AlreadyPickedModal } from "./modals/AlreadyPickedModal";
|
||||
import { ModelOption } from "../../../components/embedding/ModelSelector";
|
||||
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../configuration/llm/constants";
|
||||
import {
|
||||
EMBEDDING_MODELS_ADMIN_URL,
|
||||
EMBEDDING_PROVIDERS_ADMIN_URL,
|
||||
} from "../configuration/llm/constants";
|
||||
|
||||
export interface EmbeddingDetails {
|
||||
api_key: string;
|
||||
api_key?: string;
|
||||
api_url?: string;
|
||||
custom_config: any;
|
||||
provider_type: EmbeddingProvider;
|
||||
}
|
||||
@ -77,12 +81,20 @@ export function EmbeddingModelSelection({
|
||||
|
||||
const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] =
|
||||
useState<boolean>(false);
|
||||
|
||||
const [showAddConnectorPopup, setShowAddConnectorPopup] =
|
||||
useState<boolean>(false);
|
||||
|
||||
const { data: embeddingModelDetails } = useSWR<CloudEmbeddingModel[]>(
|
||||
EMBEDDING_MODELS_ADMIN_URL,
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
);
|
||||
|
||||
const { data: embeddingProviderDetails } = useSWR<EmbeddingDetails[]>(
|
||||
EMBEDDING_PROVIDERS_ADMIN_URL,
|
||||
errorHandlingFetcher
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
);
|
||||
|
||||
const { data: connectors } = useSWR<Connector<any>[]>(
|
||||
@ -175,6 +187,7 @@ export function EmbeddingModelSelection({
|
||||
|
||||
{showTentativeProvider && (
|
||||
<ProviderCreationModal
|
||||
isProxy={showTentativeProvider.provider_type == "LiteLLM"}
|
||||
selectedProvider={showTentativeProvider}
|
||||
onConfirm={() => {
|
||||
setShowTentativeProvider(showUnconfiguredProvider);
|
||||
@ -189,8 +202,10 @@ export function EmbeddingModelSelection({
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{changeCredentialsProvider && (
|
||||
<ChangeCredentialsModal
|
||||
isProxy={changeCredentialsProvider.provider_type == "LiteLLM"}
|
||||
useFileUpload={changeCredentialsProvider.provider_type == "Google"}
|
||||
onDeleted={() => {
|
||||
clientsideRemoveProvider(changeCredentialsProvider);
|
||||
@ -277,6 +292,7 @@ export function EmbeddingModelSelection({
|
||||
|
||||
{modelTab == "cloud" && (
|
||||
<CloudEmbeddingPage
|
||||
embeddingModelDetails={embeddingModelDetails}
|
||||
setShowModelInQueue={setShowModelInQueue}
|
||||
setShowTentativeModel={setShowTentativeModel}
|
||||
currentModel={selectedProvider}
|
||||
|
@ -21,6 +21,7 @@ export interface AdvancedSearchConfiguration {
|
||||
multipass_indexing: boolean;
|
||||
multilingual_expansion: string[];
|
||||
disable_rerank_for_streaming: boolean;
|
||||
api_url: string | null;
|
||||
}
|
||||
|
||||
export interface SavedSearchSettings extends RerankingDetails {
|
||||
@ -33,6 +34,7 @@ export interface SavedSearchSettings extends RerankingDetails {
|
||||
multipass_indexing: boolean;
|
||||
multilingual_expansion: string[];
|
||||
disable_rerank_for_streaming: boolean;
|
||||
api_url: string | null;
|
||||
provider_type: EmbeddingProvider | null;
|
||||
}
|
||||
|
||||
|
@ -15,14 +15,16 @@ export function ChangeCredentialsModal({
|
||||
onCancel,
|
||||
onDeleted,
|
||||
useFileUpload,
|
||||
isProxy = false,
|
||||
}: {
|
||||
provider: CloudEmbeddingProvider;
|
||||
onConfirm: () => void;
|
||||
onCancel: () => void;
|
||||
onDeleted: () => void;
|
||||
useFileUpload: boolean;
|
||||
isProxy?: boolean;
|
||||
}) {
|
||||
const [apiKey, setApiKey] = useState("");
|
||||
const [apiKeyOrUrl, setApiKeyOrUrl] = useState("");
|
||||
const [testError, setTestError] = useState<string>("");
|
||||
const [fileName, setFileName] = useState<string>("");
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
@ -50,7 +52,7 @@ export function ChangeCredentialsModal({
|
||||
let jsonContent;
|
||||
try {
|
||||
jsonContent = JSON.parse(fileContent);
|
||||
setApiKey(JSON.stringify(jsonContent));
|
||||
setApiKeyOrUrl(JSON.stringify(jsonContent));
|
||||
} catch (parseError) {
|
||||
throw new Error(
|
||||
"Failed to parse JSON file. Please ensure it's a valid JSON."
|
||||
@ -62,7 +64,7 @@ export function ChangeCredentialsModal({
|
||||
? error.message
|
||||
: "An unknown error occurred while processing the file."
|
||||
);
|
||||
setApiKey("");
|
||||
setApiKeyOrUrl("");
|
||||
clearFileInput();
|
||||
}
|
||||
}
|
||||
@ -74,7 +76,7 @@ export function ChangeCredentialsModal({
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type}`,
|
||||
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type.toLowerCase()}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
@ -105,7 +107,10 @@ export function ChangeCredentialsModal({
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
|
||||
api_key: apiKey,
|
||||
[isProxy ? "api_url" : "api_key"]: apiKeyOrUrl,
|
||||
[isProxy ? "api_key" : "api_url"]: isProxy
|
||||
? provider.api_key
|
||||
: provider.api_url,
|
||||
}),
|
||||
});
|
||||
|
||||
@ -119,7 +124,7 @@ export function ChangeCredentialsModal({
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
|
||||
api_key: apiKey,
|
||||
[isProxy ? "api_url" : "api_key"]: apiKeyOrUrl,
|
||||
is_default_provider: false,
|
||||
is_configured: true,
|
||||
}),
|
||||
@ -128,7 +133,8 @@ export function ChangeCredentialsModal({
|
||||
if (!updateResponse.ok) {
|
||||
const errorData = await updateResponse.json();
|
||||
throw new Error(
|
||||
errorData.detail || "Failed to update provider- check your API key"
|
||||
errorData.detail ||
|
||||
`Failed to update provider- check your ${isProxy ? "API URL" : "API key"}`
|
||||
);
|
||||
}
|
||||
|
||||
@ -144,12 +150,12 @@ export function ChangeCredentialsModal({
|
||||
<Modal
|
||||
width="max-w-3xl"
|
||||
icon={provider.icon}
|
||||
title={`Modify your ${provider.provider_type} key`}
|
||||
title={`Modify your ${provider.provider_type} ${isProxy ? "URL" : "key"}`}
|
||||
onOutsideClick={onCancel}
|
||||
>
|
||||
<div className="mb-4">
|
||||
<Subtitle className="font-bold text-lg">
|
||||
Want to swap out your key?
|
||||
Want to swap out your {isProxy ? "URL" : "key"}?
|
||||
</Subtitle>
|
||||
<a
|
||||
href={provider.apiLink}
|
||||
@ -185,9 +191,9 @@ export function ChangeCredentialsModal({
|
||||
px-3
|
||||
bg-background-emphasis
|
||||
`}
|
||||
value={apiKey}
|
||||
onChange={(e: any) => setApiKey(e.target.value)}
|
||||
placeholder="Paste your API key here"
|
||||
value={apiKeyOrUrl}
|
||||
onChange={(e: any) => setApiKeyOrUrl(e.target.value)}
|
||||
placeholder={`Paste your ${isProxy ? "API URL" : "API key"} here`}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
@ -203,15 +209,15 @@ export function ChangeCredentialsModal({
|
||||
<Button
|
||||
color="blue"
|
||||
onClick={() => handleSubmit()}
|
||||
disabled={!apiKey}
|
||||
disabled={!apiKeyOrUrl}
|
||||
>
|
||||
Swap Key
|
||||
Swap {isProxy ? "URL" : "Key"}
|
||||
</Button>
|
||||
</div>
|
||||
<Divider />
|
||||
|
||||
<Subtitle className="mt-4 font-bold text-lg mb-2">
|
||||
You can also delete your key.
|
||||
You can also delete your {isProxy ? "URL" : "key"}.
|
||||
</Subtitle>
|
||||
<Text className="mb-2">
|
||||
This is only possible if you have already switched to a different
|
||||
@ -219,7 +225,7 @@ export function ChangeCredentialsModal({
|
||||
</Text>
|
||||
|
||||
<Button onClick={handleDelete} color="red">
|
||||
Delete key
|
||||
Delete {isProxy ? "URL" : "key"}
|
||||
</Button>
|
||||
{deletionError && (
|
||||
<Callout title="Error" color="red" className="mt-4">
|
||||
|
@ -13,11 +13,13 @@ export function ProviderCreationModal({
|
||||
onConfirm,
|
||||
onCancel,
|
||||
existingProvider,
|
||||
isProxy,
|
||||
}: {
|
||||
selectedProvider: CloudEmbeddingProvider;
|
||||
onConfirm: () => void;
|
||||
onCancel: () => void;
|
||||
existingProvider?: CloudEmbeddingProvider;
|
||||
isProxy?: boolean;
|
||||
}) {
|
||||
const useFileUpload = selectedProvider.provider_type == "Google";
|
||||
|
||||
@ -29,6 +31,7 @@ export function ProviderCreationModal({
|
||||
provider_type:
|
||||
existingProvider?.provider_type || selectedProvider.provider_type,
|
||||
api_key: existingProvider?.api_key || "",
|
||||
api_url: existingProvider?.api_url || "",
|
||||
custom_config: existingProvider?.custom_config
|
||||
? Object.entries(existingProvider.custom_config)
|
||||
: [],
|
||||
@ -37,9 +40,14 @@ export function ProviderCreationModal({
|
||||
|
||||
const validationSchema = Yup.object({
|
||||
provider_type: Yup.string().required("Provider type is required"),
|
||||
api_key: useFileUpload
|
||||
api_key: isProxy
|
||||
? Yup.string()
|
||||
: Yup.string().required("API Key is required"),
|
||||
: useFileUpload
|
||||
? Yup.string()
|
||||
: Yup.string().required("API Key is required"),
|
||||
api_url: isProxy
|
||||
? Yup.string().required("API URL is required")
|
||||
: Yup.string(),
|
||||
custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)),
|
||||
});
|
||||
|
||||
@ -87,6 +95,7 @@ export function ProviderCreationModal({
|
||||
body: JSON.stringify({
|
||||
provider_type: values.provider_type.toLowerCase().split(" ")[0],
|
||||
api_key: values.api_key,
|
||||
api_url: values.api_url,
|
||||
}),
|
||||
}
|
||||
);
|
||||
@ -169,12 +178,19 @@ export function ProviderCreationModal({
|
||||
target="_blank"
|
||||
href={selectedProvider.apiLink}
|
||||
>
|
||||
API KEY
|
||||
{isProxy ? "API URL" : "API KEY"}
|
||||
</a>
|
||||
</Text>
|
||||
|
||||
<div className="flex w-full flex-col gap-y-2">
|
||||
{useFileUpload ? (
|
||||
{isProxy ? (
|
||||
<TextFormField
|
||||
name="api_url"
|
||||
label="API URL"
|
||||
placeholder="API URL"
|
||||
type="text"
|
||||
/>
|
||||
) : useFileUpload ? (
|
||||
<>
|
||||
<Label>Upload JSON File</Label>
|
||||
<input
|
||||
|
@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { Text, Title } from "@tremor/react";
|
||||
import { Button, Card, Text, Title } from "@tremor/react";
|
||||
|
||||
import {
|
||||
CloudEmbeddingProvider,
|
||||
@ -8,15 +8,19 @@ import {
|
||||
AVAILABLE_CLOUD_PROVIDERS,
|
||||
CloudEmbeddingProviderFull,
|
||||
EmbeddingModelDescriptor,
|
||||
EmbeddingProvider,
|
||||
LITELLM_CLOUD_PROVIDER,
|
||||
} from "../../../../components/embedding/interfaces";
|
||||
import { EmbeddingDetails } from "../EmbeddingModelSelectionForm";
|
||||
import { FiExternalLink, FiInfo } from "react-icons/fi";
|
||||
import { HoverPopup } from "@/components/HoverPopup";
|
||||
import { Dispatch, SetStateAction } from "react";
|
||||
import { Dispatch, SetStateAction, useEffect, useState } from "react";
|
||||
import { LiteLLMModelForm } from "@/components/embedding/LiteLLMModelForm";
|
||||
|
||||
export default function CloudEmbeddingPage({
|
||||
currentModel,
|
||||
embeddingProviderDetails,
|
||||
embeddingModelDetails,
|
||||
newEnabledProviders,
|
||||
newUnenabledProviders,
|
||||
setShowTentativeProvider,
|
||||
@ -30,6 +34,7 @@ export default function CloudEmbeddingPage({
|
||||
currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel;
|
||||
setAlreadySelectedModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||
newUnenabledProviders: string[];
|
||||
embeddingModelDetails?: CloudEmbeddingModel[];
|
||||
embeddingProviderDetails?: EmbeddingDetails[];
|
||||
newEnabledProviders: string[];
|
||||
setShowTentativeProvider: React.Dispatch<
|
||||
@ -61,6 +66,17 @@ export default function CloudEmbeddingPage({
|
||||
))!),
|
||||
})
|
||||
);
|
||||
const [liteLLMProvider, setLiteLLMProvider] = useState<
|
||||
EmbeddingDetails | undefined
|
||||
>(undefined);
|
||||
|
||||
useEffect(() => {
|
||||
const foundProvider = embeddingProviderDetails?.find(
|
||||
(provider) =>
|
||||
provider.provider_type === EmbeddingProvider.LITELLM.toLowerCase()
|
||||
);
|
||||
setLiteLLMProvider(foundProvider);
|
||||
}, [embeddingProviderDetails]);
|
||||
|
||||
return (
|
||||
<div>
|
||||
@ -122,6 +138,127 @@ export default function CloudEmbeddingPage({
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
|
||||
<Text className="mt-6">
|
||||
Alternatively, you can use a self-hosted model using the LiteLLM
|
||||
proxy. This allows you to leverage various LLM providers through a
|
||||
unified interface that you control.{" "}
|
||||
<a
|
||||
href="https://docs.litellm.ai/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-blue-500 hover:underline"
|
||||
>
|
||||
Learn more about LiteLLM
|
||||
</a>
|
||||
</Text>
|
||||
|
||||
<div key={LITELLM_CLOUD_PROVIDER.provider_type} className="mt-4 w-full">
|
||||
<div className="flex items-center mb-2">
|
||||
{LITELLM_CLOUD_PROVIDER.icon({ size: 40 })}
|
||||
<h2 className="ml-2 mt-2 text-xl font-bold">
|
||||
{LITELLM_CLOUD_PROVIDER.provider_type}{" "}
|
||||
{LITELLM_CLOUD_PROVIDER.provider_type == "Cohere" &&
|
||||
"(recommended)"}
|
||||
</h2>
|
||||
<HoverPopup
|
||||
mainContent={
|
||||
<FiInfo className="ml-2 mt-2 cursor-pointer" size={18} />
|
||||
}
|
||||
popupContent={
|
||||
<div className="text-sm text-text-800 w-52">
|
||||
<div className="my-auto">
|
||||
{LITELLM_CLOUD_PROVIDER.description}
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
style="dark"
|
||||
/>
|
||||
</div>
|
||||
<div className="w-full flex flex-col items-start">
|
||||
{!liteLLMProvider ? (
|
||||
<button
|
||||
onClick={() => setShowTentativeProvider(LITELLM_CLOUD_PROVIDER)}
|
||||
className="mb-2 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 text-sm cursor-pointer"
|
||||
>
|
||||
Provide API URL
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
onClick={() =>
|
||||
setChangeCredentialsProvider(LITELLM_CLOUD_PROVIDER)
|
||||
}
|
||||
className="mb-2 hover:underline text-sm cursor-pointer"
|
||||
>
|
||||
Modify API URL
|
||||
</button>
|
||||
)}
|
||||
|
||||
{!liteLLMProvider && (
|
||||
<Card className="mt-2 w-full max-w-4xl bg-gray-50 border border-gray-200">
|
||||
<div className="p-4">
|
||||
<Text className="text-lg font-semibold mb-2">
|
||||
API URL Required
|
||||
</Text>
|
||||
<Text className="text-sm text-gray-600 mb-4">
|
||||
Before you can add models, you need to provide an API URL
|
||||
for your LiteLLM proxy. Click the "Provide API
|
||||
URL" button above to set up your LiteLLM configuration.
|
||||
</Text>
|
||||
<div className="flex items-center">
|
||||
<FiInfo className="text-blue-500 mr-2" size={18} />
|
||||
<Text className="text-sm text-blue-500">
|
||||
Once configured, you'll be able to add and manage
|
||||
your LiteLLM models here.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
)}
|
||||
{liteLLMProvider && (
|
||||
<>
|
||||
<div className="flex mb-4 flex-wrap gap-4">
|
||||
{embeddingModelDetails
|
||||
?.filter(
|
||||
(model) =>
|
||||
model.provider_type ===
|
||||
EmbeddingProvider.LITELLM.toLowerCase()
|
||||
)
|
||||
.map((model) => (
|
||||
<CloudModelCard
|
||||
key={model.model_name}
|
||||
model={model}
|
||||
provider={LITELLM_CLOUD_PROVIDER}
|
||||
currentModel={currentModel}
|
||||
setAlreadySelectedModel={setAlreadySelectedModel}
|
||||
setShowTentativeModel={setShowTentativeModel}
|
||||
setShowModelInQueue={setShowModelInQueue}
|
||||
setShowTentativeProvider={setShowTentativeProvider}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<Card
|
||||
className={`mt-2 w-full max-w-4xl ${
|
||||
currentModel.provider_type === EmbeddingProvider.LITELLM
|
||||
? "border-2 border-blue-500"
|
||||
: ""
|
||||
}`}
|
||||
>
|
||||
<LiteLLMModelForm
|
||||
provider={liteLLMProvider}
|
||||
currentValues={
|
||||
currentModel.provider_type === EmbeddingProvider.LITELLM
|
||||
? (currentModel as CloudEmbeddingModel)
|
||||
: null
|
||||
}
|
||||
setShowTentativeModel={setShowTentativeModel}
|
||||
/>
|
||||
</Card>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
@ -146,7 +283,9 @@ export function CloudModelCard({
|
||||
React.SetStateAction<CloudEmbeddingProvider | null>
|
||||
>;
|
||||
}) {
|
||||
const enabled = model.model_name === currentModel.model_name;
|
||||
const enabled =
|
||||
model.model_name === currentModel.model_name &&
|
||||
model.provider_type == currentModel.provider_type;
|
||||
|
||||
return (
|
||||
<div
|
||||
@ -169,9 +308,12 @@ export function CloudModelCard({
|
||||
</a>
|
||||
</div>
|
||||
<p className="text-sm text-gray-600 mb-2">{model.description}</p>
|
||||
<div className="text-xs text-gray-500 mb-2">
|
||||
${model.pricePerMillion}/M tokens
|
||||
</div>
|
||||
{model?.provider_type?.toLowerCase() !=
|
||||
EmbeddingProvider.LITELLM.toLowerCase() && (
|
||||
<div className="text-xs text-gray-500 mb-2">
|
||||
${model.pricePerMillion}/M tokens
|
||||
</div>
|
||||
)}
|
||||
<div className="mt-3">
|
||||
<button
|
||||
className={`w-full p-2 rounded-lg text-sm ${
|
||||
@ -182,7 +324,10 @@ export function CloudModelCard({
|
||||
onClick={() => {
|
||||
if (enabled) {
|
||||
setAlreadySelectedModel(model);
|
||||
} else if (provider.configured) {
|
||||
} else if (
|
||||
provider.configured ||
|
||||
provider.provider_type === EmbeddingProvider.LITELLM
|
||||
) {
|
||||
setShowTentativeModel(model);
|
||||
} else {
|
||||
setShowModelInQueue(model);
|
||||
|
@ -41,6 +41,7 @@ export default function EmbeddingForm() {
|
||||
multipass_indexing: true,
|
||||
multilingual_expansion: [],
|
||||
disable_rerank_for_streaming: false,
|
||||
api_url: null,
|
||||
});
|
||||
|
||||
const [rerankingDetails, setRerankingDetails] = useState<RerankingDetails>({
|
||||
@ -116,6 +117,7 @@ export default function EmbeddingForm() {
|
||||
multilingual_expansion: searchSettings.multilingual_expansion,
|
||||
disable_rerank_for_streaming:
|
||||
searchSettings.disable_rerank_for_streaming,
|
||||
api_url: null,
|
||||
});
|
||||
setRerankingDetails({
|
||||
rerank_api_key: searchSettings.rerank_api_key,
|
||||
|
@ -41,6 +41,7 @@ export function CustomModelForm({
|
||||
api_key: null,
|
||||
provider_type: null,
|
||||
index_name: null,
|
||||
api_url: null,
|
||||
});
|
||||
}}
|
||||
>
|
||||
@ -106,20 +107,19 @@ export function CustomModelForm({
|
||||
/>
|
||||
|
||||
<BooleanFormField
|
||||
removeIndent
|
||||
name="normalize"
|
||||
label="Normalize Embeddings"
|
||||
subtext="Whether or not to normalize the embeddings generated by the model. When in doubt, leave this checked."
|
||||
/>
|
||||
|
||||
<div className="flex mt-6">
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
className="w-64 mx-auto"
|
||||
>
|
||||
Choose
|
||||
</Button>
|
||||
</div>
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
className="w-64 mx-auto"
|
||||
>
|
||||
Choose
|
||||
</Button>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
|
116
web/src/components/embedding/LiteLLMModelForm.tsx
Normal file
116
web/src/components/embedding/LiteLLMModelForm.tsx
Normal file
@ -0,0 +1,116 @@
|
||||
import { CloudEmbeddingModel, CloudEmbeddingProvider } from "./interfaces";
|
||||
import { Formik, Form } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { TextFormField, BooleanFormField } from "../admin/connectors/Field";
|
||||
import { Dispatch, SetStateAction } from "react";
|
||||
import { Button, Text } from "@tremor/react";
|
||||
import { EmbeddingDetails } from "@/app/admin/embeddings/EmbeddingModelSelectionForm";
|
||||
|
||||
export function LiteLLMModelForm({
|
||||
setShowTentativeModel,
|
||||
currentValues,
|
||||
provider,
|
||||
}: {
|
||||
setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||
currentValues: CloudEmbeddingModel | null;
|
||||
provider: EmbeddingDetails;
|
||||
}) {
|
||||
return (
|
||||
<div>
|
||||
<Formik
|
||||
initialValues={
|
||||
currentValues || {
|
||||
model_name: "",
|
||||
model_dim: 768,
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
provider_type: "LiteLLM",
|
||||
api_key: "",
|
||||
enabled: true,
|
||||
api_url: provider.api_url,
|
||||
description: "",
|
||||
index_name: "",
|
||||
pricePerMillion: 0,
|
||||
mtebScore: 0,
|
||||
maxContext: 4096,
|
||||
max_tokens: 1024,
|
||||
}
|
||||
}
|
||||
validationSchema={Yup.object().shape({
|
||||
model_name: Yup.string().required("Model name is required"),
|
||||
model_dim: Yup.number().required("Model dimension is required"),
|
||||
normalize: Yup.boolean().required(),
|
||||
query_prefix: Yup.string(),
|
||||
passage_prefix: Yup.string(),
|
||||
provider_type: Yup.string().required("Provider type is required"),
|
||||
api_key: Yup.string().optional(),
|
||||
enabled: Yup.boolean(),
|
||||
api_url: Yup.string().required("API base URL is required"),
|
||||
description: Yup.string(),
|
||||
index_name: Yup.string().nullable(),
|
||||
pricePerMillion: Yup.number(),
|
||||
mtebScore: Yup.number(),
|
||||
maxContext: Yup.number(),
|
||||
max_tokens: Yup.number(),
|
||||
})}
|
||||
onSubmit={async (values) => {
|
||||
setShowTentativeModel(values as CloudEmbeddingModel);
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
<Form>
|
||||
<Text className="text-xl text-text-900 font-bold mb-4">
|
||||
Add a new model to LiteLLM proxy at {provider.api_url}
|
||||
</Text>
|
||||
<TextFormField
|
||||
name="model_name"
|
||||
label="Model Name:"
|
||||
subtext="The name of the LiteLLM model"
|
||||
placeholder="e.g. 'all-MiniLM-L6-v2'"
|
||||
autoCompleteDisabled={true}
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="model_dim"
|
||||
label="Model Dimension:"
|
||||
subtext="The dimension of the model's embeddings"
|
||||
placeholder="e.g. '1536'"
|
||||
type="number"
|
||||
autoCompleteDisabled={true}
|
||||
/>
|
||||
|
||||
<BooleanFormField
|
||||
removeIndent
|
||||
name="normalize"
|
||||
label="Normalize"
|
||||
subtext="Whether to normalize the embeddings"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="query_prefix"
|
||||
label="Query Prefix:"
|
||||
subtext="Prefix for query embeddings"
|
||||
autoCompleteDisabled={true}
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="passage_prefix"
|
||||
label="Passage Prefix:"
|
||||
subtext="Prefix for passage embeddings"
|
||||
autoCompleteDisabled={true}
|
||||
/>
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
className="w-64 mx-auto"
|
||||
>
|
||||
Configure LiteLLM Model
|
||||
</Button>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</div>
|
||||
);
|
||||
}
|
@ -2,6 +2,7 @@ import {
|
||||
CohereIcon,
|
||||
GoogleIcon,
|
||||
IconProps,
|
||||
LiteLLMIcon,
|
||||
MicrosoftIcon,
|
||||
NomicIcon,
|
||||
OpenAIIcon,
|
||||
@ -14,11 +15,13 @@ export enum EmbeddingProvider {
|
||||
COHERE = "Cohere",
|
||||
VOYAGE = "Voyage",
|
||||
GOOGLE = "Google",
|
||||
LITELLM = "LiteLLM",
|
||||
}
|
||||
|
||||
export interface CloudEmbeddingProvider {
|
||||
provider_type: EmbeddingProvider;
|
||||
api_key?: string;
|
||||
api_url?: string;
|
||||
custom_config?: Record<string, string>;
|
||||
docsLink?: string;
|
||||
|
||||
@ -44,6 +47,7 @@ export interface EmbeddingModelDescriptor {
|
||||
provider_type: string | null;
|
||||
description: string;
|
||||
api_key: string | null;
|
||||
api_url: string | null;
|
||||
index_name: string | null;
|
||||
}
|
||||
|
||||
@ -70,7 +74,7 @@ export interface FullEmbeddingModelResponse {
|
||||
}
|
||||
|
||||
export interface CloudEmbeddingProviderFull extends CloudEmbeddingProvider {
|
||||
configured: boolean;
|
||||
configured?: boolean;
|
||||
}
|
||||
|
||||
export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
@ -87,6 +91,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
index_name: "",
|
||||
provider_type: null,
|
||||
api_key: null,
|
||||
api_url: null,
|
||||
},
|
||||
{
|
||||
model_name: "intfloat/e5-base-v2",
|
||||
@ -99,6 +104,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
passage_prefix: "passage: ",
|
||||
index_name: "",
|
||||
provider_type: null,
|
||||
api_url: null,
|
||||
api_key: null,
|
||||
},
|
||||
{
|
||||
@ -113,6 +119,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
index_name: "",
|
||||
provider_type: null,
|
||||
api_key: null,
|
||||
api_url: null,
|
||||
},
|
||||
{
|
||||
model_name: "intfloat/multilingual-e5-base",
|
||||
@ -126,6 +133,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
index_name: "",
|
||||
provider_type: null,
|
||||
api_key: null,
|
||||
api_url: null,
|
||||
},
|
||||
{
|
||||
model_name: "intfloat/multilingual-e5-small",
|
||||
@ -139,9 +147,19 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
index_name: "",
|
||||
provider_type: null,
|
||||
api_key: null,
|
||||
api_url: null,
|
||||
},
|
||||
];
|
||||
|
||||
export const LITELLM_CLOUD_PROVIDER: CloudEmbeddingProvider = {
|
||||
provider_type: EmbeddingProvider.LITELLM,
|
||||
website: "https://github.com/BerriAI/litellm",
|
||||
icon: LiteLLMIcon,
|
||||
description: "Open-source library to call LLM APIs using OpenAI format",
|
||||
apiLink: "https://docs.litellm.ai/docs/proxy/quick_start",
|
||||
embedding_models: [], // No default embedding models
|
||||
};
|
||||
|
||||
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
{
|
||||
provider_type: EmbeddingProvider.COHERE,
|
||||
@ -169,6 +187,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
passage_prefix: "",
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
api_url: null,
|
||||
},
|
||||
{
|
||||
model_name: "embed-english-light-v3.0",
|
||||
@ -185,6 +204,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
passage_prefix: "",
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
api_url: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
@ -213,6 +233,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
enabled: false,
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
api_url: null,
|
||||
},
|
||||
{
|
||||
provider_type: EmbeddingProvider.OPENAI,
|
||||
@ -229,6 +250,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
maxContext: 8191,
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
api_url: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
@ -258,6 +280,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
passage_prefix: "",
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
api_url: null,
|
||||
},
|
||||
{
|
||||
provider_type: EmbeddingProvider.GOOGLE,
|
||||
@ -273,6 +296,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
passage_prefix: "",
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
api_url: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
@ -301,6 +325,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
passage_prefix: "",
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
api_url: null,
|
||||
},
|
||||
{
|
||||
provider_type: EmbeddingProvider.VOYAGE,
|
||||
@ -317,6 +342,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
passage_prefix: "",
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
api_url: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
|
@ -48,6 +48,7 @@ import jiraSVG from "../../../public/Jira.svg";
|
||||
import confluenceSVG from "../../../public/Confluence.svg";
|
||||
import openAISVG from "../../../public/Openai.svg";
|
||||
import openSourceIcon from "../../../public/OpenSource.png";
|
||||
import litellmIcon from "../../../public/LiteLLM.jpg";
|
||||
|
||||
import awsWEBP from "../../../public/Amazon.webp";
|
||||
import azureIcon from "../../../public/Azure.png";
|
||||
@ -267,6 +268,20 @@ export const ColorSlackIcon = ({
|
||||
);
|
||||
};
|
||||
|
||||
export const LiteLLMIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<div
|
||||
style={{ width: `${size + 4}px`, height: `${size + 4}px` }}
|
||||
className={`w-[${size + 4}px] h-[${size + 4}px] -m-0.5 ` + className}
|
||||
>
|
||||
<Image src={litellmIcon} alt="Logo" width="96" height="96" />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const OpenSourceIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
|
Loading…
x
Reference in New Issue
Block a user