Add litellm proxy embeddings (#2291)

* add litellm proxy

* formatting

* move `api_url` to cloud provider + nits

* remove log

* typing

* quick tuyping fix

* update LiteLLM selection logic

* remove logs + validate functionality

* rename proxy var

* update path casing

* remove pricing for custom models

* functional values
This commit is contained in:
pablodanswer 2024-09-02 09:08:35 -07:00 committed by GitHub
parent 910821c723
commit 299cb5035c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 524 additions and 50 deletions

View File

@ -0,0 +1,26 @@
"""Add base_url to CloudEmbeddingProvider
Revision ID: bceb1e139447
Revises: 1f60f60c3401
Create Date: 2024-08-28 17:00:52.554580
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "bceb1e139447"
down_revision = "1f60f60c3401"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"embedding_provider", sa.Column("api_url", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("embedding_provider", "api_url")

View File

@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import SearchSettings
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
@ -50,6 +51,7 @@ def upsert_cloud_embedding_provider(
setattr(existing_provider, key, value)
else:
new_provider = CloudEmbeddingProviderModel(**provider.model_dump())
db_session.add(new_provider)
existing_provider = new_provider
db_session.commit()
@ -157,12 +159,19 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
def remove_embedding_provider(
db_session: Session, provider_type: EmbeddingProvider
) -> None:
db_session.execute(
delete(SearchSettings).where(SearchSettings.provider_type == provider_type)
)
# Delete the embedding provider
db_session.execute(
delete(CloudEmbeddingProviderModel).where(
CloudEmbeddingProviderModel.provider_type == provider_type
)
)
db_session.commit()
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
# Remove LLMProvider's dependent relationships

View File

@ -607,6 +607,10 @@ class SearchSettings(Base):
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
@property
def api_url(self) -> str | None:
return self.cloud_provider.api_url if self.cloud_provider is not None else None
@property
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider is not None else None
@ -1085,6 +1089,7 @@ class CloudEmbeddingProvider(Base):
provider_type: Mapped[EmbeddingProvider] = mapped_column(
Enum(EmbeddingProvider), primary_key=True
)
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
search_settings: Mapped[list["SearchSettings"]] = relationship(
"SearchSettings",

View File

@ -115,6 +115,13 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
return latest_settings
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
query = select(SearchSettings).order_by(SearchSettings.id.desc())
result = db_session.execute(query)
all_settings = result.scalars().all()
return list(all_settings)
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with Session(get_sqlalchemy_engine()) as db_session:
@ -234,6 +241,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
index_name="danswer_chunk",
multipass_indexing=False,
api_url=None,
)
@ -246,4 +254,5 @@ def get_new_default_embedding_model() -> IndexingSetting:
passage_prefix=ASYM_PASSAGE_PREFIX,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
multipass_indexing=False,
api_url=None,
)

View File

@ -32,6 +32,7 @@ class IndexingEmbedder(ABC):
passage_prefix: str | None,
provider_type: EmbeddingProvider | None,
api_key: str | None,
api_url: str | None,
):
self.model_name = model_name
self.normalize = normalize
@ -39,6 +40,7 @@ class IndexingEmbedder(ABC):
self.passage_prefix = passage_prefix
self.provider_type = provider_type
self.api_key = api_key
self.api_url = api_url
self.embedding_model = EmbeddingModel(
model_name=model_name,
@ -47,6 +49,7 @@ class IndexingEmbedder(ABC):
normalize=normalize,
api_key=api_key,
provider_type=provider_type,
api_url=api_url,
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
@ -70,9 +73,16 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
passage_prefix: str | None,
provider_type: EmbeddingProvider | None = None,
api_key: str | None = None,
api_url: str | None = None,
):
super().__init__(
model_name, normalize, query_prefix, passage_prefix, provider_type, api_key
model_name,
normalize,
query_prefix,
passage_prefix,
provider_type,
api_key,
api_url,
)
@log_function_time()
@ -156,7 +166,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embed_dict[title] = title_embedding
new_embedded_chunk = IndexChunk(
**chunk.model_dump(),
**chunk.dict(),
embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:],
@ -179,6 +189,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
)
@ -202,4 +213,5 @@ def get_embedding_model_from_search_settings(
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
)

View File

@ -99,6 +99,7 @@ class EmbeddingModelDetail(BaseModel):
normalize: bool
query_prefix: str | None
passage_prefix: str | None
api_url: str | None = None
provider_type: EmbeddingProvider | None = None
api_key: str | None = None
@ -117,6 +118,7 @@ class EmbeddingModelDetail(BaseModel):
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
)

View File

@ -90,6 +90,7 @@ class EmbeddingModel:
query_prefix: str | None,
passage_prefix: str | None,
api_key: str | None,
api_url: str | None,
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
) -> None:
@ -100,6 +101,7 @@ class EmbeddingModel:
self.normalize = normalize
self.model_name = model_name
self.retrim_content = retrim_content
self.api_url = api_url
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
@ -157,6 +159,7 @@ class EmbeddingModel:
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
api_url=self.api_url,
)
response = self._make_model_server_request(embed_request)
@ -226,6 +229,7 @@ class EmbeddingModel:
passage_prefix=search_settings.passage_prefix,
api_key=search_settings.api_key,
provider_type=search_settings.provider_type,
api_url=search_settings.api_url,
retrim_content=retrim_content,
)

View File

@ -81,6 +81,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
num_rerank=search_settings.num_rerank,
# Multilingual Expansion
multilingual_expansion=search_settings.multilingual_expansion,
api_url=search_settings.api_url,
)

View File

@ -9,7 +9,9 @@ from danswer.db.llm import fetch_existing_embedding_providers
from danswer.db.llm import remove_embedding_provider
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.models import User
from danswer.db.search_settings import get_all_search_settings
from danswer.db.search_settings import get_current_db_embedding_provider
from danswer.indexing.models import EmbeddingModelDetail
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
@ -20,6 +22,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
logger = setup_logger()
@ -37,6 +40,7 @@ def test_embedding_configuration(
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
api_key=test_llm_request.api_key,
api_url=test_llm_request.api_url,
provider_type=test_llm_request.provider_type,
normalize=False,
query_prefix=None,
@ -56,6 +60,15 @@ def test_embedding_configuration(
raise HTTPException(status_code=400, detail=error_msg)
@admin_router.get("", response_model=list[EmbeddingModelDetail])
def list_embedding_models(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[EmbeddingModelDetail]:
search_settings = get_all_search_settings(db_session)
return [EmbeddingModelDetail.from_db_model(setting) for setting in search_settings]
@admin_router.get("/embedding-provider")
def list_embedding_providers(
_: User | None = Depends(current_admin_user),

View File

@ -11,11 +11,13 @@ if TYPE_CHECKING:
class TestEmbeddingRequest(BaseModel):
provider_type: EmbeddingProvider
api_key: str | None = None
api_url: str | None = None
class CloudEmbeddingProvider(BaseModel):
provider_type: EmbeddingProvider
api_key: str | None = None
api_url: str | None = None
@classmethod
def from_request(
@ -24,9 +26,11 @@ class CloudEmbeddingProvider(BaseModel):
return cls(
provider_type=cloud_provider_model.provider_type,
api_key=cloud_provider_model.api_key,
api_url=cloud_provider_model.api_url,
)
class CloudEmbeddingProviderCreationRequest(BaseModel):
provider_type: EmbeddingProvider
api_key: str | None = None
api_url: str | None = None

View File

@ -45,7 +45,7 @@ def set_new_search_settings(
if search_settings_new.index_name:
logger.warning("Index name was specified by request, this is not suggested")
# Validate cloud provider exists
# Validate cloud provider exists or create new LiteLLM provider
if search_settings_new.provider_type is not None:
cloud_provider = get_embedding_provider_from_provider_type(
db_session, provider_type=search_settings_new.provider_type
@ -133,7 +133,7 @@ def cancel_new_embedding(
@router.get("/get-current-search-settings")
def get_curr_search_settings(
def get_current_search_settings_endpoint(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SavedSearchSettings:
@ -142,7 +142,7 @@ def get_curr_search_settings(
@router.get("/get-secondary-search-settings")
def get_sec_search_settings(
def get_secondary_search_settings_endpoint(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SavedSearchSettings | None:

View File

@ -2,6 +2,7 @@ import json
from typing import Any
from typing import Optional
import httpx
import openai
import vertexai # type: ignore
import voyageai # type: ignore
@ -235,6 +236,22 @@ def get_local_reranking_model(
return _RERANK_MODEL
def embed_with_litellm_proxy(
texts: list[str], api_url: str, model: str
) -> list[Embedding]:
with httpx.Client() as client:
response = client.post(
api_url,
json={
"model": model,
"input": texts,
},
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
@simple_log_function_time()
def embed_text(
texts: list[str],
@ -245,21 +262,37 @@ def embed_text(
api_key: str | None,
provider_type: EmbeddingProvider | None,
prefix: str | None,
api_url: str | None,
) -> list[Embedding]:
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
if not all(texts):
logger.error("Empty strings provided for embedding")
raise ValueError("Empty strings are not allowed for embedding.")
# Third party API based embedding model
if not texts:
logger.error("No texts provided for embedding")
raise ValueError("No texts provided for embedding.")
if provider_type == EmbeddingProvider.LITELLM:
logger.debug(f"Using LiteLLM proxy for embedding with URL: {api_url}")
if not api_url:
logger.error("API URL not provided for LiteLLM proxy")
raise ValueError("API URL is required for LiteLLM proxy embedding.")
try:
return embed_with_litellm_proxy(texts, api_url, model_name or "")
except Exception as e:
logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}")
raise
elif provider_type is not None:
logger.debug(f"Embedding text with provider: {provider_type}")
logger.debug(f"Using cloud provider {provider_type} for embedding")
if api_key is None:
logger.error("API key not provided for cloud model")
raise RuntimeError("API key not provided for cloud model")
if prefix:
# This may change in the future if some providers require the user
# to manually append a prefix but this is not the case currently
logger.warning("Prefix provided for cloud model, which is not supported")
raise ValueError(
"Prefix string is not valid for cloud models. "
"Cloud models take an explicit text type instead."
@ -274,14 +307,15 @@ def embed_text(
text_type=text_type,
)
# Check for None values in embeddings
if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n"
error_message += "Corresponding texts:\n"
error_message += "\n".join(texts)
logger.error(error_message)
raise ValueError(error_message)
elif model_name is not None:
logger.debug(f"Using local model {model_name} for embedding")
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
local_model = get_embedding_model(
@ -296,10 +330,12 @@ def embed_text(
]
else:
logger.error("Neither model name nor provider specified for embedding")
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
logger.info(f"Successfully embedded {len(texts)} texts")
return embeddings
@ -344,6 +380,7 @@ async def process_embed_request(
api_key=embed_request.api_key,
provider_type=embed_request.provider_type,
text_type=embed_request.text_type,
api_url=embed_request.api_url,
prefix=prefix,
)
return EmbedResponse(embeddings=embeddings)

View File

@ -61,6 +61,7 @@ PRESERVED_SEARCH_FIELDS = [
"provider_type",
"api_key",
"model_name",
"api_url",
"index_name",
"multipass_indexing",
"model_dim",

View File

@ -6,6 +6,7 @@ class EmbeddingProvider(str, Enum):
COHERE = "cohere"
VOYAGE = "voyage"
GOOGLE = "google"
LITELLM = "litellm"
class RerankerProvider(str, Enum):

View File

@ -18,6 +18,7 @@ class EmbedRequest(BaseModel):
text_type: EmbedTextType
manual_query_prefix: str | None = None
manual_passage_prefix: str | None = None
api_url: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

View File

@ -32,6 +32,7 @@ def openai_embedding_model() -> EmbeddingModel:
passage_prefix=None,
api_key=os.getenv("OPENAI_API_KEY"),
provider_type=EmbeddingProvider.OPENAI,
api_url=None,
)
@ -51,6 +52,7 @@ def cohere_embedding_model() -> EmbeddingModel:
passage_prefix=None,
api_key=os.getenv("COHERE_API_KEY"),
provider_type=EmbeddingProvider.COHERE,
api_url=None,
)
@ -70,6 +72,7 @@ def local_nomic_embedding_model() -> EmbeddingModel:
passage_prefix="search_document: ",
api_key=None,
provider_type=None,
api_url=None,
)

BIN
web/public/LiteLLM.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

View File

@ -2,3 +2,5 @@ export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
export const EMBEDDING_PROVIDERS_ADMIN_URL =
"/api/admin/embedding/embedding-provider";
export const EMBEDDING_MODELS_ADMIN_URL = "/api/admin/embedding";

View File

@ -24,10 +24,14 @@ import { ChangeCredentialsModal } from "./modals/ChangeCredentialsModal";
import { ModelSelectionConfirmationModal } from "./modals/ModelSelectionModal";
import { AlreadyPickedModal } from "./modals/AlreadyPickedModal";
import { ModelOption } from "../../../components/embedding/ModelSelector";
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../configuration/llm/constants";
import {
EMBEDDING_MODELS_ADMIN_URL,
EMBEDDING_PROVIDERS_ADMIN_URL,
} from "../configuration/llm/constants";
export interface EmbeddingDetails {
api_key: string;
api_key?: string;
api_url?: string;
custom_config: any;
provider_type: EmbeddingProvider;
}
@ -77,12 +81,20 @@ export function EmbeddingModelSelection({
const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] =
useState<boolean>(false);
const [showAddConnectorPopup, setShowAddConnectorPopup] =
useState<boolean>(false);
const { data: embeddingModelDetails } = useSWR<CloudEmbeddingModel[]>(
EMBEDDING_MODELS_ADMIN_URL,
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
const { data: embeddingProviderDetails } = useSWR<EmbeddingDetails[]>(
EMBEDDING_PROVIDERS_ADMIN_URL,
errorHandlingFetcher
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
const { data: connectors } = useSWR<Connector<any>[]>(
@ -175,6 +187,7 @@ export function EmbeddingModelSelection({
{showTentativeProvider && (
<ProviderCreationModal
isProxy={showTentativeProvider.provider_type == "LiteLLM"}
selectedProvider={showTentativeProvider}
onConfirm={() => {
setShowTentativeProvider(showUnconfiguredProvider);
@ -189,8 +202,10 @@ export function EmbeddingModelSelection({
}}
/>
)}
{changeCredentialsProvider && (
<ChangeCredentialsModal
isProxy={changeCredentialsProvider.provider_type == "LiteLLM"}
useFileUpload={changeCredentialsProvider.provider_type == "Google"}
onDeleted={() => {
clientsideRemoveProvider(changeCredentialsProvider);
@ -277,6 +292,7 @@ export function EmbeddingModelSelection({
{modelTab == "cloud" && (
<CloudEmbeddingPage
embeddingModelDetails={embeddingModelDetails}
setShowModelInQueue={setShowModelInQueue}
setShowTentativeModel={setShowTentativeModel}
currentModel={selectedProvider}

View File

@ -21,6 +21,7 @@ export interface AdvancedSearchConfiguration {
multipass_indexing: boolean;
multilingual_expansion: string[];
disable_rerank_for_streaming: boolean;
api_url: string | null;
}
export interface SavedSearchSettings extends RerankingDetails {
@ -33,6 +34,7 @@ export interface SavedSearchSettings extends RerankingDetails {
multipass_indexing: boolean;
multilingual_expansion: string[];
disable_rerank_for_streaming: boolean;
api_url: string | null;
provider_type: EmbeddingProvider | null;
}

View File

@ -15,14 +15,16 @@ export function ChangeCredentialsModal({
onCancel,
onDeleted,
useFileUpload,
isProxy = false,
}: {
provider: CloudEmbeddingProvider;
onConfirm: () => void;
onCancel: () => void;
onDeleted: () => void;
useFileUpload: boolean;
isProxy?: boolean;
}) {
const [apiKey, setApiKey] = useState("");
const [apiKeyOrUrl, setApiKeyOrUrl] = useState("");
const [testError, setTestError] = useState<string>("");
const [fileName, setFileName] = useState<string>("");
const fileInputRef = useRef<HTMLInputElement>(null);
@ -50,7 +52,7 @@ export function ChangeCredentialsModal({
let jsonContent;
try {
jsonContent = JSON.parse(fileContent);
setApiKey(JSON.stringify(jsonContent));
setApiKeyOrUrl(JSON.stringify(jsonContent));
} catch (parseError) {
throw new Error(
"Failed to parse JSON file. Please ensure it's a valid JSON."
@ -62,7 +64,7 @@ export function ChangeCredentialsModal({
? error.message
: "An unknown error occurred while processing the file."
);
setApiKey("");
setApiKeyOrUrl("");
clearFileInput();
}
}
@ -74,7 +76,7 @@ export function ChangeCredentialsModal({
try {
const response = await fetch(
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type}`,
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type.toLowerCase()}`,
{
method: "DELETE",
}
@ -105,7 +107,10 @@ export function ChangeCredentialsModal({
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
api_key: apiKey,
[isProxy ? "api_url" : "api_key"]: apiKeyOrUrl,
[isProxy ? "api_key" : "api_url"]: isProxy
? provider.api_key
: provider.api_url,
}),
});
@ -119,7 +124,7 @@ export function ChangeCredentialsModal({
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
api_key: apiKey,
[isProxy ? "api_url" : "api_key"]: apiKeyOrUrl,
is_default_provider: false,
is_configured: true,
}),
@ -128,7 +133,8 @@ export function ChangeCredentialsModal({
if (!updateResponse.ok) {
const errorData = await updateResponse.json();
throw new Error(
errorData.detail || "Failed to update provider- check your API key"
errorData.detail ||
`Failed to update provider- check your ${isProxy ? "API URL" : "API key"}`
);
}
@ -144,12 +150,12 @@ export function ChangeCredentialsModal({
<Modal
width="max-w-3xl"
icon={provider.icon}
title={`Modify your ${provider.provider_type} key`}
title={`Modify your ${provider.provider_type} ${isProxy ? "URL" : "key"}`}
onOutsideClick={onCancel}
>
<div className="mb-4">
<Subtitle className="font-bold text-lg">
Want to swap out your key?
Want to swap out your {isProxy ? "URL" : "key"}?
</Subtitle>
<a
href={provider.apiLink}
@ -185,9 +191,9 @@ export function ChangeCredentialsModal({
px-3
bg-background-emphasis
`}
value={apiKey}
onChange={(e: any) => setApiKey(e.target.value)}
placeholder="Paste your API key here"
value={apiKeyOrUrl}
onChange={(e: any) => setApiKeyOrUrl(e.target.value)}
placeholder={`Paste your ${isProxy ? "API URL" : "API key"} here`}
/>
</>
)}
@ -203,15 +209,15 @@ export function ChangeCredentialsModal({
<Button
color="blue"
onClick={() => handleSubmit()}
disabled={!apiKey}
disabled={!apiKeyOrUrl}
>
Swap Key
Swap {isProxy ? "URL" : "Key"}
</Button>
</div>
<Divider />
<Subtitle className="mt-4 font-bold text-lg mb-2">
You can also delete your key.
You can also delete your {isProxy ? "URL" : "key"}.
</Subtitle>
<Text className="mb-2">
This is only possible if you have already switched to a different
@ -219,7 +225,7 @@ export function ChangeCredentialsModal({
</Text>
<Button onClick={handleDelete} color="red">
Delete key
Delete {isProxy ? "URL" : "key"}
</Button>
{deletionError && (
<Callout title="Error" color="red" className="mt-4">

View File

@ -13,11 +13,13 @@ export function ProviderCreationModal({
onConfirm,
onCancel,
existingProvider,
isProxy,
}: {
selectedProvider: CloudEmbeddingProvider;
onConfirm: () => void;
onCancel: () => void;
existingProvider?: CloudEmbeddingProvider;
isProxy?: boolean;
}) {
const useFileUpload = selectedProvider.provider_type == "Google";
@ -29,6 +31,7 @@ export function ProviderCreationModal({
provider_type:
existingProvider?.provider_type || selectedProvider.provider_type,
api_key: existingProvider?.api_key || "",
api_url: existingProvider?.api_url || "",
custom_config: existingProvider?.custom_config
? Object.entries(existingProvider.custom_config)
: [],
@ -37,9 +40,14 @@ export function ProviderCreationModal({
const validationSchema = Yup.object({
provider_type: Yup.string().required("Provider type is required"),
api_key: useFileUpload
api_key: isProxy
? Yup.string()
: useFileUpload
? Yup.string()
: Yup.string().required("API Key is required"),
api_url: isProxy
? Yup.string().required("API URL is required")
: Yup.string(),
custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)),
});
@ -87,6 +95,7 @@ export function ProviderCreationModal({
body: JSON.stringify({
provider_type: values.provider_type.toLowerCase().split(" ")[0],
api_key: values.api_key,
api_url: values.api_url,
}),
}
);
@ -169,12 +178,19 @@ export function ProviderCreationModal({
target="_blank"
href={selectedProvider.apiLink}
>
API KEY
{isProxy ? "API URL" : "API KEY"}
</a>
</Text>
<div className="flex w-full flex-col gap-y-2">
{useFileUpload ? (
{isProxy ? (
<TextFormField
name="api_url"
label="API URL"
placeholder="API URL"
type="text"
/>
) : useFileUpload ? (
<>
<Label>Upload JSON File</Label>
<input

View File

@ -1,6 +1,6 @@
"use client";
import { Text, Title } from "@tremor/react";
import { Button, Card, Text, Title } from "@tremor/react";
import {
CloudEmbeddingProvider,
@ -8,15 +8,19 @@ import {
AVAILABLE_CLOUD_PROVIDERS,
CloudEmbeddingProviderFull,
EmbeddingModelDescriptor,
EmbeddingProvider,
LITELLM_CLOUD_PROVIDER,
} from "../../../../components/embedding/interfaces";
import { EmbeddingDetails } from "../EmbeddingModelSelectionForm";
import { FiExternalLink, FiInfo } from "react-icons/fi";
import { HoverPopup } from "@/components/HoverPopup";
import { Dispatch, SetStateAction } from "react";
import { Dispatch, SetStateAction, useEffect, useState } from "react";
import { LiteLLMModelForm } from "@/components/embedding/LiteLLMModelForm";
export default function CloudEmbeddingPage({
currentModel,
embeddingProviderDetails,
embeddingModelDetails,
newEnabledProviders,
newUnenabledProviders,
setShowTentativeProvider,
@ -30,6 +34,7 @@ export default function CloudEmbeddingPage({
currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel;
setAlreadySelectedModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
newUnenabledProviders: string[];
embeddingModelDetails?: CloudEmbeddingModel[];
embeddingProviderDetails?: EmbeddingDetails[];
newEnabledProviders: string[];
setShowTentativeProvider: React.Dispatch<
@ -61,6 +66,17 @@ export default function CloudEmbeddingPage({
))!),
})
);
const [liteLLMProvider, setLiteLLMProvider] = useState<
EmbeddingDetails | undefined
>(undefined);
useEffect(() => {
const foundProvider = embeddingProviderDetails?.find(
(provider) =>
provider.provider_type === EmbeddingProvider.LITELLM.toLowerCase()
);
setLiteLLMProvider(foundProvider);
}, [embeddingProviderDetails]);
return (
<div>
@ -122,6 +138,127 @@ export default function CloudEmbeddingPage({
</div>
</div>
))}
<Text className="mt-6">
Alternatively, you can use a self-hosted model using the LiteLLM
proxy. This allows you to leverage various LLM providers through a
unified interface that you control.{" "}
<a
href="https://docs.litellm.ai/"
target="_blank"
rel="noopener noreferrer"
className="text-blue-500 hover:underline"
>
Learn more about LiteLLM
</a>
</Text>
<div key={LITELLM_CLOUD_PROVIDER.provider_type} className="mt-4 w-full">
<div className="flex items-center mb-2">
{LITELLM_CLOUD_PROVIDER.icon({ size: 40 })}
<h2 className="ml-2 mt-2 text-xl font-bold">
{LITELLM_CLOUD_PROVIDER.provider_type}{" "}
{LITELLM_CLOUD_PROVIDER.provider_type == "Cohere" &&
"(recommended)"}
</h2>
<HoverPopup
mainContent={
<FiInfo className="ml-2 mt-2 cursor-pointer" size={18} />
}
popupContent={
<div className="text-sm text-text-800 w-52">
<div className="my-auto">
{LITELLM_CLOUD_PROVIDER.description}
</div>
</div>
}
style="dark"
/>
</div>
<div className="w-full flex flex-col items-start">
{!liteLLMProvider ? (
<button
onClick={() => setShowTentativeProvider(LITELLM_CLOUD_PROVIDER)}
className="mb-2 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 text-sm cursor-pointer"
>
Provide API URL
</button>
) : (
<button
onClick={() =>
setChangeCredentialsProvider(LITELLM_CLOUD_PROVIDER)
}
className="mb-2 hover:underline text-sm cursor-pointer"
>
Modify API URL
</button>
)}
{!liteLLMProvider && (
<Card className="mt-2 w-full max-w-4xl bg-gray-50 border border-gray-200">
<div className="p-4">
<Text className="text-lg font-semibold mb-2">
API URL Required
</Text>
<Text className="text-sm text-gray-600 mb-4">
Before you can add models, you need to provide an API URL
for your LiteLLM proxy. Click the &quot;Provide API
URL&quot; button above to set up your LiteLLM configuration.
</Text>
<div className="flex items-center">
<FiInfo className="text-blue-500 mr-2" size={18} />
<Text className="text-sm text-blue-500">
Once configured, you&apos;ll be able to add and manage
your LiteLLM models here.
</Text>
</div>
</div>
</Card>
)}
{liteLLMProvider && (
<>
<div className="flex mb-4 flex-wrap gap-4">
{embeddingModelDetails
?.filter(
(model) =>
model.provider_type ===
EmbeddingProvider.LITELLM.toLowerCase()
)
.map((model) => (
<CloudModelCard
key={model.model_name}
model={model}
provider={LITELLM_CLOUD_PROVIDER}
currentModel={currentModel}
setAlreadySelectedModel={setAlreadySelectedModel}
setShowTentativeModel={setShowTentativeModel}
setShowModelInQueue={setShowModelInQueue}
setShowTentativeProvider={setShowTentativeProvider}
/>
))}
</div>
<Card
className={`mt-2 w-full max-w-4xl ${
currentModel.provider_type === EmbeddingProvider.LITELLM
? "border-2 border-blue-500"
: ""
}`}
>
<LiteLLMModelForm
provider={liteLLMProvider}
currentValues={
currentModel.provider_type === EmbeddingProvider.LITELLM
? (currentModel as CloudEmbeddingModel)
: null
}
setShowTentativeModel={setShowTentativeModel}
/>
</Card>
</>
)}
</div>
</div>
</div>
</div>
);
@ -146,7 +283,9 @@ export function CloudModelCard({
React.SetStateAction<CloudEmbeddingProvider | null>
>;
}) {
const enabled = model.model_name === currentModel.model_name;
const enabled =
model.model_name === currentModel.model_name &&
model.provider_type == currentModel.provider_type;
return (
<div
@ -169,9 +308,12 @@ export function CloudModelCard({
</a>
</div>
<p className="text-sm text-gray-600 mb-2">{model.description}</p>
{model?.provider_type?.toLowerCase() !=
EmbeddingProvider.LITELLM.toLowerCase() && (
<div className="text-xs text-gray-500 mb-2">
${model.pricePerMillion}/M tokens
</div>
)}
<div className="mt-3">
<button
className={`w-full p-2 rounded-lg text-sm ${
@ -182,7 +324,10 @@ export function CloudModelCard({
onClick={() => {
if (enabled) {
setAlreadySelectedModel(model);
} else if (provider.configured) {
} else if (
provider.configured ||
provider.provider_type === EmbeddingProvider.LITELLM
) {
setShowTentativeModel(model);
} else {
setShowModelInQueue(model);

View File

@ -41,6 +41,7 @@ export default function EmbeddingForm() {
multipass_indexing: true,
multilingual_expansion: [],
disable_rerank_for_streaming: false,
api_url: null,
});
const [rerankingDetails, setRerankingDetails] = useState<RerankingDetails>({
@ -116,6 +117,7 @@ export default function EmbeddingForm() {
multilingual_expansion: searchSettings.multilingual_expansion,
disable_rerank_for_streaming:
searchSettings.disable_rerank_for_streaming,
api_url: null,
});
setRerankingDetails({
rerank_api_key: searchSettings.rerank_api_key,

View File

@ -41,6 +41,7 @@ export function CustomModelForm({
api_key: null,
provider_type: null,
index_name: null,
api_url: null,
});
}}
>
@ -106,12 +107,12 @@ export function CustomModelForm({
/>
<BooleanFormField
removeIndent
name="normalize"
label="Normalize Embeddings"
subtext="Whether or not to normalize the embeddings generated by the model. When in doubt, leave this checked."
/>
<div className="flex mt-6">
<Button
type="submit"
disabled={isSubmitting}
@ -119,7 +120,6 @@ export function CustomModelForm({
>
Choose
</Button>
</div>
</Form>
)}
</Formik>

View File

@ -0,0 +1,116 @@
import { CloudEmbeddingModel, CloudEmbeddingProvider } from "./interfaces";
import { Formik, Form } from "formik";
import * as Yup from "yup";
import { TextFormField, BooleanFormField } from "../admin/connectors/Field";
import { Dispatch, SetStateAction } from "react";
import { Button, Text } from "@tremor/react";
import { EmbeddingDetails } from "@/app/admin/embeddings/EmbeddingModelSelectionForm";
export function LiteLLMModelForm({
setShowTentativeModel,
currentValues,
provider,
}: {
setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
currentValues: CloudEmbeddingModel | null;
provider: EmbeddingDetails;
}) {
return (
<div>
<Formik
initialValues={
currentValues || {
model_name: "",
model_dim: 768,
normalize: false,
query_prefix: "",
passage_prefix: "",
provider_type: "LiteLLM",
api_key: "",
enabled: true,
api_url: provider.api_url,
description: "",
index_name: "",
pricePerMillion: 0,
mtebScore: 0,
maxContext: 4096,
max_tokens: 1024,
}
}
validationSchema={Yup.object().shape({
model_name: Yup.string().required("Model name is required"),
model_dim: Yup.number().required("Model dimension is required"),
normalize: Yup.boolean().required(),
query_prefix: Yup.string(),
passage_prefix: Yup.string(),
provider_type: Yup.string().required("Provider type is required"),
api_key: Yup.string().optional(),
enabled: Yup.boolean(),
api_url: Yup.string().required("API base URL is required"),
description: Yup.string(),
index_name: Yup.string().nullable(),
pricePerMillion: Yup.number(),
mtebScore: Yup.number(),
maxContext: Yup.number(),
max_tokens: Yup.number(),
})}
onSubmit={async (values) => {
setShowTentativeModel(values as CloudEmbeddingModel);
}}
>
{({ isSubmitting }) => (
<Form>
<Text className="text-xl text-text-900 font-bold mb-4">
Add a new model to LiteLLM proxy at {provider.api_url}
</Text>
<TextFormField
name="model_name"
label="Model Name:"
subtext="The name of the LiteLLM model"
placeholder="e.g. 'all-MiniLM-L6-v2'"
autoCompleteDisabled={true}
/>
<TextFormField
name="model_dim"
label="Model Dimension:"
subtext="The dimension of the model's embeddings"
placeholder="e.g. '1536'"
type="number"
autoCompleteDisabled={true}
/>
<BooleanFormField
removeIndent
name="normalize"
label="Normalize"
subtext="Whether to normalize the embeddings"
/>
<TextFormField
name="query_prefix"
label="Query Prefix:"
subtext="Prefix for query embeddings"
autoCompleteDisabled={true}
/>
<TextFormField
name="passage_prefix"
label="Passage Prefix:"
subtext="Prefix for passage embeddings"
autoCompleteDisabled={true}
/>
<Button
type="submit"
disabled={isSubmitting}
className="w-64 mx-auto"
>
Configure LiteLLM Model
</Button>
</Form>
)}
</Formik>
</div>
);
}

View File

@ -2,6 +2,7 @@ import {
CohereIcon,
GoogleIcon,
IconProps,
LiteLLMIcon,
MicrosoftIcon,
NomicIcon,
OpenAIIcon,
@ -14,11 +15,13 @@ export enum EmbeddingProvider {
COHERE = "Cohere",
VOYAGE = "Voyage",
GOOGLE = "Google",
LITELLM = "LiteLLM",
}
export interface CloudEmbeddingProvider {
provider_type: EmbeddingProvider;
api_key?: string;
api_url?: string;
custom_config?: Record<string, string>;
docsLink?: string;
@ -44,6 +47,7 @@ export interface EmbeddingModelDescriptor {
provider_type: string | null;
description: string;
api_key: string | null;
api_url: string | null;
index_name: string | null;
}
@ -70,7 +74,7 @@ export interface FullEmbeddingModelResponse {
}
export interface CloudEmbeddingProviderFull extends CloudEmbeddingProvider {
configured: boolean;
configured?: boolean;
}
export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
@ -87,6 +91,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
index_name: "",
provider_type: null,
api_key: null,
api_url: null,
},
{
model_name: "intfloat/e5-base-v2",
@ -99,6 +104,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
passage_prefix: "passage: ",
index_name: "",
provider_type: null,
api_url: null,
api_key: null,
},
{
@ -113,6 +119,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
index_name: "",
provider_type: null,
api_key: null,
api_url: null,
},
{
model_name: "intfloat/multilingual-e5-base",
@ -126,6 +133,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
index_name: "",
provider_type: null,
api_key: null,
api_url: null,
},
{
model_name: "intfloat/multilingual-e5-small",
@ -139,9 +147,19 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
index_name: "",
provider_type: null,
api_key: null,
api_url: null,
},
];
export const LITELLM_CLOUD_PROVIDER: CloudEmbeddingProvider = {
provider_type: EmbeddingProvider.LITELLM,
website: "https://github.com/BerriAI/litellm",
icon: LiteLLMIcon,
description: "Open-source library to call LLM APIs using OpenAI format",
apiLink: "https://docs.litellm.ai/docs/proxy/quick_start",
embedding_models: [], // No default embedding models
};
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
{
provider_type: EmbeddingProvider.COHERE,
@ -169,6 +187,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "",
index_name: "",
api_key: null,
api_url: null,
},
{
model_name: "embed-english-light-v3.0",
@ -185,6 +204,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "",
index_name: "",
api_key: null,
api_url: null,
},
],
},
@ -213,6 +233,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
enabled: false,
index_name: "",
api_key: null,
api_url: null,
},
{
provider_type: EmbeddingProvider.OPENAI,
@ -229,6 +250,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
maxContext: 8191,
index_name: "",
api_key: null,
api_url: null,
},
],
},
@ -258,6 +280,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "",
index_name: "",
api_key: null,
api_url: null,
},
{
provider_type: EmbeddingProvider.GOOGLE,
@ -273,6 +296,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "",
index_name: "",
api_key: null,
api_url: null,
},
],
},
@ -301,6 +325,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "",
index_name: "",
api_key: null,
api_url: null,
},
{
provider_type: EmbeddingProvider.VOYAGE,
@ -317,6 +342,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "",
index_name: "",
api_key: null,
api_url: null,
},
],
},

View File

@ -48,6 +48,7 @@ import jiraSVG from "../../../public/Jira.svg";
import confluenceSVG from "../../../public/Confluence.svg";
import openAISVG from "../../../public/Openai.svg";
import openSourceIcon from "../../../public/OpenSource.png";
import litellmIcon from "../../../public/LiteLLM.jpg";
import awsWEBP from "../../../public/Amazon.webp";
import azureIcon from "../../../public/Azure.png";
@ -267,6 +268,20 @@ export const ColorSlackIcon = ({
);
};
export const LiteLLMIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<div
style={{ width: `${size + 4}px`, height: `${size + 4}px` }}
className={`w-[${size + 4}px] h-[${size + 4}px] -m-0.5 ` + className}
>
<Image src={litellmIcon} alt="Logo" width="96" height="96" />
</div>
);
};
export const OpenSourceIcon = ({
size = 16,
className = defaultTailwindCSS,