mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-18 20:10:58 +02:00
Add litellm proxy embeddings (#2291)
* add litellm proxy * formatting * move `api_url` to cloud provider + nits * remove log * typing * quick tuyping fix * update LiteLLM selection logic * remove logs + validate functionality * rename proxy var * update path casing * remove pricing for custom models * functional values
This commit is contained in:
parent
910821c723
commit
299cb5035c
@ -0,0 +1,26 @@
|
|||||||
|
"""Add base_url to CloudEmbeddingProvider
|
||||||
|
|
||||||
|
Revision ID: bceb1e139447
|
||||||
|
Revises: 1f60f60c3401
|
||||||
|
Create Date: 2024-08-28 17:00:52.554580
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "bceb1e139447"
|
||||||
|
down_revision = "1f60f60c3401"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"embedding_provider", sa.Column("api_url", sa.String(), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("embedding_provider", "api_url")
|
@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
|
|||||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||||
from danswer.db.models import LLMProvider__UserGroup
|
from danswer.db.models import LLMProvider__UserGroup
|
||||||
|
from danswer.db.models import SearchSettings
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.db.models import User__UserGroup
|
from danswer.db.models import User__UserGroup
|
||||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||||
@ -50,6 +51,7 @@ def upsert_cloud_embedding_provider(
|
|||||||
setattr(existing_provider, key, value)
|
setattr(existing_provider, key, value)
|
||||||
else:
|
else:
|
||||||
new_provider = CloudEmbeddingProviderModel(**provider.model_dump())
|
new_provider = CloudEmbeddingProviderModel(**provider.model_dump())
|
||||||
|
|
||||||
db_session.add(new_provider)
|
db_session.add(new_provider)
|
||||||
existing_provider = new_provider
|
existing_provider = new_provider
|
||||||
db_session.commit()
|
db_session.commit()
|
||||||
@ -157,12 +159,19 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
|
|||||||
def remove_embedding_provider(
|
def remove_embedding_provider(
|
||||||
db_session: Session, provider_type: EmbeddingProvider
|
db_session: Session, provider_type: EmbeddingProvider
|
||||||
) -> None:
|
) -> None:
|
||||||
|
db_session.execute(
|
||||||
|
delete(SearchSettings).where(SearchSettings.provider_type == provider_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete the embedding provider
|
||||||
db_session.execute(
|
db_session.execute(
|
||||||
delete(CloudEmbeddingProviderModel).where(
|
delete(CloudEmbeddingProviderModel).where(
|
||||||
CloudEmbeddingProviderModel.provider_type == provider_type
|
CloudEmbeddingProviderModel.provider_type == provider_type
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||||
# Remove LLMProvider's dependent relationships
|
# Remove LLMProvider's dependent relationships
|
||||||
|
@ -607,6 +607,10 @@ class SearchSettings(Base):
|
|||||||
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
|
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
|
||||||
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
|
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api_url(self) -> str | None:
|
||||||
|
return self.cloud_provider.api_url if self.cloud_provider is not None else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def api_key(self) -> str | None:
|
def api_key(self) -> str | None:
|
||||||
return self.cloud_provider.api_key if self.cloud_provider is not None else None
|
return self.cloud_provider.api_key if self.cloud_provider is not None else None
|
||||||
@ -1085,6 +1089,7 @@ class CloudEmbeddingProvider(Base):
|
|||||||
provider_type: Mapped[EmbeddingProvider] = mapped_column(
|
provider_type: Mapped[EmbeddingProvider] = mapped_column(
|
||||||
Enum(EmbeddingProvider), primary_key=True
|
Enum(EmbeddingProvider), primary_key=True
|
||||||
)
|
)
|
||||||
|
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||||
search_settings: Mapped[list["SearchSettings"]] = relationship(
|
search_settings: Mapped[list["SearchSettings"]] = relationship(
|
||||||
"SearchSettings",
|
"SearchSettings",
|
||||||
|
@ -115,6 +115,13 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
|
|||||||
return latest_settings
|
return latest_settings
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||||
|
query = select(SearchSettings).order_by(SearchSettings.id.desc())
|
||||||
|
result = db_session.execute(query)
|
||||||
|
all_settings = result.scalars().all()
|
||||||
|
return list(all_settings)
|
||||||
|
|
||||||
|
|
||||||
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
|
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
|
||||||
if db_session is None:
|
if db_session is None:
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
@ -234,6 +241,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
|
|||||||
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
||||||
index_name="danswer_chunk",
|
index_name="danswer_chunk",
|
||||||
multipass_indexing=False,
|
multipass_indexing=False,
|
||||||
|
api_url=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -246,4 +254,5 @@ def get_new_default_embedding_model() -> IndexingSetting:
|
|||||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||||
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
||||||
multipass_indexing=False,
|
multipass_indexing=False,
|
||||||
|
api_url=None,
|
||||||
)
|
)
|
||||||
|
@ -32,6 +32,7 @@ class IndexingEmbedder(ABC):
|
|||||||
passage_prefix: str | None,
|
passage_prefix: str | None,
|
||||||
provider_type: EmbeddingProvider | None,
|
provider_type: EmbeddingProvider | None,
|
||||||
api_key: str | None,
|
api_key: str | None,
|
||||||
|
api_url: str | None,
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
@ -39,6 +40,7 @@ class IndexingEmbedder(ABC):
|
|||||||
self.passage_prefix = passage_prefix
|
self.passage_prefix = passage_prefix
|
||||||
self.provider_type = provider_type
|
self.provider_type = provider_type
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self.api_url = api_url
|
||||||
|
|
||||||
self.embedding_model = EmbeddingModel(
|
self.embedding_model = EmbeddingModel(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -47,6 +49,7 @@ class IndexingEmbedder(ABC):
|
|||||||
normalize=normalize,
|
normalize=normalize,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
provider_type=provider_type,
|
provider_type=provider_type,
|
||||||
|
api_url=api_url,
|
||||||
# The below are globally set, this flow always uses the indexing one
|
# The below are globally set, this flow always uses the indexing one
|
||||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||||
@ -70,9 +73,16 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
|||||||
passage_prefix: str | None,
|
passage_prefix: str | None,
|
||||||
provider_type: EmbeddingProvider | None = None,
|
provider_type: EmbeddingProvider | None = None,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
|
api_url: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_name, normalize, query_prefix, passage_prefix, provider_type, api_key
|
model_name,
|
||||||
|
normalize,
|
||||||
|
query_prefix,
|
||||||
|
passage_prefix,
|
||||||
|
provider_type,
|
||||||
|
api_key,
|
||||||
|
api_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
@log_function_time()
|
@log_function_time()
|
||||||
@ -156,7 +166,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
|||||||
title_embed_dict[title] = title_embedding
|
title_embed_dict[title] = title_embedding
|
||||||
|
|
||||||
new_embedded_chunk = IndexChunk(
|
new_embedded_chunk = IndexChunk(
|
||||||
**chunk.model_dump(),
|
**chunk.dict(),
|
||||||
embeddings=ChunkEmbedding(
|
embeddings=ChunkEmbedding(
|
||||||
full_embedding=chunk_embeddings[0],
|
full_embedding=chunk_embeddings[0],
|
||||||
mini_chunk_embeddings=chunk_embeddings[1:],
|
mini_chunk_embeddings=chunk_embeddings[1:],
|
||||||
@ -179,6 +189,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
|||||||
passage_prefix=search_settings.passage_prefix,
|
passage_prefix=search_settings.passage_prefix,
|
||||||
provider_type=search_settings.provider_type,
|
provider_type=search_settings.provider_type,
|
||||||
api_key=search_settings.api_key,
|
api_key=search_settings.api_key,
|
||||||
|
api_url=search_settings.api_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -202,4 +213,5 @@ def get_embedding_model_from_search_settings(
|
|||||||
passage_prefix=search_settings.passage_prefix,
|
passage_prefix=search_settings.passage_prefix,
|
||||||
provider_type=search_settings.provider_type,
|
provider_type=search_settings.provider_type,
|
||||||
api_key=search_settings.api_key,
|
api_key=search_settings.api_key,
|
||||||
|
api_url=search_settings.api_url,
|
||||||
)
|
)
|
||||||
|
@ -99,6 +99,7 @@ class EmbeddingModelDetail(BaseModel):
|
|||||||
normalize: bool
|
normalize: bool
|
||||||
query_prefix: str | None
|
query_prefix: str | None
|
||||||
passage_prefix: str | None
|
passage_prefix: str | None
|
||||||
|
api_url: str | None = None
|
||||||
provider_type: EmbeddingProvider | None = None
|
provider_type: EmbeddingProvider | None = None
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
|
|
||||||
@ -117,6 +118,7 @@ class EmbeddingModelDetail(BaseModel):
|
|||||||
passage_prefix=search_settings.passage_prefix,
|
passage_prefix=search_settings.passage_prefix,
|
||||||
provider_type=search_settings.provider_type,
|
provider_type=search_settings.provider_type,
|
||||||
api_key=search_settings.api_key,
|
api_key=search_settings.api_key,
|
||||||
|
api_url=search_settings.api_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,6 +90,7 @@ class EmbeddingModel:
|
|||||||
query_prefix: str | None,
|
query_prefix: str | None,
|
||||||
passage_prefix: str | None,
|
passage_prefix: str | None,
|
||||||
api_key: str | None,
|
api_key: str | None,
|
||||||
|
api_url: str | None,
|
||||||
provider_type: EmbeddingProvider | None,
|
provider_type: EmbeddingProvider | None,
|
||||||
retrim_content: bool = False,
|
retrim_content: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -100,6 +101,7 @@ class EmbeddingModel:
|
|||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.retrim_content = retrim_content
|
self.retrim_content = retrim_content
|
||||||
|
self.api_url = api_url
|
||||||
self.tokenizer = get_tokenizer(
|
self.tokenizer = get_tokenizer(
|
||||||
model_name=model_name, provider_type=provider_type
|
model_name=model_name, provider_type=provider_type
|
||||||
)
|
)
|
||||||
@ -157,6 +159,7 @@ class EmbeddingModel:
|
|||||||
text_type=text_type,
|
text_type=text_type,
|
||||||
manual_query_prefix=self.query_prefix,
|
manual_query_prefix=self.query_prefix,
|
||||||
manual_passage_prefix=self.passage_prefix,
|
manual_passage_prefix=self.passage_prefix,
|
||||||
|
api_url=self.api_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = self._make_model_server_request(embed_request)
|
response = self._make_model_server_request(embed_request)
|
||||||
@ -226,6 +229,7 @@ class EmbeddingModel:
|
|||||||
passage_prefix=search_settings.passage_prefix,
|
passage_prefix=search_settings.passage_prefix,
|
||||||
api_key=search_settings.api_key,
|
api_key=search_settings.api_key,
|
||||||
provider_type=search_settings.provider_type,
|
provider_type=search_settings.provider_type,
|
||||||
|
api_url=search_settings.api_url,
|
||||||
retrim_content=retrim_content,
|
retrim_content=retrim_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -81,6 +81,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
|||||||
num_rerank=search_settings.num_rerank,
|
num_rerank=search_settings.num_rerank,
|
||||||
# Multilingual Expansion
|
# Multilingual Expansion
|
||||||
multilingual_expansion=search_settings.multilingual_expansion,
|
multilingual_expansion=search_settings.multilingual_expansion,
|
||||||
|
api_url=search_settings.api_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,7 +9,9 @@ from danswer.db.llm import fetch_existing_embedding_providers
|
|||||||
from danswer.db.llm import remove_embedding_provider
|
from danswer.db.llm import remove_embedding_provider
|
||||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
|
from danswer.db.search_settings import get_all_search_settings
|
||||||
from danswer.db.search_settings import get_current_db_embedding_provider
|
from danswer.db.search_settings import get_current_db_embedding_provider
|
||||||
|
from danswer.indexing.models import EmbeddingModelDetail
|
||||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||||
@ -20,6 +22,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
|||||||
from shared_configs.enums import EmbeddingProvider
|
from shared_configs.enums import EmbeddingProvider
|
||||||
from shared_configs.enums import EmbedTextType
|
from shared_configs.enums import EmbedTextType
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
@ -37,6 +40,7 @@ def test_embedding_configuration(
|
|||||||
server_host=MODEL_SERVER_HOST,
|
server_host=MODEL_SERVER_HOST,
|
||||||
server_port=MODEL_SERVER_PORT,
|
server_port=MODEL_SERVER_PORT,
|
||||||
api_key=test_llm_request.api_key,
|
api_key=test_llm_request.api_key,
|
||||||
|
api_url=test_llm_request.api_url,
|
||||||
provider_type=test_llm_request.provider_type,
|
provider_type=test_llm_request.provider_type,
|
||||||
normalize=False,
|
normalize=False,
|
||||||
query_prefix=None,
|
query_prefix=None,
|
||||||
@ -56,6 +60,15 @@ def test_embedding_configuration(
|
|||||||
raise HTTPException(status_code=400, detail=error_msg)
|
raise HTTPException(status_code=400, detail=error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.get("", response_model=list[EmbeddingModelDetail])
|
||||||
|
def list_embedding_models(
|
||||||
|
_: User | None = Depends(current_admin_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> list[EmbeddingModelDetail]:
|
||||||
|
search_settings = get_all_search_settings(db_session)
|
||||||
|
return [EmbeddingModelDetail.from_db_model(setting) for setting in search_settings]
|
||||||
|
|
||||||
|
|
||||||
@admin_router.get("/embedding-provider")
|
@admin_router.get("/embedding-provider")
|
||||||
def list_embedding_providers(
|
def list_embedding_providers(
|
||||||
_: User | None = Depends(current_admin_user),
|
_: User | None = Depends(current_admin_user),
|
||||||
|
@ -11,11 +11,13 @@ if TYPE_CHECKING:
|
|||||||
class TestEmbeddingRequest(BaseModel):
|
class TestEmbeddingRequest(BaseModel):
|
||||||
provider_type: EmbeddingProvider
|
provider_type: EmbeddingProvider
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
|
api_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class CloudEmbeddingProvider(BaseModel):
|
class CloudEmbeddingProvider(BaseModel):
|
||||||
provider_type: EmbeddingProvider
|
provider_type: EmbeddingProvider
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
|
api_url: str | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_request(
|
def from_request(
|
||||||
@ -24,9 +26,11 @@ class CloudEmbeddingProvider(BaseModel):
|
|||||||
return cls(
|
return cls(
|
||||||
provider_type=cloud_provider_model.provider_type,
|
provider_type=cloud_provider_model.provider_type,
|
||||||
api_key=cloud_provider_model.api_key,
|
api_key=cloud_provider_model.api_key,
|
||||||
|
api_url=cloud_provider_model.api_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CloudEmbeddingProviderCreationRequest(BaseModel):
|
class CloudEmbeddingProviderCreationRequest(BaseModel):
|
||||||
provider_type: EmbeddingProvider
|
provider_type: EmbeddingProvider
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
|
api_url: str | None = None
|
||||||
|
@ -45,7 +45,7 @@ def set_new_search_settings(
|
|||||||
if search_settings_new.index_name:
|
if search_settings_new.index_name:
|
||||||
logger.warning("Index name was specified by request, this is not suggested")
|
logger.warning("Index name was specified by request, this is not suggested")
|
||||||
|
|
||||||
# Validate cloud provider exists
|
# Validate cloud provider exists or create new LiteLLM provider
|
||||||
if search_settings_new.provider_type is not None:
|
if search_settings_new.provider_type is not None:
|
||||||
cloud_provider = get_embedding_provider_from_provider_type(
|
cloud_provider = get_embedding_provider_from_provider_type(
|
||||||
db_session, provider_type=search_settings_new.provider_type
|
db_session, provider_type=search_settings_new.provider_type
|
||||||
@ -133,7 +133,7 @@ def cancel_new_embedding(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/get-current-search-settings")
|
@router.get("/get-current-search-settings")
|
||||||
def get_curr_search_settings(
|
def get_current_search_settings_endpoint(
|
||||||
_: User | None = Depends(current_user),
|
_: User | None = Depends(current_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> SavedSearchSettings:
|
) -> SavedSearchSettings:
|
||||||
@ -142,7 +142,7 @@ def get_curr_search_settings(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/get-secondary-search-settings")
|
@router.get("/get-secondary-search-settings")
|
||||||
def get_sec_search_settings(
|
def get_secondary_search_settings_endpoint(
|
||||||
_: User | None = Depends(current_user),
|
_: User | None = Depends(current_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> SavedSearchSettings | None:
|
) -> SavedSearchSettings | None:
|
||||||
|
@ -2,6 +2,7 @@ import json
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
import vertexai # type: ignore
|
import vertexai # type: ignore
|
||||||
import voyageai # type: ignore
|
import voyageai # type: ignore
|
||||||
@ -235,6 +236,22 @@ def get_local_reranking_model(
|
|||||||
return _RERANK_MODEL
|
return _RERANK_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def embed_with_litellm_proxy(
|
||||||
|
texts: list[str], api_url: str, model: str
|
||||||
|
) -> list[Embedding]:
|
||||||
|
with httpx.Client() as client:
|
||||||
|
response = client.post(
|
||||||
|
api_url,
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"input": texts,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return [embedding["embedding"] for embedding in result["data"]]
|
||||||
|
|
||||||
|
|
||||||
@simple_log_function_time()
|
@simple_log_function_time()
|
||||||
def embed_text(
|
def embed_text(
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
@ -245,21 +262,37 @@ def embed_text(
|
|||||||
api_key: str | None,
|
api_key: str | None,
|
||||||
provider_type: EmbeddingProvider | None,
|
provider_type: EmbeddingProvider | None,
|
||||||
prefix: str | None,
|
prefix: str | None,
|
||||||
|
api_url: str | None,
|
||||||
) -> list[Embedding]:
|
) -> list[Embedding]:
|
||||||
|
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
|
||||||
|
|
||||||
if not all(texts):
|
if not all(texts):
|
||||||
|
logger.error("Empty strings provided for embedding")
|
||||||
raise ValueError("Empty strings are not allowed for embedding.")
|
raise ValueError("Empty strings are not allowed for embedding.")
|
||||||
|
|
||||||
# Third party API based embedding model
|
|
||||||
if not texts:
|
if not texts:
|
||||||
|
logger.error("No texts provided for embedding")
|
||||||
raise ValueError("No texts provided for embedding.")
|
raise ValueError("No texts provided for embedding.")
|
||||||
|
|
||||||
|
if provider_type == EmbeddingProvider.LITELLM:
|
||||||
|
logger.debug(f"Using LiteLLM proxy for embedding with URL: {api_url}")
|
||||||
|
if not api_url:
|
||||||
|
logger.error("API URL not provided for LiteLLM proxy")
|
||||||
|
raise ValueError("API URL is required for LiteLLM proxy embedding.")
|
||||||
|
try:
|
||||||
|
return embed_with_litellm_proxy(texts, api_url, model_name or "")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
elif provider_type is not None:
|
elif provider_type is not None:
|
||||||
logger.debug(f"Embedding text with provider: {provider_type}")
|
logger.debug(f"Using cloud provider {provider_type} for embedding")
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
|
logger.error("API key not provided for cloud model")
|
||||||
raise RuntimeError("API key not provided for cloud model")
|
raise RuntimeError("API key not provided for cloud model")
|
||||||
|
|
||||||
if prefix:
|
if prefix:
|
||||||
# This may change in the future if some providers require the user
|
logger.warning("Prefix provided for cloud model, which is not supported")
|
||||||
# to manually append a prefix but this is not the case currently
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Prefix string is not valid for cloud models. "
|
"Prefix string is not valid for cloud models. "
|
||||||
"Cloud models take an explicit text type instead."
|
"Cloud models take an explicit text type instead."
|
||||||
@ -274,14 +307,15 @@ def embed_text(
|
|||||||
text_type=text_type,
|
text_type=text_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for None values in embeddings
|
|
||||||
if any(embedding is None for embedding in embeddings):
|
if any(embedding is None for embedding in embeddings):
|
||||||
error_message = "Embeddings contain None values\n"
|
error_message = "Embeddings contain None values\n"
|
||||||
error_message += "Corresponding texts:\n"
|
error_message += "Corresponding texts:\n"
|
||||||
error_message += "\n".join(texts)
|
error_message += "\n".join(texts)
|
||||||
|
logger.error(error_message)
|
||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
|
|
||||||
elif model_name is not None:
|
elif model_name is not None:
|
||||||
|
logger.debug(f"Using local model {model_name} for embedding")
|
||||||
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||||
|
|
||||||
local_model = get_embedding_model(
|
local_model = get_embedding_model(
|
||||||
@ -296,10 +330,12 @@ def embed_text(
|
|||||||
]
|
]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
logger.error("Neither model name nor provider specified for embedding")
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Either model name or provider must be provided to run embeddings."
|
"Either model name or provider must be provided to run embeddings."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(f"Successfully embedded {len(texts)} texts")
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
@ -344,6 +380,7 @@ async def process_embed_request(
|
|||||||
api_key=embed_request.api_key,
|
api_key=embed_request.api_key,
|
||||||
provider_type=embed_request.provider_type,
|
provider_type=embed_request.provider_type,
|
||||||
text_type=embed_request.text_type,
|
text_type=embed_request.text_type,
|
||||||
|
api_url=embed_request.api_url,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
)
|
)
|
||||||
return EmbedResponse(embeddings=embeddings)
|
return EmbedResponse(embeddings=embeddings)
|
||||||
|
@ -61,6 +61,7 @@ PRESERVED_SEARCH_FIELDS = [
|
|||||||
"provider_type",
|
"provider_type",
|
||||||
"api_key",
|
"api_key",
|
||||||
"model_name",
|
"model_name",
|
||||||
|
"api_url",
|
||||||
"index_name",
|
"index_name",
|
||||||
"multipass_indexing",
|
"multipass_indexing",
|
||||||
"model_dim",
|
"model_dim",
|
||||||
|
@ -6,6 +6,7 @@ class EmbeddingProvider(str, Enum):
|
|||||||
COHERE = "cohere"
|
COHERE = "cohere"
|
||||||
VOYAGE = "voyage"
|
VOYAGE = "voyage"
|
||||||
GOOGLE = "google"
|
GOOGLE = "google"
|
||||||
|
LITELLM = "litellm"
|
||||||
|
|
||||||
|
|
||||||
class RerankerProvider(str, Enum):
|
class RerankerProvider(str, Enum):
|
||||||
|
@ -18,6 +18,7 @@ class EmbedRequest(BaseModel):
|
|||||||
text_type: EmbedTextType
|
text_type: EmbedTextType
|
||||||
manual_query_prefix: str | None = None
|
manual_query_prefix: str | None = None
|
||||||
manual_passage_prefix: str | None = None
|
manual_passage_prefix: str | None = None
|
||||||
|
api_url: str | None = None
|
||||||
|
|
||||||
# This disables the "model_" protected namespace for pydantic
|
# This disables the "model_" protected namespace for pydantic
|
||||||
model_config = {"protected_namespaces": ()}
|
model_config = {"protected_namespaces": ()}
|
||||||
|
@ -32,6 +32,7 @@ def openai_embedding_model() -> EmbeddingModel:
|
|||||||
passage_prefix=None,
|
passage_prefix=None,
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
provider_type=EmbeddingProvider.OPENAI,
|
provider_type=EmbeddingProvider.OPENAI,
|
||||||
|
api_url=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -51,6 +52,7 @@ def cohere_embedding_model() -> EmbeddingModel:
|
|||||||
passage_prefix=None,
|
passage_prefix=None,
|
||||||
api_key=os.getenv("COHERE_API_KEY"),
|
api_key=os.getenv("COHERE_API_KEY"),
|
||||||
provider_type=EmbeddingProvider.COHERE,
|
provider_type=EmbeddingProvider.COHERE,
|
||||||
|
api_url=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -70,6 +72,7 @@ def local_nomic_embedding_model() -> EmbeddingModel:
|
|||||||
passage_prefix="search_document: ",
|
passage_prefix="search_document: ",
|
||||||
api_key=None,
|
api_key=None,
|
||||||
provider_type=None,
|
provider_type=None,
|
||||||
|
api_url=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
BIN
web/public/LiteLLM.jpg
Normal file
BIN
web/public/LiteLLM.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 12 KiB |
@ -2,3 +2,5 @@ export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
|
|||||||
|
|
||||||
export const EMBEDDING_PROVIDERS_ADMIN_URL =
|
export const EMBEDDING_PROVIDERS_ADMIN_URL =
|
||||||
"/api/admin/embedding/embedding-provider";
|
"/api/admin/embedding/embedding-provider";
|
||||||
|
|
||||||
|
export const EMBEDDING_MODELS_ADMIN_URL = "/api/admin/embedding";
|
||||||
|
@ -24,10 +24,14 @@ import { ChangeCredentialsModal } from "./modals/ChangeCredentialsModal";
|
|||||||
import { ModelSelectionConfirmationModal } from "./modals/ModelSelectionModal";
|
import { ModelSelectionConfirmationModal } from "./modals/ModelSelectionModal";
|
||||||
import { AlreadyPickedModal } from "./modals/AlreadyPickedModal";
|
import { AlreadyPickedModal } from "./modals/AlreadyPickedModal";
|
||||||
import { ModelOption } from "../../../components/embedding/ModelSelector";
|
import { ModelOption } from "../../../components/embedding/ModelSelector";
|
||||||
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../configuration/llm/constants";
|
import {
|
||||||
|
EMBEDDING_MODELS_ADMIN_URL,
|
||||||
|
EMBEDDING_PROVIDERS_ADMIN_URL,
|
||||||
|
} from "../configuration/llm/constants";
|
||||||
|
|
||||||
export interface EmbeddingDetails {
|
export interface EmbeddingDetails {
|
||||||
api_key: string;
|
api_key?: string;
|
||||||
|
api_url?: string;
|
||||||
custom_config: any;
|
custom_config: any;
|
||||||
provider_type: EmbeddingProvider;
|
provider_type: EmbeddingProvider;
|
||||||
}
|
}
|
||||||
@ -77,12 +81,20 @@ export function EmbeddingModelSelection({
|
|||||||
|
|
||||||
const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] =
|
const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] =
|
||||||
useState<boolean>(false);
|
useState<boolean>(false);
|
||||||
|
|
||||||
const [showAddConnectorPopup, setShowAddConnectorPopup] =
|
const [showAddConnectorPopup, setShowAddConnectorPopup] =
|
||||||
useState<boolean>(false);
|
useState<boolean>(false);
|
||||||
|
|
||||||
|
const { data: embeddingModelDetails } = useSWR<CloudEmbeddingModel[]>(
|
||||||
|
EMBEDDING_MODELS_ADMIN_URL,
|
||||||
|
errorHandlingFetcher,
|
||||||
|
{ refreshInterval: 5000 } // 5 seconds
|
||||||
|
);
|
||||||
|
|
||||||
const { data: embeddingProviderDetails } = useSWR<EmbeddingDetails[]>(
|
const { data: embeddingProviderDetails } = useSWR<EmbeddingDetails[]>(
|
||||||
EMBEDDING_PROVIDERS_ADMIN_URL,
|
EMBEDDING_PROVIDERS_ADMIN_URL,
|
||||||
errorHandlingFetcher
|
errorHandlingFetcher,
|
||||||
|
{ refreshInterval: 5000 } // 5 seconds
|
||||||
);
|
);
|
||||||
|
|
||||||
const { data: connectors } = useSWR<Connector<any>[]>(
|
const { data: connectors } = useSWR<Connector<any>[]>(
|
||||||
@ -175,6 +187,7 @@ export function EmbeddingModelSelection({
|
|||||||
|
|
||||||
{showTentativeProvider && (
|
{showTentativeProvider && (
|
||||||
<ProviderCreationModal
|
<ProviderCreationModal
|
||||||
|
isProxy={showTentativeProvider.provider_type == "LiteLLM"}
|
||||||
selectedProvider={showTentativeProvider}
|
selectedProvider={showTentativeProvider}
|
||||||
onConfirm={() => {
|
onConfirm={() => {
|
||||||
setShowTentativeProvider(showUnconfiguredProvider);
|
setShowTentativeProvider(showUnconfiguredProvider);
|
||||||
@ -189,8 +202,10 @@ export function EmbeddingModelSelection({
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{changeCredentialsProvider && (
|
{changeCredentialsProvider && (
|
||||||
<ChangeCredentialsModal
|
<ChangeCredentialsModal
|
||||||
|
isProxy={changeCredentialsProvider.provider_type == "LiteLLM"}
|
||||||
useFileUpload={changeCredentialsProvider.provider_type == "Google"}
|
useFileUpload={changeCredentialsProvider.provider_type == "Google"}
|
||||||
onDeleted={() => {
|
onDeleted={() => {
|
||||||
clientsideRemoveProvider(changeCredentialsProvider);
|
clientsideRemoveProvider(changeCredentialsProvider);
|
||||||
@ -277,6 +292,7 @@ export function EmbeddingModelSelection({
|
|||||||
|
|
||||||
{modelTab == "cloud" && (
|
{modelTab == "cloud" && (
|
||||||
<CloudEmbeddingPage
|
<CloudEmbeddingPage
|
||||||
|
embeddingModelDetails={embeddingModelDetails}
|
||||||
setShowModelInQueue={setShowModelInQueue}
|
setShowModelInQueue={setShowModelInQueue}
|
||||||
setShowTentativeModel={setShowTentativeModel}
|
setShowTentativeModel={setShowTentativeModel}
|
||||||
currentModel={selectedProvider}
|
currentModel={selectedProvider}
|
||||||
|
@ -21,6 +21,7 @@ export interface AdvancedSearchConfiguration {
|
|||||||
multipass_indexing: boolean;
|
multipass_indexing: boolean;
|
||||||
multilingual_expansion: string[];
|
multilingual_expansion: string[];
|
||||||
disable_rerank_for_streaming: boolean;
|
disable_rerank_for_streaming: boolean;
|
||||||
|
api_url: string | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface SavedSearchSettings extends RerankingDetails {
|
export interface SavedSearchSettings extends RerankingDetails {
|
||||||
@ -33,6 +34,7 @@ export interface SavedSearchSettings extends RerankingDetails {
|
|||||||
multipass_indexing: boolean;
|
multipass_indexing: boolean;
|
||||||
multilingual_expansion: string[];
|
multilingual_expansion: string[];
|
||||||
disable_rerank_for_streaming: boolean;
|
disable_rerank_for_streaming: boolean;
|
||||||
|
api_url: string | null;
|
||||||
provider_type: EmbeddingProvider | null;
|
provider_type: EmbeddingProvider | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,14 +15,16 @@ export function ChangeCredentialsModal({
|
|||||||
onCancel,
|
onCancel,
|
||||||
onDeleted,
|
onDeleted,
|
||||||
useFileUpload,
|
useFileUpload,
|
||||||
|
isProxy = false,
|
||||||
}: {
|
}: {
|
||||||
provider: CloudEmbeddingProvider;
|
provider: CloudEmbeddingProvider;
|
||||||
onConfirm: () => void;
|
onConfirm: () => void;
|
||||||
onCancel: () => void;
|
onCancel: () => void;
|
||||||
onDeleted: () => void;
|
onDeleted: () => void;
|
||||||
useFileUpload: boolean;
|
useFileUpload: boolean;
|
||||||
|
isProxy?: boolean;
|
||||||
}) {
|
}) {
|
||||||
const [apiKey, setApiKey] = useState("");
|
const [apiKeyOrUrl, setApiKeyOrUrl] = useState("");
|
||||||
const [testError, setTestError] = useState<string>("");
|
const [testError, setTestError] = useState<string>("");
|
||||||
const [fileName, setFileName] = useState<string>("");
|
const [fileName, setFileName] = useState<string>("");
|
||||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||||
@ -50,7 +52,7 @@ export function ChangeCredentialsModal({
|
|||||||
let jsonContent;
|
let jsonContent;
|
||||||
try {
|
try {
|
||||||
jsonContent = JSON.parse(fileContent);
|
jsonContent = JSON.parse(fileContent);
|
||||||
setApiKey(JSON.stringify(jsonContent));
|
setApiKeyOrUrl(JSON.stringify(jsonContent));
|
||||||
} catch (parseError) {
|
} catch (parseError) {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
"Failed to parse JSON file. Please ensure it's a valid JSON."
|
"Failed to parse JSON file. Please ensure it's a valid JSON."
|
||||||
@ -62,7 +64,7 @@ export function ChangeCredentialsModal({
|
|||||||
? error.message
|
? error.message
|
||||||
: "An unknown error occurred while processing the file."
|
: "An unknown error occurred while processing the file."
|
||||||
);
|
);
|
||||||
setApiKey("");
|
setApiKeyOrUrl("");
|
||||||
clearFileInput();
|
clearFileInput();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -74,7 +76,7 @@ export function ChangeCredentialsModal({
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await fetch(
|
const response = await fetch(
|
||||||
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type}`,
|
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type.toLowerCase()}`,
|
||||||
{
|
{
|
||||||
method: "DELETE",
|
method: "DELETE",
|
||||||
}
|
}
|
||||||
@ -105,7 +107,10 @@ export function ChangeCredentialsModal({
|
|||||||
headers: { "Content-Type": "application/json" },
|
headers: { "Content-Type": "application/json" },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
|
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
|
||||||
api_key: apiKey,
|
[isProxy ? "api_url" : "api_key"]: apiKeyOrUrl,
|
||||||
|
[isProxy ? "api_key" : "api_url"]: isProxy
|
||||||
|
? provider.api_key
|
||||||
|
: provider.api_url,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -119,7 +124,7 @@ export function ChangeCredentialsModal({
|
|||||||
headers: { "Content-Type": "application/json" },
|
headers: { "Content-Type": "application/json" },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
|
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
|
||||||
api_key: apiKey,
|
[isProxy ? "api_url" : "api_key"]: apiKeyOrUrl,
|
||||||
is_default_provider: false,
|
is_default_provider: false,
|
||||||
is_configured: true,
|
is_configured: true,
|
||||||
}),
|
}),
|
||||||
@ -128,7 +133,8 @@ export function ChangeCredentialsModal({
|
|||||||
if (!updateResponse.ok) {
|
if (!updateResponse.ok) {
|
||||||
const errorData = await updateResponse.json();
|
const errorData = await updateResponse.json();
|
||||||
throw new Error(
|
throw new Error(
|
||||||
errorData.detail || "Failed to update provider- check your API key"
|
errorData.detail ||
|
||||||
|
`Failed to update provider- check your ${isProxy ? "API URL" : "API key"}`
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,12 +150,12 @@ export function ChangeCredentialsModal({
|
|||||||
<Modal
|
<Modal
|
||||||
width="max-w-3xl"
|
width="max-w-3xl"
|
||||||
icon={provider.icon}
|
icon={provider.icon}
|
||||||
title={`Modify your ${provider.provider_type} key`}
|
title={`Modify your ${provider.provider_type} ${isProxy ? "URL" : "key"}`}
|
||||||
onOutsideClick={onCancel}
|
onOutsideClick={onCancel}
|
||||||
>
|
>
|
||||||
<div className="mb-4">
|
<div className="mb-4">
|
||||||
<Subtitle className="font-bold text-lg">
|
<Subtitle className="font-bold text-lg">
|
||||||
Want to swap out your key?
|
Want to swap out your {isProxy ? "URL" : "key"}?
|
||||||
</Subtitle>
|
</Subtitle>
|
||||||
<a
|
<a
|
||||||
href={provider.apiLink}
|
href={provider.apiLink}
|
||||||
@ -185,9 +191,9 @@ export function ChangeCredentialsModal({
|
|||||||
px-3
|
px-3
|
||||||
bg-background-emphasis
|
bg-background-emphasis
|
||||||
`}
|
`}
|
||||||
value={apiKey}
|
value={apiKeyOrUrl}
|
||||||
onChange={(e: any) => setApiKey(e.target.value)}
|
onChange={(e: any) => setApiKeyOrUrl(e.target.value)}
|
||||||
placeholder="Paste your API key here"
|
placeholder={`Paste your ${isProxy ? "API URL" : "API key"} here`}
|
||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
@ -203,15 +209,15 @@ export function ChangeCredentialsModal({
|
|||||||
<Button
|
<Button
|
||||||
color="blue"
|
color="blue"
|
||||||
onClick={() => handleSubmit()}
|
onClick={() => handleSubmit()}
|
||||||
disabled={!apiKey}
|
disabled={!apiKeyOrUrl}
|
||||||
>
|
>
|
||||||
Swap Key
|
Swap {isProxy ? "URL" : "Key"}
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
<Divider />
|
<Divider />
|
||||||
|
|
||||||
<Subtitle className="mt-4 font-bold text-lg mb-2">
|
<Subtitle className="mt-4 font-bold text-lg mb-2">
|
||||||
You can also delete your key.
|
You can also delete your {isProxy ? "URL" : "key"}.
|
||||||
</Subtitle>
|
</Subtitle>
|
||||||
<Text className="mb-2">
|
<Text className="mb-2">
|
||||||
This is only possible if you have already switched to a different
|
This is only possible if you have already switched to a different
|
||||||
@ -219,7 +225,7 @@ export function ChangeCredentialsModal({
|
|||||||
</Text>
|
</Text>
|
||||||
|
|
||||||
<Button onClick={handleDelete} color="red">
|
<Button onClick={handleDelete} color="red">
|
||||||
Delete key
|
Delete {isProxy ? "URL" : "key"}
|
||||||
</Button>
|
</Button>
|
||||||
{deletionError && (
|
{deletionError && (
|
||||||
<Callout title="Error" color="red" className="mt-4">
|
<Callout title="Error" color="red" className="mt-4">
|
||||||
|
@ -13,11 +13,13 @@ export function ProviderCreationModal({
|
|||||||
onConfirm,
|
onConfirm,
|
||||||
onCancel,
|
onCancel,
|
||||||
existingProvider,
|
existingProvider,
|
||||||
|
isProxy,
|
||||||
}: {
|
}: {
|
||||||
selectedProvider: CloudEmbeddingProvider;
|
selectedProvider: CloudEmbeddingProvider;
|
||||||
onConfirm: () => void;
|
onConfirm: () => void;
|
||||||
onCancel: () => void;
|
onCancel: () => void;
|
||||||
existingProvider?: CloudEmbeddingProvider;
|
existingProvider?: CloudEmbeddingProvider;
|
||||||
|
isProxy?: boolean;
|
||||||
}) {
|
}) {
|
||||||
const useFileUpload = selectedProvider.provider_type == "Google";
|
const useFileUpload = selectedProvider.provider_type == "Google";
|
||||||
|
|
||||||
@ -29,6 +31,7 @@ export function ProviderCreationModal({
|
|||||||
provider_type:
|
provider_type:
|
||||||
existingProvider?.provider_type || selectedProvider.provider_type,
|
existingProvider?.provider_type || selectedProvider.provider_type,
|
||||||
api_key: existingProvider?.api_key || "",
|
api_key: existingProvider?.api_key || "",
|
||||||
|
api_url: existingProvider?.api_url || "",
|
||||||
custom_config: existingProvider?.custom_config
|
custom_config: existingProvider?.custom_config
|
||||||
? Object.entries(existingProvider.custom_config)
|
? Object.entries(existingProvider.custom_config)
|
||||||
: [],
|
: [],
|
||||||
@ -37,9 +40,14 @@ export function ProviderCreationModal({
|
|||||||
|
|
||||||
const validationSchema = Yup.object({
|
const validationSchema = Yup.object({
|
||||||
provider_type: Yup.string().required("Provider type is required"),
|
provider_type: Yup.string().required("Provider type is required"),
|
||||||
api_key: useFileUpload
|
api_key: isProxy
|
||||||
? Yup.string()
|
? Yup.string()
|
||||||
: Yup.string().required("API Key is required"),
|
: useFileUpload
|
||||||
|
? Yup.string()
|
||||||
|
: Yup.string().required("API Key is required"),
|
||||||
|
api_url: isProxy
|
||||||
|
? Yup.string().required("API URL is required")
|
||||||
|
: Yup.string(),
|
||||||
custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)),
|
custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)),
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -87,6 +95,7 @@ export function ProviderCreationModal({
|
|||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
provider_type: values.provider_type.toLowerCase().split(" ")[0],
|
provider_type: values.provider_type.toLowerCase().split(" ")[0],
|
||||||
api_key: values.api_key,
|
api_key: values.api_key,
|
||||||
|
api_url: values.api_url,
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
@ -169,12 +178,19 @@ export function ProviderCreationModal({
|
|||||||
target="_blank"
|
target="_blank"
|
||||||
href={selectedProvider.apiLink}
|
href={selectedProvider.apiLink}
|
||||||
>
|
>
|
||||||
API KEY
|
{isProxy ? "API URL" : "API KEY"}
|
||||||
</a>
|
</a>
|
||||||
</Text>
|
</Text>
|
||||||
|
|
||||||
<div className="flex w-full flex-col gap-y-2">
|
<div className="flex w-full flex-col gap-y-2">
|
||||||
{useFileUpload ? (
|
{isProxy ? (
|
||||||
|
<TextFormField
|
||||||
|
name="api_url"
|
||||||
|
label="API URL"
|
||||||
|
placeholder="API URL"
|
||||||
|
type="text"
|
||||||
|
/>
|
||||||
|
) : useFileUpload ? (
|
||||||
<>
|
<>
|
||||||
<Label>Upload JSON File</Label>
|
<Label>Upload JSON File</Label>
|
||||||
<input
|
<input
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { Text, Title } from "@tremor/react";
|
import { Button, Card, Text, Title } from "@tremor/react";
|
||||||
|
|
||||||
import {
|
import {
|
||||||
CloudEmbeddingProvider,
|
CloudEmbeddingProvider,
|
||||||
@ -8,15 +8,19 @@ import {
|
|||||||
AVAILABLE_CLOUD_PROVIDERS,
|
AVAILABLE_CLOUD_PROVIDERS,
|
||||||
CloudEmbeddingProviderFull,
|
CloudEmbeddingProviderFull,
|
||||||
EmbeddingModelDescriptor,
|
EmbeddingModelDescriptor,
|
||||||
|
EmbeddingProvider,
|
||||||
|
LITELLM_CLOUD_PROVIDER,
|
||||||
} from "../../../../components/embedding/interfaces";
|
} from "../../../../components/embedding/interfaces";
|
||||||
import { EmbeddingDetails } from "../EmbeddingModelSelectionForm";
|
import { EmbeddingDetails } from "../EmbeddingModelSelectionForm";
|
||||||
import { FiExternalLink, FiInfo } from "react-icons/fi";
|
import { FiExternalLink, FiInfo } from "react-icons/fi";
|
||||||
import { HoverPopup } from "@/components/HoverPopup";
|
import { HoverPopup } from "@/components/HoverPopup";
|
||||||
import { Dispatch, SetStateAction } from "react";
|
import { Dispatch, SetStateAction, useEffect, useState } from "react";
|
||||||
|
import { LiteLLMModelForm } from "@/components/embedding/LiteLLMModelForm";
|
||||||
|
|
||||||
export default function CloudEmbeddingPage({
|
export default function CloudEmbeddingPage({
|
||||||
currentModel,
|
currentModel,
|
||||||
embeddingProviderDetails,
|
embeddingProviderDetails,
|
||||||
|
embeddingModelDetails,
|
||||||
newEnabledProviders,
|
newEnabledProviders,
|
||||||
newUnenabledProviders,
|
newUnenabledProviders,
|
||||||
setShowTentativeProvider,
|
setShowTentativeProvider,
|
||||||
@ -30,6 +34,7 @@ export default function CloudEmbeddingPage({
|
|||||||
currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel;
|
currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel;
|
||||||
setAlreadySelectedModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
setAlreadySelectedModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||||
newUnenabledProviders: string[];
|
newUnenabledProviders: string[];
|
||||||
|
embeddingModelDetails?: CloudEmbeddingModel[];
|
||||||
embeddingProviderDetails?: EmbeddingDetails[];
|
embeddingProviderDetails?: EmbeddingDetails[];
|
||||||
newEnabledProviders: string[];
|
newEnabledProviders: string[];
|
||||||
setShowTentativeProvider: React.Dispatch<
|
setShowTentativeProvider: React.Dispatch<
|
||||||
@ -61,6 +66,17 @@ export default function CloudEmbeddingPage({
|
|||||||
))!),
|
))!),
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
const [liteLLMProvider, setLiteLLMProvider] = useState<
|
||||||
|
EmbeddingDetails | undefined
|
||||||
|
>(undefined);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const foundProvider = embeddingProviderDetails?.find(
|
||||||
|
(provider) =>
|
||||||
|
provider.provider_type === EmbeddingProvider.LITELLM.toLowerCase()
|
||||||
|
);
|
||||||
|
setLiteLLMProvider(foundProvider);
|
||||||
|
}, [embeddingProviderDetails]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
@ -122,6 +138,127 @@ export default function CloudEmbeddingPage({
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
|
|
||||||
|
<Text className="mt-6">
|
||||||
|
Alternatively, you can use a self-hosted model using the LiteLLM
|
||||||
|
proxy. This allows you to leverage various LLM providers through a
|
||||||
|
unified interface that you control.{" "}
|
||||||
|
<a
|
||||||
|
href="https://docs.litellm.ai/"
|
||||||
|
target="_blank"
|
||||||
|
rel="noopener noreferrer"
|
||||||
|
className="text-blue-500 hover:underline"
|
||||||
|
>
|
||||||
|
Learn more about LiteLLM
|
||||||
|
</a>
|
||||||
|
</Text>
|
||||||
|
|
||||||
|
<div key={LITELLM_CLOUD_PROVIDER.provider_type} className="mt-4 w-full">
|
||||||
|
<div className="flex items-center mb-2">
|
||||||
|
{LITELLM_CLOUD_PROVIDER.icon({ size: 40 })}
|
||||||
|
<h2 className="ml-2 mt-2 text-xl font-bold">
|
||||||
|
{LITELLM_CLOUD_PROVIDER.provider_type}{" "}
|
||||||
|
{LITELLM_CLOUD_PROVIDER.provider_type == "Cohere" &&
|
||||||
|
"(recommended)"}
|
||||||
|
</h2>
|
||||||
|
<HoverPopup
|
||||||
|
mainContent={
|
||||||
|
<FiInfo className="ml-2 mt-2 cursor-pointer" size={18} />
|
||||||
|
}
|
||||||
|
popupContent={
|
||||||
|
<div className="text-sm text-text-800 w-52">
|
||||||
|
<div className="my-auto">
|
||||||
|
{LITELLM_CLOUD_PROVIDER.description}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
style="dark"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="w-full flex flex-col items-start">
|
||||||
|
{!liteLLMProvider ? (
|
||||||
|
<button
|
||||||
|
onClick={() => setShowTentativeProvider(LITELLM_CLOUD_PROVIDER)}
|
||||||
|
className="mb-2 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 text-sm cursor-pointer"
|
||||||
|
>
|
||||||
|
Provide API URL
|
||||||
|
</button>
|
||||||
|
) : (
|
||||||
|
<button
|
||||||
|
onClick={() =>
|
||||||
|
setChangeCredentialsProvider(LITELLM_CLOUD_PROVIDER)
|
||||||
|
}
|
||||||
|
className="mb-2 hover:underline text-sm cursor-pointer"
|
||||||
|
>
|
||||||
|
Modify API URL
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{!liteLLMProvider && (
|
||||||
|
<Card className="mt-2 w-full max-w-4xl bg-gray-50 border border-gray-200">
|
||||||
|
<div className="p-4">
|
||||||
|
<Text className="text-lg font-semibold mb-2">
|
||||||
|
API URL Required
|
||||||
|
</Text>
|
||||||
|
<Text className="text-sm text-gray-600 mb-4">
|
||||||
|
Before you can add models, you need to provide an API URL
|
||||||
|
for your LiteLLM proxy. Click the "Provide API
|
||||||
|
URL" button above to set up your LiteLLM configuration.
|
||||||
|
</Text>
|
||||||
|
<div className="flex items-center">
|
||||||
|
<FiInfo className="text-blue-500 mr-2" size={18} />
|
||||||
|
<Text className="text-sm text-blue-500">
|
||||||
|
Once configured, you'll be able to add and manage
|
||||||
|
your LiteLLM models here.
|
||||||
|
</Text>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</Card>
|
||||||
|
)}
|
||||||
|
{liteLLMProvider && (
|
||||||
|
<>
|
||||||
|
<div className="flex mb-4 flex-wrap gap-4">
|
||||||
|
{embeddingModelDetails
|
||||||
|
?.filter(
|
||||||
|
(model) =>
|
||||||
|
model.provider_type ===
|
||||||
|
EmbeddingProvider.LITELLM.toLowerCase()
|
||||||
|
)
|
||||||
|
.map((model) => (
|
||||||
|
<CloudModelCard
|
||||||
|
key={model.model_name}
|
||||||
|
model={model}
|
||||||
|
provider={LITELLM_CLOUD_PROVIDER}
|
||||||
|
currentModel={currentModel}
|
||||||
|
setAlreadySelectedModel={setAlreadySelectedModel}
|
||||||
|
setShowTentativeModel={setShowTentativeModel}
|
||||||
|
setShowModelInQueue={setShowModelInQueue}
|
||||||
|
setShowTentativeProvider={setShowTentativeProvider}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Card
|
||||||
|
className={`mt-2 w-full max-w-4xl ${
|
||||||
|
currentModel.provider_type === EmbeddingProvider.LITELLM
|
||||||
|
? "border-2 border-blue-500"
|
||||||
|
: ""
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
<LiteLLMModelForm
|
||||||
|
provider={liteLLMProvider}
|
||||||
|
currentValues={
|
||||||
|
currentModel.provider_type === EmbeddingProvider.LITELLM
|
||||||
|
? (currentModel as CloudEmbeddingModel)
|
||||||
|
: null
|
||||||
|
}
|
||||||
|
setShowTentativeModel={setShowTentativeModel}
|
||||||
|
/>
|
||||||
|
</Card>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
@ -146,7 +283,9 @@ export function CloudModelCard({
|
|||||||
React.SetStateAction<CloudEmbeddingProvider | null>
|
React.SetStateAction<CloudEmbeddingProvider | null>
|
||||||
>;
|
>;
|
||||||
}) {
|
}) {
|
||||||
const enabled = model.model_name === currentModel.model_name;
|
const enabled =
|
||||||
|
model.model_name === currentModel.model_name &&
|
||||||
|
model.provider_type == currentModel.provider_type;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
@ -169,9 +308,12 @@ export function CloudModelCard({
|
|||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
<p className="text-sm text-gray-600 mb-2">{model.description}</p>
|
<p className="text-sm text-gray-600 mb-2">{model.description}</p>
|
||||||
<div className="text-xs text-gray-500 mb-2">
|
{model?.provider_type?.toLowerCase() !=
|
||||||
${model.pricePerMillion}/M tokens
|
EmbeddingProvider.LITELLM.toLowerCase() && (
|
||||||
</div>
|
<div className="text-xs text-gray-500 mb-2">
|
||||||
|
${model.pricePerMillion}/M tokens
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<div className="mt-3">
|
<div className="mt-3">
|
||||||
<button
|
<button
|
||||||
className={`w-full p-2 rounded-lg text-sm ${
|
className={`w-full p-2 rounded-lg text-sm ${
|
||||||
@ -182,7 +324,10 @@ export function CloudModelCard({
|
|||||||
onClick={() => {
|
onClick={() => {
|
||||||
if (enabled) {
|
if (enabled) {
|
||||||
setAlreadySelectedModel(model);
|
setAlreadySelectedModel(model);
|
||||||
} else if (provider.configured) {
|
} else if (
|
||||||
|
provider.configured ||
|
||||||
|
provider.provider_type === EmbeddingProvider.LITELLM
|
||||||
|
) {
|
||||||
setShowTentativeModel(model);
|
setShowTentativeModel(model);
|
||||||
} else {
|
} else {
|
||||||
setShowModelInQueue(model);
|
setShowModelInQueue(model);
|
||||||
|
@ -41,6 +41,7 @@ export default function EmbeddingForm() {
|
|||||||
multipass_indexing: true,
|
multipass_indexing: true,
|
||||||
multilingual_expansion: [],
|
multilingual_expansion: [],
|
||||||
disable_rerank_for_streaming: false,
|
disable_rerank_for_streaming: false,
|
||||||
|
api_url: null,
|
||||||
});
|
});
|
||||||
|
|
||||||
const [rerankingDetails, setRerankingDetails] = useState<RerankingDetails>({
|
const [rerankingDetails, setRerankingDetails] = useState<RerankingDetails>({
|
||||||
@ -116,6 +117,7 @@ export default function EmbeddingForm() {
|
|||||||
multilingual_expansion: searchSettings.multilingual_expansion,
|
multilingual_expansion: searchSettings.multilingual_expansion,
|
||||||
disable_rerank_for_streaming:
|
disable_rerank_for_streaming:
|
||||||
searchSettings.disable_rerank_for_streaming,
|
searchSettings.disable_rerank_for_streaming,
|
||||||
|
api_url: null,
|
||||||
});
|
});
|
||||||
setRerankingDetails({
|
setRerankingDetails({
|
||||||
rerank_api_key: searchSettings.rerank_api_key,
|
rerank_api_key: searchSettings.rerank_api_key,
|
||||||
|
@ -41,6 +41,7 @@ export function CustomModelForm({
|
|||||||
api_key: null,
|
api_key: null,
|
||||||
provider_type: null,
|
provider_type: null,
|
||||||
index_name: null,
|
index_name: null,
|
||||||
|
api_url: null,
|
||||||
});
|
});
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
@ -106,20 +107,19 @@ export function CustomModelForm({
|
|||||||
/>
|
/>
|
||||||
|
|
||||||
<BooleanFormField
|
<BooleanFormField
|
||||||
|
removeIndent
|
||||||
name="normalize"
|
name="normalize"
|
||||||
label="Normalize Embeddings"
|
label="Normalize Embeddings"
|
||||||
subtext="Whether or not to normalize the embeddings generated by the model. When in doubt, leave this checked."
|
subtext="Whether or not to normalize the embeddings generated by the model. When in doubt, leave this checked."
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<div className="flex mt-6">
|
<Button
|
||||||
<Button
|
type="submit"
|
||||||
type="submit"
|
disabled={isSubmitting}
|
||||||
disabled={isSubmitting}
|
className="w-64 mx-auto"
|
||||||
className="w-64 mx-auto"
|
>
|
||||||
>
|
Choose
|
||||||
Choose
|
</Button>
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</Form>
|
</Form>
|
||||||
)}
|
)}
|
||||||
</Formik>
|
</Formik>
|
||||||
|
116
web/src/components/embedding/LiteLLMModelForm.tsx
Normal file
116
web/src/components/embedding/LiteLLMModelForm.tsx
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
import { CloudEmbeddingModel, CloudEmbeddingProvider } from "./interfaces";
|
||||||
|
import { Formik, Form } from "formik";
|
||||||
|
import * as Yup from "yup";
|
||||||
|
import { TextFormField, BooleanFormField } from "../admin/connectors/Field";
|
||||||
|
import { Dispatch, SetStateAction } from "react";
|
||||||
|
import { Button, Text } from "@tremor/react";
|
||||||
|
import { EmbeddingDetails } from "@/app/admin/embeddings/EmbeddingModelSelectionForm";
|
||||||
|
|
||||||
|
export function LiteLLMModelForm({
|
||||||
|
setShowTentativeModel,
|
||||||
|
currentValues,
|
||||||
|
provider,
|
||||||
|
}: {
|
||||||
|
setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||||
|
currentValues: CloudEmbeddingModel | null;
|
||||||
|
provider: EmbeddingDetails;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div>
|
||||||
|
<Formik
|
||||||
|
initialValues={
|
||||||
|
currentValues || {
|
||||||
|
model_name: "",
|
||||||
|
model_dim: 768,
|
||||||
|
normalize: false,
|
||||||
|
query_prefix: "",
|
||||||
|
passage_prefix: "",
|
||||||
|
provider_type: "LiteLLM",
|
||||||
|
api_key: "",
|
||||||
|
enabled: true,
|
||||||
|
api_url: provider.api_url,
|
||||||
|
description: "",
|
||||||
|
index_name: "",
|
||||||
|
pricePerMillion: 0,
|
||||||
|
mtebScore: 0,
|
||||||
|
maxContext: 4096,
|
||||||
|
max_tokens: 1024,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
validationSchema={Yup.object().shape({
|
||||||
|
model_name: Yup.string().required("Model name is required"),
|
||||||
|
model_dim: Yup.number().required("Model dimension is required"),
|
||||||
|
normalize: Yup.boolean().required(),
|
||||||
|
query_prefix: Yup.string(),
|
||||||
|
passage_prefix: Yup.string(),
|
||||||
|
provider_type: Yup.string().required("Provider type is required"),
|
||||||
|
api_key: Yup.string().optional(),
|
||||||
|
enabled: Yup.boolean(),
|
||||||
|
api_url: Yup.string().required("API base URL is required"),
|
||||||
|
description: Yup.string(),
|
||||||
|
index_name: Yup.string().nullable(),
|
||||||
|
pricePerMillion: Yup.number(),
|
||||||
|
mtebScore: Yup.number(),
|
||||||
|
maxContext: Yup.number(),
|
||||||
|
max_tokens: Yup.number(),
|
||||||
|
})}
|
||||||
|
onSubmit={async (values) => {
|
||||||
|
setShowTentativeModel(values as CloudEmbeddingModel);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{({ isSubmitting }) => (
|
||||||
|
<Form>
|
||||||
|
<Text className="text-xl text-text-900 font-bold mb-4">
|
||||||
|
Add a new model to LiteLLM proxy at {provider.api_url}
|
||||||
|
</Text>
|
||||||
|
<TextFormField
|
||||||
|
name="model_name"
|
||||||
|
label="Model Name:"
|
||||||
|
subtext="The name of the LiteLLM model"
|
||||||
|
placeholder="e.g. 'all-MiniLM-L6-v2'"
|
||||||
|
autoCompleteDisabled={true}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<TextFormField
|
||||||
|
name="model_dim"
|
||||||
|
label="Model Dimension:"
|
||||||
|
subtext="The dimension of the model's embeddings"
|
||||||
|
placeholder="e.g. '1536'"
|
||||||
|
type="number"
|
||||||
|
autoCompleteDisabled={true}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<BooleanFormField
|
||||||
|
removeIndent
|
||||||
|
name="normalize"
|
||||||
|
label="Normalize"
|
||||||
|
subtext="Whether to normalize the embeddings"
|
||||||
|
/>
|
||||||
|
|
||||||
|
<TextFormField
|
||||||
|
name="query_prefix"
|
||||||
|
label="Query Prefix:"
|
||||||
|
subtext="Prefix for query embeddings"
|
||||||
|
autoCompleteDisabled={true}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<TextFormField
|
||||||
|
name="passage_prefix"
|
||||||
|
label="Passage Prefix:"
|
||||||
|
subtext="Prefix for passage embeddings"
|
||||||
|
autoCompleteDisabled={true}
|
||||||
|
/>
|
||||||
|
|
||||||
|
<Button
|
||||||
|
type="submit"
|
||||||
|
disabled={isSubmitting}
|
||||||
|
className="w-64 mx-auto"
|
||||||
|
>
|
||||||
|
Configure LiteLLM Model
|
||||||
|
</Button>
|
||||||
|
</Form>
|
||||||
|
)}
|
||||||
|
</Formik>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
@ -2,6 +2,7 @@ import {
|
|||||||
CohereIcon,
|
CohereIcon,
|
||||||
GoogleIcon,
|
GoogleIcon,
|
||||||
IconProps,
|
IconProps,
|
||||||
|
LiteLLMIcon,
|
||||||
MicrosoftIcon,
|
MicrosoftIcon,
|
||||||
NomicIcon,
|
NomicIcon,
|
||||||
OpenAIIcon,
|
OpenAIIcon,
|
||||||
@ -14,11 +15,13 @@ export enum EmbeddingProvider {
|
|||||||
COHERE = "Cohere",
|
COHERE = "Cohere",
|
||||||
VOYAGE = "Voyage",
|
VOYAGE = "Voyage",
|
||||||
GOOGLE = "Google",
|
GOOGLE = "Google",
|
||||||
|
LITELLM = "LiteLLM",
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface CloudEmbeddingProvider {
|
export interface CloudEmbeddingProvider {
|
||||||
provider_type: EmbeddingProvider;
|
provider_type: EmbeddingProvider;
|
||||||
api_key?: string;
|
api_key?: string;
|
||||||
|
api_url?: string;
|
||||||
custom_config?: Record<string, string>;
|
custom_config?: Record<string, string>;
|
||||||
docsLink?: string;
|
docsLink?: string;
|
||||||
|
|
||||||
@ -44,6 +47,7 @@ export interface EmbeddingModelDescriptor {
|
|||||||
provider_type: string | null;
|
provider_type: string | null;
|
||||||
description: string;
|
description: string;
|
||||||
api_key: string | null;
|
api_key: string | null;
|
||||||
|
api_url: string | null;
|
||||||
index_name: string | null;
|
index_name: string | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -70,7 +74,7 @@ export interface FullEmbeddingModelResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export interface CloudEmbeddingProviderFull extends CloudEmbeddingProvider {
|
export interface CloudEmbeddingProviderFull extends CloudEmbeddingProvider {
|
||||||
configured: boolean;
|
configured?: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||||
@ -87,6 +91,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
|||||||
index_name: "",
|
index_name: "",
|
||||||
provider_type: null,
|
provider_type: null,
|
||||||
api_key: null,
|
api_key: null,
|
||||||
|
api_url: null,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
model_name: "intfloat/e5-base-v2",
|
model_name: "intfloat/e5-base-v2",
|
||||||
@ -99,6 +104,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
|||||||
passage_prefix: "passage: ",
|
passage_prefix: "passage: ",
|
||||||
index_name: "",
|
index_name: "",
|
||||||
provider_type: null,
|
provider_type: null,
|
||||||
|
api_url: null,
|
||||||
api_key: null,
|
api_key: null,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -113,6 +119,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
|||||||
index_name: "",
|
index_name: "",
|
||||||
provider_type: null,
|
provider_type: null,
|
||||||
api_key: null,
|
api_key: null,
|
||||||
|
api_url: null,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
model_name: "intfloat/multilingual-e5-base",
|
model_name: "intfloat/multilingual-e5-base",
|
||||||
@ -126,6 +133,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
|||||||
index_name: "",
|
index_name: "",
|
||||||
provider_type: null,
|
provider_type: null,
|
||||||
api_key: null,
|
api_key: null,
|
||||||
|
api_url: null,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
model_name: "intfloat/multilingual-e5-small",
|
model_name: "intfloat/multilingual-e5-small",
|
||||||
@ -139,9 +147,19 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
|||||||
index_name: "",
|
index_name: "",
|
||||||
provider_type: null,
|
provider_type: null,
|
||||||
api_key: null,
|
api_key: null,
|
||||||
|
api_url: null,
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
|
export const LITELLM_CLOUD_PROVIDER: CloudEmbeddingProvider = {
|
||||||
|
provider_type: EmbeddingProvider.LITELLM,
|
||||||
|
website: "https://github.com/BerriAI/litellm",
|
||||||
|
icon: LiteLLMIcon,
|
||||||
|
description: "Open-source library to call LLM APIs using OpenAI format",
|
||||||
|
apiLink: "https://docs.litellm.ai/docs/proxy/quick_start",
|
||||||
|
embedding_models: [], // No default embedding models
|
||||||
|
};
|
||||||
|
|
||||||
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||||
{
|
{
|
||||||
provider_type: EmbeddingProvider.COHERE,
|
provider_type: EmbeddingProvider.COHERE,
|
||||||
@ -169,6 +187,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
|||||||
passage_prefix: "",
|
passage_prefix: "",
|
||||||
index_name: "",
|
index_name: "",
|
||||||
api_key: null,
|
api_key: null,
|
||||||
|
api_url: null,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
model_name: "embed-english-light-v3.0",
|
model_name: "embed-english-light-v3.0",
|
||||||
@ -185,6 +204,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
|||||||
passage_prefix: "",
|
passage_prefix: "",
|
||||||
index_name: "",
|
index_name: "",
|
||||||
api_key: null,
|
api_key: null,
|
||||||
|
api_url: null,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
@ -213,6 +233,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
|||||||
enabled: false,
|
enabled: false,
|
||||||
index_name: "",
|
index_name: "",
|
||||||
api_key: null,
|
api_key: null,
|
||||||
|
api_url: null,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
provider_type: EmbeddingProvider.OPENAI,
|
provider_type: EmbeddingProvider.OPENAI,
|
||||||
@ -229,6 +250,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
|||||||
maxContext: 8191,
|
maxContext: 8191,
|
||||||
index_name: "",
|
index_name: "",
|
||||||
api_key: null,
|
api_key: null,
|
||||||
|
api_url: null,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
@ -258,6 +280,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
|||||||
passage_prefix: "",
|
passage_prefix: "",
|
||||||
index_name: "",
|
index_name: "",
|
||||||
api_key: null,
|
api_key: null,
|
||||||
|
api_url: null,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
provider_type: EmbeddingProvider.GOOGLE,
|
provider_type: EmbeddingProvider.GOOGLE,
|
||||||
@ -273,6 +296,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
|||||||
passage_prefix: "",
|
passage_prefix: "",
|
||||||
index_name: "",
|
index_name: "",
|
||||||
api_key: null,
|
api_key: null,
|
||||||
|
api_url: null,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
@ -301,6 +325,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
|||||||
passage_prefix: "",
|
passage_prefix: "",
|
||||||
index_name: "",
|
index_name: "",
|
||||||
api_key: null,
|
api_key: null,
|
||||||
|
api_url: null,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
provider_type: EmbeddingProvider.VOYAGE,
|
provider_type: EmbeddingProvider.VOYAGE,
|
||||||
@ -317,6 +342,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
|||||||
passage_prefix: "",
|
passage_prefix: "",
|
||||||
index_name: "",
|
index_name: "",
|
||||||
api_key: null,
|
api_key: null,
|
||||||
|
api_url: null,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
@ -48,6 +48,7 @@ import jiraSVG from "../../../public/Jira.svg";
|
|||||||
import confluenceSVG from "../../../public/Confluence.svg";
|
import confluenceSVG from "../../../public/Confluence.svg";
|
||||||
import openAISVG from "../../../public/Openai.svg";
|
import openAISVG from "../../../public/Openai.svg";
|
||||||
import openSourceIcon from "../../../public/OpenSource.png";
|
import openSourceIcon from "../../../public/OpenSource.png";
|
||||||
|
import litellmIcon from "../../../public/LiteLLM.jpg";
|
||||||
|
|
||||||
import awsWEBP from "../../../public/Amazon.webp";
|
import awsWEBP from "../../../public/Amazon.webp";
|
||||||
import azureIcon from "../../../public/Azure.png";
|
import azureIcon from "../../../public/Azure.png";
|
||||||
@ -267,6 +268,20 @@ export const ColorSlackIcon = ({
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const LiteLLMIcon = ({
|
||||||
|
size = 16,
|
||||||
|
className = defaultTailwindCSS,
|
||||||
|
}: IconProps) => {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
style={{ width: `${size + 4}px`, height: `${size + 4}px` }}
|
||||||
|
className={`w-[${size + 4}px] h-[${size + 4}px] -m-0.5 ` + className}
|
||||||
|
>
|
||||||
|
<Image src={litellmIcon} alt="Logo" width="96" height="96" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
export const OpenSourceIcon = ({
|
export const OpenSourceIcon = ({
|
||||||
size = 16,
|
size = 16,
|
||||||
className = defaultTailwindCSS,
|
className = defaultTailwindCSS,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user