Simpler azure embedding (#2751)

* functional but janky

* nit

* adapt for azure

* nit

* minor updates

* nits

* nit

* nit

* ensure access to litellm

* k
This commit is contained in:
pablodanswer
2024-10-15 16:23:11 -07:00
committed by GitHub
parent 02cc211e91
commit e022e77b6d
20 changed files with 524 additions and 214 deletions

View File

@@ -0,0 +1,30 @@
"""add api_version and deployment_name to search settings
Revision ID: 5d12a446f5c0
Revises: e4334d5b33ba
Create Date: 2024-10-08 15:56:07.975636
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5d12a446f5c0"
down_revision = "e4334d5b33ba"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"embedding_provider", sa.Column("api_version", sa.String(), nullable=True)
)
op.add_column(
"embedding_provider", sa.Column("deployment_name", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("embedding_provider", "deployment_name")
op.drop_column("embedding_provider", "api_version")

View File

@@ -53,7 +53,6 @@ MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false" os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
) )
SESSION_EXPIRE_TIME_SECONDS = int( SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7 os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days ) # 7 days

View File

@@ -615,6 +615,7 @@ class SearchSettings(Base):
normalize: Mapped[bool] = mapped_column(Boolean) normalize: Mapped[bool] = mapped_column(Boolean)
query_prefix: Mapped[str | None] = mapped_column(String, nullable=True) query_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
passage_prefix: Mapped[str | None] = mapped_column(String, nullable=True) passage_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
status: Mapped[IndexModelStatus] = mapped_column( status: Mapped[IndexModelStatus] = mapped_column(
Enum(IndexModelStatus, native_enum=False) Enum(IndexModelStatus, native_enum=False)
) )
@@ -670,6 +671,20 @@ class SearchSettings(Base):
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\ return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>" cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
@property
def api_version(self) -> str | None:
return (
self.cloud_provider.api_version if self.cloud_provider is not None else None
)
@property
def deployment_name(self) -> str | None:
return (
self.cloud_provider.deployment_name
if self.cloud_provider is not None
else None
)
@property @property
def api_url(self) -> str | None: def api_url(self) -> str | None:
return self.cloud_provider.api_url if self.cloud_provider is not None else None return self.cloud_provider.api_url if self.cloud_provider is not None else None
@@ -1164,6 +1179,9 @@ class CloudEmbeddingProvider(Base):
) )
api_url: Mapped[str | None] = mapped_column(String, nullable=True) api_url: Mapped[str | None] = mapped_column(String, nullable=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString()) api_key: Mapped[str | None] = mapped_column(EncryptedString())
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
search_settings: Mapped[list["SearchSettings"]] = relationship( search_settings: Mapped[list["SearchSettings"]] = relationship(
"SearchSettings", "SearchSettings",
back_populates="cloud_provider", back_populates="cloud_provider",

View File

@@ -32,6 +32,8 @@ class IndexingEmbedder(ABC):
provider_type: EmbeddingProvider | None, provider_type: EmbeddingProvider | None,
api_key: str | None, api_key: str | None,
api_url: str | None, api_url: str | None,
api_version: str | None,
deployment_name: str | None,
heartbeat: Heartbeat | None, heartbeat: Heartbeat | None,
): ):
self.model_name = model_name self.model_name = model_name
@@ -41,6 +43,8 @@ class IndexingEmbedder(ABC):
self.provider_type = provider_type self.provider_type = provider_type
self.api_key = api_key self.api_key = api_key
self.api_url = api_url self.api_url = api_url
self.api_version = api_version
self.deployment_name = deployment_name
self.embedding_model = EmbeddingModel( self.embedding_model = EmbeddingModel(
model_name=model_name, model_name=model_name,
@@ -50,6 +54,8 @@ class IndexingEmbedder(ABC):
api_key=api_key, api_key=api_key,
provider_type=provider_type, provider_type=provider_type,
api_url=api_url, api_url=api_url,
api_version=api_version,
deployment_name=deployment_name,
# The below are globally set, this flow always uses the indexing one # The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST, server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT, server_port=INDEXING_MODEL_SERVER_PORT,
@@ -75,6 +81,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type: EmbeddingProvider | None = None, provider_type: EmbeddingProvider | None = None,
api_key: str | None = None, api_key: str | None = None,
api_url: str | None = None, api_url: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
heartbeat: Heartbeat | None = None, heartbeat: Heartbeat | None = None,
): ):
super().__init__( super().__init__(
@@ -85,6 +93,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type, provider_type,
api_key, api_key,
api_url, api_url,
api_version,
deployment_name,
heartbeat, heartbeat,
) )
@@ -193,5 +203,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type=search_settings.provider_type, provider_type=search_settings.provider_type,
api_key=search_settings.api_key, api_key=search_settings.api_key,
api_url=search_settings.api_url, api_url=search_settings.api_url,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
heartbeat=heartbeat, heartbeat=heartbeat,
) )

View File

@@ -97,6 +97,8 @@ class EmbeddingModel:
provider_type: EmbeddingProvider | None, provider_type: EmbeddingProvider | None,
retrim_content: bool = False, retrim_content: bool = False,
heartbeat: Heartbeat | None = None, heartbeat: Heartbeat | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
) -> None: ) -> None:
self.api_key = api_key self.api_key = api_key
self.provider_type = provider_type self.provider_type = provider_type
@@ -106,6 +108,8 @@ class EmbeddingModel:
self.model_name = model_name self.model_name = model_name
self.retrim_content = retrim_content self.retrim_content = retrim_content
self.api_url = api_url self.api_url = api_url
self.api_version = api_version
self.deployment_name = deployment_name
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type model_name=model_name, provider_type=provider_type
) )
@@ -157,6 +161,8 @@ class EmbeddingModel:
embed_request = EmbedRequest( embed_request = EmbedRequest(
model_name=self.model_name, model_name=self.model_name,
texts=text_batch, texts=text_batch,
api_version=self.api_version,
deployment_name=self.deployment_name,
max_context_length=max_seq_length, max_context_length=max_seq_length,
normalize_embeddings=self.normalize, normalize_embeddings=self.normalize,
api_key=self.api_key, api_key=self.api_key,
@@ -239,6 +245,8 @@ class EmbeddingModel:
provider_type=search_settings.provider_type, provider_type=search_settings.provider_type,
api_url=search_settings.api_url, api_url=search_settings.api_url,
retrim_content=retrim_content, retrim_content=retrim_content,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
) )

View File

@@ -43,6 +43,8 @@ def test_embedding_configuration(
api_url=test_llm_request.api_url, api_url=test_llm_request.api_url,
provider_type=test_llm_request.provider_type, provider_type=test_llm_request.provider_type,
model_name=test_llm_request.model_name, model_name=test_llm_request.model_name,
api_version=test_llm_request.api_version,
deployment_name=test_llm_request.deployment_name,
normalize=False, normalize=False,
query_prefix=None, query_prefix=None,
passage_prefix=None, passage_prefix=None,

View File

@@ -17,6 +17,8 @@ class TestEmbeddingRequest(BaseModel):
api_key: str | None = None api_key: str | None = None
api_url: str | None = None api_url: str | None = None
model_name: str | None = None model_name: str | None = None
api_version: str | None = None
deployment_name: str | None = None
# This disables the "model_" protected namespace for pydantic # This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()} model_config = {"protected_namespaces": ()}
@@ -26,6 +28,8 @@ class CloudEmbeddingProvider(BaseModel):
provider_type: EmbeddingProvider provider_type: EmbeddingProvider
api_key: str | None = None api_key: str | None = None
api_url: str | None = None api_url: str | None = None
api_version: str | None = None
deployment_name: str | None = None
@classmethod @classmethod
def from_request( def from_request(
@@ -35,6 +39,8 @@ class CloudEmbeddingProvider(BaseModel):
provider_type=cloud_provider_model.provider_type, provider_type=cloud_provider_model.provider_type,
api_key=cloud_provider_model.api_key, api_key=cloud_provider_model.api_key,
api_url=cloud_provider_model.api_url, api_url=cloud_provider_model.api_url,
api_version=cloud_provider_model.api_version,
deployment_name=cloud_provider_model.deployment_name,
) )
@@ -42,3 +48,5 @@ class CloudEmbeddingProviderCreationRequest(BaseModel):
provider_type: EmbeddingProvider provider_type: EmbeddingProvider
api_key: str | None = None api_key: str | None = None
api_url: str | None = None api_url: str | None = None
api_version: str | None = None
deployment_name: str | None = None

View File

@@ -10,6 +10,7 @@ from cohere import Client as CohereClient
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import HTTPException from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore from google.oauth2 import service_account # type: ignore
from litellm import embedding
from retry import retry from retry import retry
from sentence_transformers import CrossEncoder # type: ignore from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore from sentence_transformers import SentenceTransformer # type: ignore
@@ -54,7 +55,11 @@ _COHERE_MAX_INPUT_LEN = 96
def _initialize_client( def _initialize_client(
api_key: str, provider: EmbeddingProvider, model: str | None = None api_key: str,
provider: EmbeddingProvider,
model: str | None = None,
api_url: str | None = None,
api_version: str | None = None,
) -> Any: ) -> Any:
if provider == EmbeddingProvider.OPENAI: if provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=api_key, timeout=OPENAI_EMBEDDING_TIMEOUT) return openai.OpenAI(api_key=api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
@@ -69,6 +74,8 @@ def _initialize_client(
project_id = json.loads(api_key)["project_id"] project_id = json.loads(api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials) vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL) return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
elif provider == EmbeddingProvider.AZURE:
return {"api_key": api_key, "api_url": api_url, "api_version": api_version}
else: else:
raise ValueError(f"Unsupported provider: {provider}") raise ValueError(f"Unsupported provider: {provider}")
@@ -78,11 +85,15 @@ class CloudEmbedding:
self, self,
api_key: str, api_key: str,
provider: EmbeddingProvider, provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
# Only for Google as is needed on client setup # Only for Google as is needed on client setup
model: str | None = None, model: str | None = None,
) -> None: ) -> None:
self.provider = provider self.provider = provider
self.client = _initialize_client(api_key, self.provider, model) self.client = _initialize_client(
api_key, self.provider, model, api_url, api_version
)
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
if not model: if not model:
@@ -144,6 +155,18 @@ class CloudEmbedding:
) )
return response.embeddings return response.embeddings
def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]:
response = embedding(
model=model,
input=texts,
api_key=self.client["api_key"],
api_base=self.client["api_url"],
api_version=self.client["api_version"],
)
embeddings = [embedding["embedding"] for embedding in response.data]
return embeddings
def _embed_vertex( def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]: ) -> list[Embedding]:
@@ -169,10 +192,13 @@ class CloudEmbedding:
texts: list[str], texts: list[str],
text_type: EmbedTextType, text_type: EmbedTextType,
model_name: str | None = None, model_name: str | None = None,
deployment_name: str | None = None,
) -> list[Embedding]: ) -> list[Embedding]:
try: try:
if self.provider == EmbeddingProvider.OPENAI: if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name) return self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE: if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type) return self._embed_cohere(texts, model_name, embedding_type)
@@ -190,10 +216,14 @@ class CloudEmbedding:
@staticmethod @staticmethod
def create( def create(
api_key: str, provider: EmbeddingProvider, model: str | None = None api_key: str,
provider: EmbeddingProvider,
model: str | None = None,
api_url: str | None = None,
api_version: str | None = None,
) -> "CloudEmbedding": ) -> "CloudEmbedding":
logger.debug(f"Creating Embedding instance for provider: {provider}") logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, model) return CloudEmbedding(api_key, provider, model, api_url, api_version)
def get_embedding_model( def get_embedding_model(
@@ -260,12 +290,14 @@ def embed_text(
texts: list[str], texts: list[str],
text_type: EmbedTextType, text_type: EmbedTextType,
model_name: str | None, model_name: str | None,
deployment_name: str | None,
max_context_length: int, max_context_length: int,
normalize_embeddings: bool, normalize_embeddings: bool,
api_key: str | None, api_key: str | None,
provider_type: EmbeddingProvider | None, provider_type: EmbeddingProvider | None,
prefix: str | None, prefix: str | None,
api_url: str | None, api_url: str | None,
api_version: str | None,
) -> list[Embedding]: ) -> list[Embedding]:
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}") logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
@@ -307,11 +339,16 @@ def embed_text(
) )
cloud_model = CloudEmbedding( cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=model_name api_key=api_key,
provider=provider_type,
model=model_name,
api_url=api_url,
api_version=api_version,
) )
embeddings = cloud_model.embed( embeddings = cloud_model.embed(
texts=texts, texts=texts,
model_name=model_name, model_name=model_name,
deployment_name=deployment_name,
text_type=text_type, text_type=text_type,
) )
@@ -405,12 +442,14 @@ async def process_embed_request(
embeddings = embed_text( embeddings = embed_text(
texts=embed_request.texts, texts=embed_request.texts,
model_name=embed_request.model_name, model_name=embed_request.model_name,
deployment_name=embed_request.deployment_name,
max_context_length=embed_request.max_context_length, max_context_length=embed_request.max_context_length,
normalize_embeddings=embed_request.normalize_embeddings, normalize_embeddings=embed_request.normalize_embeddings,
api_key=embed_request.api_key, api_key=embed_request.api_key,
provider_type=embed_request.provider_type, provider_type=embed_request.provider_type,
text_type=embed_request.text_type, text_type=embed_request.text_type,
api_url=embed_request.api_url, api_url=embed_request.api_url,
api_version=embed_request.api_version,
prefix=prefix, prefix=prefix,
) )
return EmbedResponse(embeddings=embeddings) return EmbedResponse(embeddings=embeddings)

View File

@@ -12,3 +12,4 @@ torch==2.2.0
transformers==4.39.2 transformers==4.39.2
uvicorn==0.21.1 uvicorn==0.21.1
voyageai==0.2.3 voyageai==0.2.3
litellm==1.48.7

View File

@@ -7,6 +7,7 @@ class EmbeddingProvider(str, Enum):
VOYAGE = "voyage" VOYAGE = "voyage"
GOOGLE = "google" GOOGLE = "google"
LITELLM = "litellm" LITELLM = "litellm"
AZURE = "azure"
class RerankerProvider(str, Enum): class RerankerProvider(str, Enum):

View File

@@ -20,6 +20,7 @@ class EmbedRequest(BaseModel):
texts: list[str] texts: list[str]
# Can be none for cloud embedding model requests, error handling logic exists for other cases # Can be none for cloud embedding model requests, error handling logic exists for other cases
model_name: str | None = None model_name: str | None = None
deployment_name: str | None = None
max_context_length: int max_context_length: int
normalize_embeddings: bool normalize_embeddings: bool
api_key: str | None = None api_key: str | None = None
@@ -28,7 +29,7 @@ class EmbedRequest(BaseModel):
manual_query_prefix: str | None = None manual_query_prefix: str | None = None
manual_passage_prefix: str | None = None manual_passage_prefix: str | None = None
api_url: str | None = None api_url: str | None = None
api_version: str | None = None
# This disables the "model_" protected namespace for pydantic # This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()} model_config = {"protected_namespaces": ()}

View File

@@ -27,10 +27,13 @@ import {
EMBEDDING_MODELS_ADMIN_URL, EMBEDDING_MODELS_ADMIN_URL,
EMBEDDING_PROVIDERS_ADMIN_URL, EMBEDDING_PROVIDERS_ADMIN_URL,
} from "../configuration/llm/constants"; } from "../configuration/llm/constants";
import { AdvancedSearchConfiguration } from "./interfaces";
export interface EmbeddingDetails { export interface EmbeddingDetails {
api_key?: string; api_key?: string;
api_url?: string; api_url?: string;
api_version?: string;
deployment_name?: string;
custom_config: any; custom_config: any;
provider_type: EmbeddingProvider; provider_type: EmbeddingProvider;
} }
@@ -41,6 +44,8 @@ export function EmbeddingModelSelection({
updateSelectedProvider, updateSelectedProvider,
modelTab, modelTab,
setModelTab, setModelTab,
updateCurrentModel,
advancedEmbeddingDetails,
}: { }: {
modelTab: "open" | "cloud" | null; modelTab: "open" | "cloud" | null;
setModelTab: Dispatch<SetStateAction<"open" | "cloud" | null>>; setModelTab: Dispatch<SetStateAction<"open" | "cloud" | null>>;
@@ -49,6 +54,11 @@ export function EmbeddingModelSelection({
updateSelectedProvider: ( updateSelectedProvider: (
model: CloudEmbeddingModel | HostedEmbeddingModel model: CloudEmbeddingModel | HostedEmbeddingModel
) => void; ) => void;
updateCurrentModel: (
newModel: string,
provider_type: EmbeddingProvider
) => void;
advancedEmbeddingDetails: AdvancedSearchConfiguration;
}) { }) {
// Cloud Provider based modals // Cloud Provider based modals
const [showTentativeProvider, setShowTentativeProvider] = const [showTentativeProvider, setShowTentativeProvider] =
@@ -72,12 +82,6 @@ export function EmbeddingModelSelection({
const [showTentativeOpenProvider, setShowTentativeOpenProvider] = const [showTentativeOpenProvider, setShowTentativeOpenProvider] =
useState<HostedEmbeddingModel | null>(null); useState<HostedEmbeddingModel | null>(null);
// Enabled / unenabled providers
const [newEnabledProviders, setNewEnabledProviders] = useState<string[]>([]);
const [newUnenabledProviders, setNewUnenabledProviders] = useState<string[]>(
[]
);
const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] = const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] =
useState<boolean>(false); useState<boolean>(false);
@@ -90,7 +94,10 @@ export function EmbeddingModelSelection({
{ refreshInterval: 5000 } // 5 seconds { refreshInterval: 5000 } // 5 seconds
); );
const { data: embeddingProviderDetails } = useSWR<EmbeddingDetails[]>( const {
data: embeddingProviderDetails,
mutate: mutateEmbeddingProviderDetails,
} = useSWR<EmbeddingDetails[]>(
EMBEDDING_PROVIDERS_ADMIN_URL, EMBEDDING_PROVIDERS_ADMIN_URL,
errorHandlingFetcher, errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds { refreshInterval: 5000 } // 5 seconds
@@ -132,32 +139,6 @@ export function EmbeddingModelSelection({
} }
}; };
const clientsideAddProvider = (provider: CloudEmbeddingProvider) => {
const providerType = provider.provider_type;
setNewEnabledProviders((newEnabledProviders) => [
...newEnabledProviders,
providerType,
]);
setNewUnenabledProviders((newUnenabledProviders) =>
newUnenabledProviders.filter(
(givenProviderType) => givenProviderType != providerType
)
);
};
const clientsideRemoveProvider = (provider: CloudEmbeddingProvider) => {
const providerType = provider.provider_type;
setNewEnabledProviders((newEnabledProviders) =>
newEnabledProviders.filter(
(givenProviderType) => givenProviderType != providerType
)
);
setNewUnenabledProviders((newUnenabledProviders) => [
...newUnenabledProviders,
providerType,
]);
};
return ( return (
<div className="p-2"> <div className="p-2">
{alreadySelectedModel && ( {alreadySelectedModel && (
@@ -186,14 +167,16 @@ export function EmbeddingModelSelection({
{showTentativeProvider && ( {showTentativeProvider && (
<ProviderCreationModal <ProviderCreationModal
updateCurrentModel={updateCurrentModel}
isProxy={showTentativeProvider.provider_type == "LiteLLM"} isProxy={showTentativeProvider.provider_type == "LiteLLM"}
isAzure={showTentativeProvider.provider_type == "Azure"}
selectedProvider={showTentativeProvider} selectedProvider={showTentativeProvider}
onConfirm={() => { onConfirm={() => {
setShowTentativeProvider(showUnconfiguredProvider); setShowTentativeProvider(showUnconfiguredProvider);
clientsideAddProvider(showTentativeProvider);
if (showModelInQueue) { if (showModelInQueue) {
setShowTentativeModel(showModelInQueue); setShowTentativeModel(showModelInQueue);
} }
mutateEmbeddingProviderDetails();
}} }}
onCancel={() => { onCancel={() => {
setShowModelInQueue(null); setShowModelInQueue(null);
@@ -205,10 +188,11 @@ export function EmbeddingModelSelection({
{changeCredentialsProvider && ( {changeCredentialsProvider && (
<ChangeCredentialsModal <ChangeCredentialsModal
isProxy={changeCredentialsProvider.provider_type == "LiteLLM"} isProxy={changeCredentialsProvider.provider_type == "LiteLLM"}
isAzure={changeCredentialsProvider.provider_type == "Azure"}
useFileUpload={changeCredentialsProvider.provider_type == "Google"} useFileUpload={changeCredentialsProvider.provider_type == "Google"}
onDeleted={() => { onDeleted={() => {
clientsideRemoveProvider(changeCredentialsProvider);
setChangeCredentialsProvider(null); setChangeCredentialsProvider(null);
mutateEmbeddingProviderDetails();
}} }}
provider={changeCredentialsProvider} provider={changeCredentialsProvider}
onConfirm={() => setChangeCredentialsProvider(null)} onConfirm={() => setChangeCredentialsProvider(null)}
@@ -236,12 +220,13 @@ export function EmbeddingModelSelection({
modelProvider={showTentativeProvider!} modelProvider={showTentativeProvider!}
onConfirm={() => { onConfirm={() => {
setShowDeleteCredentialsModal(false); setShowDeleteCredentialsModal(false);
mutateEmbeddingProviderDetails();
}} }}
onCancel={() => setShowDeleteCredentialsModal(false)} onCancel={() => setShowDeleteCredentialsModal(false)}
/> />
)} )}
<p className="t mb-4"> <p className="mb-4">
Select from cloud, self-hosted models, or continue with your current Select from cloud, self-hosted models, or continue with your current
embedding model. embedding model.
</p> </p>
@@ -291,14 +276,13 @@ export function EmbeddingModelSelection({
{modelTab == "cloud" && ( {modelTab == "cloud" && (
<CloudEmbeddingPage <CloudEmbeddingPage
advancedEmbeddingDetails={advancedEmbeddingDetails}
embeddingModelDetails={embeddingModelDetails} embeddingModelDetails={embeddingModelDetails}
setShowModelInQueue={setShowModelInQueue} setShowModelInQueue={setShowModelInQueue}
setShowTentativeModel={setShowTentativeModel} setShowTentativeModel={setShowTentativeModel}
currentModel={selectedProvider || currentEmbeddingModel} currentModel={selectedProvider || currentEmbeddingModel}
setAlreadySelectedModel={setAlreadySelectedModel} setAlreadySelectedModel={setAlreadySelectedModel}
embeddingProviderDetails={embeddingProviderDetails} embeddingProviderDetails={embeddingProviderDetails}
newEnabledProviders={newEnabledProviders}
newUnenabledProviders={newUnenabledProviders}
setShowTentativeProvider={setShowTentativeProvider} setShowTentativeProvider={setShowTentativeProvider}
setChangeCredentialsProvider={setChangeCredentialsProvider} setChangeCredentialsProvider={setChangeCredentialsProvider}
/> />

View File

@@ -16,6 +16,7 @@ export function ChangeCredentialsModal({
onDeleted, onDeleted,
useFileUpload, useFileUpload,
isProxy = false, isProxy = false,
isAzure = false,
}: { }: {
provider: CloudEmbeddingProvider; provider: CloudEmbeddingProvider;
onConfirm: () => void; onConfirm: () => void;
@@ -23,6 +24,7 @@ export function ChangeCredentialsModal({
onDeleted: () => void; onDeleted: () => void;
useFileUpload: boolean; useFileUpload: boolean;
isProxy?: boolean; isProxy?: boolean;
isAzure?: boolean;
}) { }) {
const [apiKey, setApiKey] = useState(""); const [apiKey, setApiKey] = useState("");
const [apiUrl, setApiUrl] = useState(""); const [apiUrl, setApiUrl] = useState("");
@@ -151,7 +153,6 @@ export function ChangeCredentialsModal({
); );
} }
}; };
return ( return (
<Modal <Modal
width="max-w-3xl" width="max-w-3xl"
@@ -160,133 +161,131 @@ export function ChangeCredentialsModal({
onOutsideClick={onCancel} onOutsideClick={onCancel}
> >
<> <>
<p className="mb-4"> {!isAzure && (
You can modify your configuration by providing a new API key <>
{isProxy ? " or API URL." : "."} <p className="mb-4">
</p> You can modify your configuration by providing a new API key
{isProxy ? " or API URL." : "."}
</p>
<div className="mb-4 flex flex-col gap-y-2"> <div className="mb-4 flex flex-col gap-y-2">
<Label className="mt-2">API Key</Label> <Label className="mt-2">API Key</Label>
{useFileUpload ? ( {useFileUpload ? (
<> <>
<Label className="mt-2">Upload JSON File</Label> <Label className="mt-2">Upload JSON File</Label>
<input <input
ref={fileInputRef} ref={fileInputRef}
type="file" type="file"
accept=".json" accept=".json"
onChange={handleFileUpload} onChange={handleFileUpload}
className="text-lg w-full p-1" className="text-lg w-full p-1"
/> />
{fileName && <p>Uploaded file: {fileName}</p>} {fileName && <p>Uploaded file: {fileName}</p>}
</> </>
) : ( ) : (
<> <>
<input <input
className={` className={`
border border
border-border border-border
rounded rounded
w-full w-full
py-2 py-2
px-3 px-3
bg-background-emphasis bg-background-emphasis
`} `}
value={apiKey} value={apiKey}
onChange={(e: any) => setApiKey(e.target.value)} onChange={(e: any) => setApiKey(e.target.value)}
placeholder="Paste your API key here" placeholder="Paste your API key here"
/> />
</> </>
)} )}
{isProxy && ( {isProxy && (
<> <>
<Label className="mt-2">API URL</Label> <Label className="mt-2">API URL</Label>
<input <input
className={` className={`
border border
border-border border-border
rounded rounded
w-full w-full
py-2 py-2
px-3 px-3
bg-background-emphasis bg-background-emphasis
`} `}
value={apiUrl} value={apiUrl}
onChange={(e: any) => setApiUrl(e.target.value)} onChange={(e: any) => setApiUrl(e.target.value)}
placeholder="Paste your API URL here" placeholder="Paste your API URL here"
/> />
{deletionError && ( {deletionError && (
<Callout title="Error" color="red" className="mt-4"> <Callout title="Error" color="red" className="mt-4">
{deletionError} {deletionError}
</Callout>
)}
<div>
<Label className="mt-2">Test Model</Label>
<p>
Since you are using a liteLLM proxy, we&apos;ll need a
model name to test the connection with.
</p>
</div>
<input
className={`
border
border-border
rounded
w-full
py-2
px-3
bg-background-emphasis
`}
value={modelName}
onChange={(e: any) => setModelName(e.target.value)}
placeholder="Paste your model name here"
/>
</>
)}
{testError && (
<Callout title="Error" color="red" className="my-4">
{testError}
</Callout> </Callout>
)} )}
<div> <Button
<Label className="mt-2">Test Model</Label> className="mr-auto mt-4"
<p> color="blue"
Since you are using a liteLLM proxy, we&apos;ll need a model onClick={() => handleSubmit()}
name to test the connection with. disabled={!apiKey}
</p> >
</div> Update Configuration
<input </Button>
className={`
border
border-border
rounded
w-full
py-2
px-3
bg-background-emphasis
`}
value={modelName}
onChange={(e: any) => setModelName(e.target.value)}
placeholder="Paste your API URL here"
/>
{deletionError && ( <Divider />
<Callout title="Error" color="red" className="mt-4"> </div>
{deletionError} </>
</Callout> )}
)}
</>
)}
{testError && ( <Subtitle className="mt-4 font-bold text-lg mb-2">
<Callout title="Error" color="red" className="my-4"> You can delete your configuration.
{testError} </Subtitle>
</Callout> <Text className="mb-2">
)} This is only possible if you have already switched to a different
embedding type!
</Text>
<Button <Button className="mr-auto" onClick={handleDelete} color="red">
className="mr-auto mt-4" Delete Configuration
color="blue" </Button>
onClick={() => handleSubmit()} {deletionError && (
disabled={!apiKey} <Callout title="Error" color="red" className="mt-4">
> {deletionError}
Update Configuration </Callout>
</Button> )}
<Divider />
<Subtitle className="mt-4 font-bold text-lg mb-2">
You can also delete your configuration.
</Subtitle>
<Text className="mb-2">
This is only possible if you have already switched to a different
embedding type!
</Text>
<Button className="mr-auto" onClick={handleDelete} color="red">
Delete Configuration
</Button>
{deletionError && (
<Callout title="Error" color="red" className="mt-4">
{deletionError}
</Callout>
)}
</div>
</> </>
</Modal> </Modal>
); );

View File

@@ -4,7 +4,10 @@ import { Formik, Form } from "formik";
import * as Yup from "yup"; import * as Yup from "yup";
import { Label, TextFormField } from "@/components/admin/connectors/Field"; import { Label, TextFormField } from "@/components/admin/connectors/Field";
import { LoadingAnimation } from "@/components/Loading"; import { LoadingAnimation } from "@/components/Loading";
import { CloudEmbeddingProvider } from "../../../../components/embedding/interfaces"; import {
CloudEmbeddingProvider,
EmbeddingProvider,
} from "../../../../components/embedding/interfaces";
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../../configuration/llm/constants"; import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../../configuration/llm/constants";
import { Modal } from "@/components/Modal"; import { Modal } from "@/components/Modal";
@@ -14,12 +17,19 @@ export function ProviderCreationModal({
onCancel, onCancel,
existingProvider, existingProvider,
isProxy, isProxy,
isAzure,
updateCurrentModel,
}: { }: {
updateCurrentModel: (
newModel: string,
provider_type: EmbeddingProvider
) => void;
selectedProvider: CloudEmbeddingProvider; selectedProvider: CloudEmbeddingProvider;
onConfirm: () => void; onConfirm: () => void;
onCancel: () => void; onCancel: () => void;
existingProvider?: CloudEmbeddingProvider; existingProvider?: CloudEmbeddingProvider;
isProxy?: boolean; isProxy?: boolean;
isAzure?: boolean;
}) { }) {
const useFileUpload = selectedProvider.provider_type == "Google"; const useFileUpload = selectedProvider.provider_type == "Google";
@@ -41,16 +51,24 @@ export function ProviderCreationModal({
const validationSchema = Yup.object({ const validationSchema = Yup.object({
provider_type: Yup.string().required("Provider type is required"), provider_type: Yup.string().required("Provider type is required"),
api_key: isProxy api_key:
? Yup.string() isProxy || isAzure
: useFileUpload
? Yup.string() ? Yup.string()
: Yup.string().required("API Key is required"), : useFileUpload
? Yup.string()
: Yup.string().required("API Key is required"),
model_name: isProxy model_name: isProxy
? Yup.string().required("Model name is required") ? Yup.string().required("Model name is required")
: Yup.string().nullable(), : Yup.string().nullable(),
api_url: isProxy api_url:
? Yup.string().required("API URL is required") isProxy || isAzure
? Yup.string().required("API URL is required")
: Yup.string(),
deployment_name: isAzure
? Yup.string().required("Deployment name is required")
: Yup.string(),
api_version: isAzure
? Yup.string().required("API Version is required")
: Yup.string(), : Yup.string(),
custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)), custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)),
}); });
@@ -101,6 +119,8 @@ export function ProviderCreationModal({
api_key: values.api_key, api_key: values.api_key,
api_url: values.api_url, api_url: values.api_url,
model_name: values.model_name, model_name: values.model_name,
api_version: values.api_version,
deployment_name: values.deployment_name,
}), }),
} }
); );
@@ -118,6 +138,8 @@ export function ProviderCreationModal({
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
body: JSON.stringify({ body: JSON.stringify({
...values, ...values,
api_version: values.api_version,
deployment_name: values.deployment_name,
provider_type: values.provider_type.toLowerCase().split(" ")[0], provider_type: values.provider_type.toLowerCase().split(" ")[0],
custom_config: customConfig, custom_config: customConfig,
is_default_provider: false, is_default_provider: false,
@@ -125,6 +147,10 @@ export function ProviderCreationModal({
}), }),
}); });
if (isAzure) {
updateCurrentModel(values.model_name, EmbeddingProvider.AZURE);
}
if (!response.ok) { if (!response.ok) {
const errorData = await response.json(); const errorData = await response.json();
throw new Error( throw new Error(
@@ -178,26 +204,45 @@ export function ProviderCreationModal({
href={selectedProvider.apiLink} href={selectedProvider.apiLink}
rel="noreferrer" rel="noreferrer"
> >
{isProxy ? "API URL" : "API KEY"} {isProxy || isAzure ? "API URL" : "API KEY"}
</a> </a>
</Text> </Text>
<div className="flex w-full flex-col gap-y-6"> <div className="flex w-full flex-col gap-y-6">
{(isProxy || isAzure) && (
<TextFormField
name="api_url"
label="API URL"
placeholder="API URL"
type="text"
/>
)}
{isProxy && ( {isProxy && (
<> <TextFormField
<TextFormField name="model_name"
name="api_url" label={`Model Name ${isProxy ? "(for testing)" : ""}`}
label="API URL" placeholder="Model Name"
placeholder="API URL" type="text"
type="text" />
/> )}
<TextFormField
name="model_name" {isAzure && (
label="Model Name (for testing)" <TextFormField
placeholder="Model Name" name="deployment_name"
type="text" label="Deployment Name"
/> placeholder="Deployment Name"
</> type="text"
/>
)}
{isAzure && (
<TextFormField
name="api_version"
label="API Version"
placeholder="API Version"
type="text"
/>
)} )}
{useFileUpload ? ( {useFileUpload ? (

View File

@@ -10,42 +10,42 @@ import {
EmbeddingModelDescriptor, EmbeddingModelDescriptor,
EmbeddingProvider, EmbeddingProvider,
LITELLM_CLOUD_PROVIDER, LITELLM_CLOUD_PROVIDER,
AZURE_CLOUD_PROVIDER,
} from "../../../../components/embedding/interfaces"; } from "../../../../components/embedding/interfaces";
import { EmbeddingDetails } from "../EmbeddingModelSelectionForm"; import { EmbeddingDetails } from "../EmbeddingModelSelectionForm";
import { FiExternalLink, FiInfo, FiTrash } from "react-icons/fi"; import { FiExternalLink, FiInfo, FiTrash } from "react-icons/fi";
import { HoverPopup } from "@/components/HoverPopup"; import { HoverPopup } from "@/components/HoverPopup";
import { Dispatch, SetStateAction, useEffect, useState } from "react"; import { Dispatch, SetStateAction, useEffect, useState } from "react";
import { LiteLLMModelForm } from "@/components/embedding/LiteLLMModelForm"; import { CustomEmbeddingModelForm } from "@/components/embedding/CustomEmbeddingModelForm";
import { deleteSearchSettings } from "./utils"; import { deleteSearchSettings } from "./utils";
import { usePopup } from "@/components/admin/connectors/Popup"; import { usePopup } from "@/components/admin/connectors/Popup";
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal"; import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
import { AdvancedSearchConfiguration } from "../interfaces";
export default function CloudEmbeddingPage({ export default function CloudEmbeddingPage({
currentModel, currentModel,
embeddingProviderDetails, embeddingProviderDetails,
embeddingModelDetails, embeddingModelDetails,
newEnabledProviders,
newUnenabledProviders,
setShowTentativeProvider, setShowTentativeProvider,
setChangeCredentialsProvider, setChangeCredentialsProvider,
setAlreadySelectedModel, setAlreadySelectedModel,
setShowTentativeModel, setShowTentativeModel,
setShowModelInQueue, setShowModelInQueue,
advancedEmbeddingDetails,
}: { }: {
setShowModelInQueue: Dispatch<SetStateAction<CloudEmbeddingModel | null>>; setShowModelInQueue: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>; setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel; currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel;
setAlreadySelectedModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>; setAlreadySelectedModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
newUnenabledProviders: string[];
embeddingModelDetails?: CloudEmbeddingModel[]; embeddingModelDetails?: CloudEmbeddingModel[];
embeddingProviderDetails?: EmbeddingDetails[]; embeddingProviderDetails?: EmbeddingDetails[];
newEnabledProviders: string[];
setShowTentativeProvider: React.Dispatch< setShowTentativeProvider: React.Dispatch<
React.SetStateAction<CloudEmbeddingProvider | null> React.SetStateAction<CloudEmbeddingProvider | null>
>; >;
setChangeCredentialsProvider: React.Dispatch< setChangeCredentialsProvider: React.Dispatch<
React.SetStateAction<CloudEmbeddingProvider | null> React.SetStateAction<CloudEmbeddingProvider | null>
>; >;
advancedEmbeddingDetails: AdvancedSearchConfiguration;
}) { }) {
function hasProviderTypeinArray( function hasProviderTypeinArray(
arr: Array<{ provider_type: string }>, arr: Array<{ provider_type: string }>,
@@ -60,27 +60,38 @@ export default function CloudEmbeddingPage({
(model) => ({ (model) => ({
...model, ...model,
configured: configured:
!newUnenabledProviders.includes(model.provider_type) && embeddingProviderDetails &&
(newEnabledProviders.includes(model.provider_type) || hasProviderTypeinArray(embeddingProviderDetails, model.provider_type),
(embeddingProviderDetails &&
hasProviderTypeinArray(
embeddingProviderDetails,
model.provider_type
))!),
}) })
); );
const [liteLLMProvider, setLiteLLMProvider] = useState< const [liteLLMProvider, setLiteLLMProvider] = useState<
EmbeddingDetails | undefined EmbeddingDetails | undefined
>(undefined); >(undefined);
const [azureProvider, setAzureProvider] = useState<
EmbeddingDetails | undefined
>(undefined);
useEffect(() => { useEffect(() => {
const foundProvider = embeddingProviderDetails?.find( const liteLLMProvider = embeddingProviderDetails?.find(
(provider) => (provider) =>
provider.provider_type === EmbeddingProvider.LITELLM.toLowerCase() provider.provider_type === EmbeddingProvider.LITELLM.toLowerCase()
); );
setLiteLLMProvider(foundProvider); setLiteLLMProvider(liteLLMProvider);
const azureProvider = embeddingProviderDetails?.find(
(provider) =>
provider.provider_type === EmbeddingProvider.AZURE.toLowerCase()
);
setAzureProvider(azureProvider);
}, [embeddingProviderDetails]); }, [embeddingProviderDetails]);
const isAzureConfigured = azureProvider !== undefined;
// Get details of the configured Azure provider
const azureProviderDetails = embeddingProviderDetails?.find(
(provider) => provider.provider_type.toLowerCase() === "azure"
);
return ( return (
<div> <div>
<Title className="mt-8"> <Title className="mt-8">
@@ -248,7 +259,8 @@ export default function CloudEmbeddingPage({
: "" : ""
}`} }`}
> >
<LiteLLMModelForm <CustomEmbeddingModelForm
embeddingType={EmbeddingProvider.LITELLM}
provider={liteLLMProvider} provider={liteLLMProvider}
currentValues={ currentValues={
currentModel.provider_type === EmbeddingProvider.LITELLM currentModel.provider_type === EmbeddingProvider.LITELLM
@@ -262,6 +274,126 @@ export default function CloudEmbeddingPage({
)} )}
</div> </div>
</div> </div>
<Text className="mt-6">
You can also use Azure OpenAI models for embeddings. Azure requires
separate configuration for each model.
</Text>
<div key={AZURE_CLOUD_PROVIDER.provider_type} className="mt-4 w-full">
<div className="flex items-center mb-2">
{AZURE_CLOUD_PROVIDER.icon({ size: 40 })}
<h2 className="ml-2 mt-2 text-xl font-bold">
{AZURE_CLOUD_PROVIDER.provider_type}{" "}
</h2>
<HoverPopup
mainContent={
<FiInfo className="ml-2 mt-2 cursor-pointer" size={18} />
}
popupContent={
<div className="text-sm text-text-800 w-52">
<div className="my-auto">
{AZURE_CLOUD_PROVIDER.description}
</div>
</div>
}
style="dark"
/>
</div>
</div>
<div className="w-full flex flex-col items-start">
{!isAzureConfigured ? (
<>
<button
onClick={() => setShowTentativeProvider(AZURE_CLOUD_PROVIDER)}
className="mb-2 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 text-sm cursor-pointer"
>
Configure Azure OpenAI
</button>
<div className="mt-2 w-full max-w-4xl">
<Card className="p-4 border border-gray-200 rounded-lg shadow-sm">
<Text className="text-base font-medium mb-2">
Configure Azure OpenAI for Embeddings
</Text>
<Text className="text-sm text-gray-600 mb-3">
Click &quot;Configure Azure OpenAI&quot; to set up Azure
OpenAI for embeddings.
</Text>
<div className="flex items-center text-sm text-gray-700">
<FiInfo className="text-gray-400 mr-2" size={16} />
<Text>
You&apos;ll need: API version, base URL, API key, model
name, and deployment name.
</Text>
</div>
</Card>
</div>
</>
) : (
<>
<div className="mb-6 w-full">
<Text className="text-lg font-semibold mb-3">
Current Azure Configuration
</Text>
{azureProviderDetails ? (
<Card className="bg-white shadow-sm border border-gray-200 rounded-lg">
<div className="p-4 space-y-3">
<div className="flex justify-between">
<span className="font-medium">API Version:</span>
<span>{azureProviderDetails.api_version}</span>
</div>
<div className="flex justify-between">
<span className="font-medium">Base URL:</span>
<span>{azureProviderDetails.api_url}</span>
</div>
<div className="flex justify-between">
<span className="font-medium">Deployment Name:</span>
<span>{azureProviderDetails.deployment_name}</span>
</div>
</div>
<button
onClick={() =>
setChangeCredentialsProvider(AZURE_CLOUD_PROVIDER)
}
className="mt-2 px-4 py-2 bg-red-500 text-white rounded hover:bg-red-600 text-sm"
>
Delete Current Azure Provider
</button>
</Card>
) : (
<Card className="bg-gray-50 border border-gray-200 rounded-lg">
<div className="p-4 text-gray-500 text-center">
No Azure provider has been configured yet.
</div>
</Card>
)}
</div>
<Card
className={`mt-2 w-full max-w-4xl ${
currentModel.provider_type === EmbeddingProvider.AZURE
? "border-2 border-blue-500"
: ""
}`}
>
{azureProvider && (
<CustomEmbeddingModelForm
embeddingType={EmbeddingProvider.AZURE}
provider={azureProvider}
currentValues={
currentModel.provider_type === EmbeddingProvider.AZURE
? (currentModel as CloudEmbeddingModel)
: null
}
setShowTentativeModel={setShowTentativeModel}
/>
)}
</Card>
</>
)}
</div>
</div> </div>
</div> </div>
); );

View File

@@ -152,11 +152,6 @@ export default function EmbeddingForm() {
} }
}, [currentEmbeddingModel]); }, [currentEmbeddingModel]);
useEffect(() => {
if (currentEmbeddingModel) {
setSelectedProvider(currentEmbeddingModel);
}
}, [currentEmbeddingModel]);
if (!selectedProvider) { if (!selectedProvider) {
return <ThreeDotsLoader />; return <ThreeDotsLoader />;
} }
@@ -164,10 +159,18 @@ export default function EmbeddingForm() {
return <ErrorCallout errorTitle="Failed to fetch embedding model status" />; return <ErrorCallout errorTitle="Failed to fetch embedding model status" />;
} }
const updateCurrentModel = (newModel: string) => {
setAdvancedEmbeddingDetails((values) => ({
...values,
model_name: newModel,
}));
};
const updateSearch = async () => { const updateSearch = async () => {
const values: SavedSearchSettings = { const values: SavedSearchSettings = {
...rerankingDetails, ...rerankingDetails,
...advancedEmbeddingDetails, ...advancedEmbeddingDetails,
...selectedProvider,
provider_type: provider_type:
selectedProvider.provider_type?.toLowerCase() as EmbeddingProvider | null, selectedProvider.provider_type?.toLowerCase() as EmbeddingProvider | null,
}; };
@@ -311,11 +314,13 @@ export default function EmbeddingForm() {
</Text> </Text>
<Card> <Card>
<EmbeddingModelSelection <EmbeddingModelSelection
updateCurrentModel={updateCurrentModel}
setModelTab={setModelTab} setModelTab={setModelTab}
modelTab={modelTab} modelTab={modelTab}
selectedProvider={selectedProvider} selectedProvider={selectedProvider}
currentEmbeddingModel={currentEmbeddingModel} currentEmbeddingModel={currentEmbeddingModel}
updateSelectedProvider={updateSelectedProvider} updateSelectedProvider={updateSelectedProvider}
advancedEmbeddingDetails={advancedEmbeddingDetails}
/> />
</Card> </Card>
<div className="mt-4 flex w-full justify-end"> <div className="mt-4 flex w-full justify-end">

View File

@@ -27,7 +27,7 @@ export async function Layout({ children }: { children: React.ReactNode }) {
const authTypeMetadata = results[0] as AuthTypeMetadata | null; const authTypeMetadata = results[0] as AuthTypeMetadata | null;
const user = results[1] as User | null; const user = results[1] as User | null;
console.log("authTypeMetadata", authTypeMetadata);
const authDisabled = authTypeMetadata?.authType === "disabled"; const authDisabled = authTypeMetadata?.authType === "disabled";
const requiresVerification = authTypeMetadata?.requiresVerification; const requiresVerification = authTypeMetadata?.requiresVerification;

View File

@@ -1,4 +1,4 @@
import { CloudEmbeddingModel, CloudEmbeddingProvider } from "./interfaces"; import { CloudEmbeddingModel, EmbeddingProvider } from "./interfaces";
import { Formik, Form } from "formik"; import { Formik, Form } from "formik";
import * as Yup from "yup"; import * as Yup from "yup";
import { TextFormField, BooleanFormField } from "../admin/connectors/Field"; import { TextFormField, BooleanFormField } from "../admin/connectors/Field";
@@ -6,14 +6,16 @@ import { Dispatch, SetStateAction } from "react";
import { Button, Text } from "@tremor/react"; import { Button, Text } from "@tremor/react";
import { EmbeddingDetails } from "@/app/admin/embeddings/EmbeddingModelSelectionForm"; import { EmbeddingDetails } from "@/app/admin/embeddings/EmbeddingModelSelectionForm";
export function LiteLLMModelForm({ export function CustomEmbeddingModelForm({
setShowTentativeModel, setShowTentativeModel,
currentValues, currentValues,
provider, provider,
embeddingType,
}: { }: {
setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>; setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
currentValues: CloudEmbeddingModel | null; currentValues: CloudEmbeddingModel | null;
provider: EmbeddingDetails; provider: EmbeddingDetails;
embeddingType: EmbeddingProvider;
}) { }) {
return ( return (
<div> <div>
@@ -25,7 +27,7 @@ export function LiteLLMModelForm({
normalize: false, normalize: false,
query_prefix: "", query_prefix: "",
passage_prefix: "", passage_prefix: "",
provider_type: "LiteLLM", provider_type: embeddingType,
api_key: "", api_key: "",
enabled: true, enabled: true,
api_url: provider.api_url, api_url: provider.api_url,
@@ -55,18 +57,21 @@ export function LiteLLMModelForm({
max_tokens: Yup.number(), max_tokens: Yup.number(),
})} })}
onSubmit={async (values) => { onSubmit={async (values) => {
console.log(values);
setShowTentativeModel(values as CloudEmbeddingModel); setShowTentativeModel(values as CloudEmbeddingModel);
}} }}
> >
{({ isSubmitting }) => ( {({ isSubmitting, submitForm, errors }) => (
<Form> <Form>
<Text className="text-xl text-text-900 font-bold mb-4"> <Text className="text-xl text-text-900 font-bold mb-4">
Add a new model to LiteLLM proxy at {provider.api_url} Specify details for your{" "}
{embeddingType === EmbeddingProvider.AZURE ? "Azure" : "LiteLLM"}{" "}
Provider&apos;s model
</Text> </Text>
<TextFormField <TextFormField
name="model_name" name="model_name"
label="Model Name:" label="Model Name:"
subtext="The name of the LiteLLM model" subtext={`The name of the ${embeddingType === EmbeddingProvider.AZURE ? "Azure" : "LiteLLM"} model`}
placeholder="e.g. 'all-MiniLM-L6-v2'" placeholder="e.g. 'all-MiniLM-L6-v2'"
autoCompleteDisabled={true} autoCompleteDisabled={true}
/> />
@@ -103,10 +108,13 @@ export function LiteLLMModelForm({
<Button <Button
type="submit" type="submit"
onClick={() => console.log(errors)}
disabled={isSubmitting} disabled={isSubmitting}
className="w-64 mx-auto" className="w-64 mx-auto"
> >
Configure LiteLLM Model Configure{" "}
{embeddingType === EmbeddingProvider.AZURE ? "Azure" : "LiteLLM"}{" "}
Model
</Button> </Button>
</Form> </Form>
)} )}

View File

@@ -31,7 +31,7 @@ export default function EmbeddingSidebar() {
w-[250px] w-[250px]
`} `}
> >
<div className="fixed h-full left-0 top-0 w-[250px]"> <div className="fixed h-full left-0 top-0 bg-background-100 w-[250px]">
<div className="ml-4 mr-3 flex flex gap-x-1 items-center mt-2 my-auto text-text-700 text-xl"> <div className="ml-4 mr-3 flex flex gap-x-1 items-center mt-2 my-auto text-text-700 text-xl">
<div className="mr-1 my-auto h-6 w-6"> <div className="mr-1 my-auto h-6 w-6">
<Logo height={24} width={24} /> <Logo height={24} width={24} />

View File

@@ -1,4 +1,5 @@
import { import {
AzureIcon,
CohereIcon, CohereIcon,
GoogleIcon, GoogleIcon,
IconProps, IconProps,
@@ -16,6 +17,7 @@ export enum EmbeddingProvider {
VOYAGE = "Voyage", VOYAGE = "Voyage",
GOOGLE = "Google", GOOGLE = "Google",
LITELLM = "LiteLLM", LITELLM = "LiteLLM",
AZURE = "Azure",
} }
export interface CloudEmbeddingProvider { export interface CloudEmbeddingProvider {
@@ -49,6 +51,8 @@ export interface EmbeddingModelDescriptor {
description: string; description: string;
api_key: string | null; api_key: string | null;
api_url: string | null; api_url: string | null;
api_version?: string | null;
deployment_name?: string | null;
index_name: string | null; index_name: string | null;
} }
@@ -161,6 +165,20 @@ export const LITELLM_CLOUD_PROVIDER: CloudEmbeddingProvider = {
embedding_models: [], // No default embedding models embedding_models: [], // No default embedding models
}; };
export const AZURE_CLOUD_PROVIDER: CloudEmbeddingProvider = {
provider_type: EmbeddingProvider.AZURE,
website:
"https://azure.microsoft.com/en-us/products/cognitive-services/openai/",
icon: AzureIcon,
description:
"Azure OpenAI is a cloud-based AI service that provides access to OpenAI models.",
apiLink:
"https://docs.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource",
costslink:
"https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai/",
embedding_models: [], // No default embedding models
};
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
{ {
provider_type: EmbeddingProvider.COHERE, provider_type: EmbeddingProvider.COHERE,