Simpler azure embedding (#2751)

* functional but janky

* nit

* adapt for azure

* nit

* minor updates

* nits

* nit

* nit

* ensure access to litellm

* k
This commit is contained in:
pablodanswer 2024-10-15 16:23:11 -07:00 committed by GitHub
parent 02cc211e91
commit e022e77b6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 524 additions and 214 deletions

View File

@ -0,0 +1,30 @@
"""add api_version and deployment_name to search settings
Revision ID: 5d12a446f5c0
Revises: e4334d5b33ba
Create Date: 2024-10-08 15:56:07.975636
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5d12a446f5c0"
down_revision = "e4334d5b33ba"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"embedding_provider", sa.Column("api_version", sa.String(), nullable=True)
)
op.add_column(
"embedding_provider", sa.Column("deployment_name", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("embedding_provider", "deployment_name")
op.drop_column("embedding_provider", "api_version")

View File

@ -53,7 +53,6 @@ MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)
SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days

View File

@ -615,6 +615,7 @@ class SearchSettings(Base):
normalize: Mapped[bool] = mapped_column(Boolean)
query_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
passage_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
status: Mapped[IndexModelStatus] = mapped_column(
Enum(IndexModelStatus, native_enum=False)
)
@ -670,6 +671,20 @@ class SearchSettings(Base):
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
@property
def api_version(self) -> str | None:
return (
self.cloud_provider.api_version if self.cloud_provider is not None else None
)
@property
def deployment_name(self) -> str | None:
return (
self.cloud_provider.deployment_name
if self.cloud_provider is not None
else None
)
@property
def api_url(self) -> str | None:
return self.cloud_provider.api_url if self.cloud_provider is not None else None
@ -1164,6 +1179,9 @@ class CloudEmbeddingProvider(Base):
)
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
search_settings: Mapped[list["SearchSettings"]] = relationship(
"SearchSettings",
back_populates="cloud_provider",

View File

@ -32,6 +32,8 @@ class IndexingEmbedder(ABC):
provider_type: EmbeddingProvider | None,
api_key: str | None,
api_url: str | None,
api_version: str | None,
deployment_name: str | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
@ -41,6 +43,8 @@ class IndexingEmbedder(ABC):
self.provider_type = provider_type
self.api_key = api_key
self.api_url = api_url
self.api_version = api_version
self.deployment_name = deployment_name
self.embedding_model = EmbeddingModel(
model_name=model_name,
@ -50,6 +54,8 @@ class IndexingEmbedder(ABC):
api_key=api_key,
provider_type=provider_type,
api_url=api_url,
api_version=api_version,
deployment_name=deployment_name,
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
@ -75,6 +81,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type: EmbeddingProvider | None = None,
api_key: str | None = None,
api_url: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
@ -85,6 +93,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type,
api_key,
api_url,
api_version,
deployment_name,
heartbeat,
)
@ -193,5 +203,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
heartbeat=heartbeat,
)

View File

@ -97,6 +97,8 @@ class EmbeddingModel:
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
heartbeat: Heartbeat | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@ -106,6 +108,8 @@ class EmbeddingModel:
self.model_name = model_name
self.retrim_content = retrim_content
self.api_url = api_url
self.api_version = api_version
self.deployment_name = deployment_name
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
@ -157,6 +161,8 @@ class EmbeddingModel:
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
api_version=self.api_version,
deployment_name=self.deployment_name,
max_context_length=max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
@ -239,6 +245,8 @@ class EmbeddingModel:
provider_type=search_settings.provider_type,
api_url=search_settings.api_url,
retrim_content=retrim_content,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
)

View File

@ -43,6 +43,8 @@ def test_embedding_configuration(
api_url=test_llm_request.api_url,
provider_type=test_llm_request.provider_type,
model_name=test_llm_request.model_name,
api_version=test_llm_request.api_version,
deployment_name=test_llm_request.deployment_name,
normalize=False,
query_prefix=None,
passage_prefix=None,

View File

@ -17,6 +17,8 @@ class TestEmbeddingRequest(BaseModel):
api_key: str | None = None
api_url: str | None = None
model_name: str | None = None
api_version: str | None = None
deployment_name: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
@ -26,6 +28,8 @@ class CloudEmbeddingProvider(BaseModel):
provider_type: EmbeddingProvider
api_key: str | None = None
api_url: str | None = None
api_version: str | None = None
deployment_name: str | None = None
@classmethod
def from_request(
@ -35,6 +39,8 @@ class CloudEmbeddingProvider(BaseModel):
provider_type=cloud_provider_model.provider_type,
api_key=cloud_provider_model.api_key,
api_url=cloud_provider_model.api_url,
api_version=cloud_provider_model.api_version,
deployment_name=cloud_provider_model.deployment_name,
)
@ -42,3 +48,5 @@ class CloudEmbeddingProviderCreationRequest(BaseModel):
provider_type: EmbeddingProvider
api_key: str | None = None
api_url: str | None = None
api_version: str | None = None
deployment_name: str | None = None

View File

@ -10,6 +10,7 @@ from cohere import Client as CohereClient
from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from litellm import embedding
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
@ -54,7 +55,11 @@ _COHERE_MAX_INPUT_LEN = 96
def _initialize_client(
api_key: str, provider: EmbeddingProvider, model: str | None = None
api_key: str,
provider: EmbeddingProvider,
model: str | None = None,
api_url: str | None = None,
api_version: str | None = None,
) -> Any:
if provider == EmbeddingProvider.OPENAI:
return openai.OpenAI(api_key=api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
@ -69,6 +74,8 @@ def _initialize_client(
project_id = json.loads(api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
elif provider == EmbeddingProvider.AZURE:
return {"api_key": api_key, "api_url": api_url, "api_version": api_version}
else:
raise ValueError(f"Unsupported provider: {provider}")
@ -78,11 +85,15 @@ class CloudEmbedding:
self,
api_key: str,
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
# Only for Google as is needed on client setup
model: str | None = None,
) -> None:
self.provider = provider
self.client = _initialize_client(api_key, self.provider, model)
self.client = _initialize_client(
api_key, self.provider, model, api_url, api_version
)
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
if not model:
@ -144,6 +155,18 @@ class CloudEmbedding:
)
return response.embeddings
def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]:
response = embedding(
model=model,
input=texts,
api_key=self.client["api_key"],
api_base=self.client["api_url"],
api_version=self.client["api_version"],
)
embeddings = [embedding["embedding"] for embedding in response.data]
return embeddings
def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
@ -169,10 +192,13 @@ class CloudEmbedding:
texts: list[str],
text_type: EmbedTextType,
model_name: str | None = None,
deployment_name: str | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
@ -190,10 +216,14 @@ class CloudEmbedding:
@staticmethod
def create(
api_key: str, provider: EmbeddingProvider, model: str | None = None
api_key: str,
provider: EmbeddingProvider,
model: str | None = None,
api_url: str | None = None,
api_version: str | None = None,
) -> "CloudEmbedding":
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, model)
return CloudEmbedding(api_key, provider, model, api_url, api_version)
def get_embedding_model(
@ -260,12 +290,14 @@ def embed_text(
texts: list[str],
text_type: EmbedTextType,
model_name: str | None,
deployment_name: str | None,
max_context_length: int,
normalize_embeddings: bool,
api_key: str | None,
provider_type: EmbeddingProvider | None,
prefix: str | None,
api_url: str | None,
api_version: str | None,
) -> list[Embedding]:
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
@ -307,11 +339,16 @@ def embed_text(
)
cloud_model = CloudEmbedding(
api_key=api_key, provider=provider_type, model=model_name
api_key=api_key,
provider=provider_type,
model=model_name,
api_url=api_url,
api_version=api_version,
)
embeddings = cloud_model.embed(
texts=texts,
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
)
@ -405,12 +442,14 @@ async def process_embed_request(
embeddings = embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
deployment_name=embed_request.deployment_name,
max_context_length=embed_request.max_context_length,
normalize_embeddings=embed_request.normalize_embeddings,
api_key=embed_request.api_key,
provider_type=embed_request.provider_type,
text_type=embed_request.text_type,
api_url=embed_request.api_url,
api_version=embed_request.api_version,
prefix=prefix,
)
return EmbedResponse(embeddings=embeddings)

View File

@ -12,3 +12,4 @@ torch==2.2.0
transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
litellm==1.48.7

View File

@ -7,6 +7,7 @@ class EmbeddingProvider(str, Enum):
VOYAGE = "voyage"
GOOGLE = "google"
LITELLM = "litellm"
AZURE = "azure"
class RerankerProvider(str, Enum):

View File

@ -20,6 +20,7 @@ class EmbedRequest(BaseModel):
texts: list[str]
# Can be none for cloud embedding model requests, error handling logic exists for other cases
model_name: str | None = None
deployment_name: str | None = None
max_context_length: int
normalize_embeddings: bool
api_key: str | None = None
@ -28,7 +29,7 @@ class EmbedRequest(BaseModel):
manual_query_prefix: str | None = None
manual_passage_prefix: str | None = None
api_url: str | None = None
api_version: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

View File

@ -27,10 +27,13 @@ import {
EMBEDDING_MODELS_ADMIN_URL,
EMBEDDING_PROVIDERS_ADMIN_URL,
} from "../configuration/llm/constants";
import { AdvancedSearchConfiguration } from "./interfaces";
export interface EmbeddingDetails {
api_key?: string;
api_url?: string;
api_version?: string;
deployment_name?: string;
custom_config: any;
provider_type: EmbeddingProvider;
}
@ -41,6 +44,8 @@ export function EmbeddingModelSelection({
updateSelectedProvider,
modelTab,
setModelTab,
updateCurrentModel,
advancedEmbeddingDetails,
}: {
modelTab: "open" | "cloud" | null;
setModelTab: Dispatch<SetStateAction<"open" | "cloud" | null>>;
@ -49,6 +54,11 @@ export function EmbeddingModelSelection({
updateSelectedProvider: (
model: CloudEmbeddingModel | HostedEmbeddingModel
) => void;
updateCurrentModel: (
newModel: string,
provider_type: EmbeddingProvider
) => void;
advancedEmbeddingDetails: AdvancedSearchConfiguration;
}) {
// Cloud Provider based modals
const [showTentativeProvider, setShowTentativeProvider] =
@ -72,12 +82,6 @@ export function EmbeddingModelSelection({
const [showTentativeOpenProvider, setShowTentativeOpenProvider] =
useState<HostedEmbeddingModel | null>(null);
// Enabled / unenabled providers
const [newEnabledProviders, setNewEnabledProviders] = useState<string[]>([]);
const [newUnenabledProviders, setNewUnenabledProviders] = useState<string[]>(
[]
);
const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] =
useState<boolean>(false);
@ -90,7 +94,10 @@ export function EmbeddingModelSelection({
{ refreshInterval: 5000 } // 5 seconds
);
const { data: embeddingProviderDetails } = useSWR<EmbeddingDetails[]>(
const {
data: embeddingProviderDetails,
mutate: mutateEmbeddingProviderDetails,
} = useSWR<EmbeddingDetails[]>(
EMBEDDING_PROVIDERS_ADMIN_URL,
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
@ -132,32 +139,6 @@ export function EmbeddingModelSelection({
}
};
const clientsideAddProvider = (provider: CloudEmbeddingProvider) => {
const providerType = provider.provider_type;
setNewEnabledProviders((newEnabledProviders) => [
...newEnabledProviders,
providerType,
]);
setNewUnenabledProviders((newUnenabledProviders) =>
newUnenabledProviders.filter(
(givenProviderType) => givenProviderType != providerType
)
);
};
const clientsideRemoveProvider = (provider: CloudEmbeddingProvider) => {
const providerType = provider.provider_type;
setNewEnabledProviders((newEnabledProviders) =>
newEnabledProviders.filter(
(givenProviderType) => givenProviderType != providerType
)
);
setNewUnenabledProviders((newUnenabledProviders) => [
...newUnenabledProviders,
providerType,
]);
};
return (
<div className="p-2">
{alreadySelectedModel && (
@ -186,14 +167,16 @@ export function EmbeddingModelSelection({
{showTentativeProvider && (
<ProviderCreationModal
updateCurrentModel={updateCurrentModel}
isProxy={showTentativeProvider.provider_type == "LiteLLM"}
isAzure={showTentativeProvider.provider_type == "Azure"}
selectedProvider={showTentativeProvider}
onConfirm={() => {
setShowTentativeProvider(showUnconfiguredProvider);
clientsideAddProvider(showTentativeProvider);
if (showModelInQueue) {
setShowTentativeModel(showModelInQueue);
}
mutateEmbeddingProviderDetails();
}}
onCancel={() => {
setShowModelInQueue(null);
@ -205,10 +188,11 @@ export function EmbeddingModelSelection({
{changeCredentialsProvider && (
<ChangeCredentialsModal
isProxy={changeCredentialsProvider.provider_type == "LiteLLM"}
isAzure={changeCredentialsProvider.provider_type == "Azure"}
useFileUpload={changeCredentialsProvider.provider_type == "Google"}
onDeleted={() => {
clientsideRemoveProvider(changeCredentialsProvider);
setChangeCredentialsProvider(null);
mutateEmbeddingProviderDetails();
}}
provider={changeCredentialsProvider}
onConfirm={() => setChangeCredentialsProvider(null)}
@ -236,12 +220,13 @@ export function EmbeddingModelSelection({
modelProvider={showTentativeProvider!}
onConfirm={() => {
setShowDeleteCredentialsModal(false);
mutateEmbeddingProviderDetails();
}}
onCancel={() => setShowDeleteCredentialsModal(false)}
/>
)}
<p className="t mb-4">
<p className="mb-4">
Select from cloud, self-hosted models, or continue with your current
embedding model.
</p>
@ -291,14 +276,13 @@ export function EmbeddingModelSelection({
{modelTab == "cloud" && (
<CloudEmbeddingPage
advancedEmbeddingDetails={advancedEmbeddingDetails}
embeddingModelDetails={embeddingModelDetails}
setShowModelInQueue={setShowModelInQueue}
setShowTentativeModel={setShowTentativeModel}
currentModel={selectedProvider || currentEmbeddingModel}
setAlreadySelectedModel={setAlreadySelectedModel}
embeddingProviderDetails={embeddingProviderDetails}
newEnabledProviders={newEnabledProviders}
newUnenabledProviders={newUnenabledProviders}
setShowTentativeProvider={setShowTentativeProvider}
setChangeCredentialsProvider={setChangeCredentialsProvider}
/>

View File

@ -16,6 +16,7 @@ export function ChangeCredentialsModal({
onDeleted,
useFileUpload,
isProxy = false,
isAzure = false,
}: {
provider: CloudEmbeddingProvider;
onConfirm: () => void;
@ -23,6 +24,7 @@ export function ChangeCredentialsModal({
onDeleted: () => void;
useFileUpload: boolean;
isProxy?: boolean;
isAzure?: boolean;
}) {
const [apiKey, setApiKey] = useState("");
const [apiUrl, setApiUrl] = useState("");
@ -151,7 +153,6 @@ export function ChangeCredentialsModal({
);
}
};
return (
<Modal
width="max-w-3xl"
@ -160,133 +161,131 @@ export function ChangeCredentialsModal({
onOutsideClick={onCancel}
>
<>
<p className="mb-4">
You can modify your configuration by providing a new API key
{isProxy ? " or API URL." : "."}
</p>
{!isAzure && (
<>
<p className="mb-4">
You can modify your configuration by providing a new API key
{isProxy ? " or API URL." : "."}
</p>
<div className="mb-4 flex flex-col gap-y-2">
<Label className="mt-2">API Key</Label>
{useFileUpload ? (
<>
<Label className="mt-2">Upload JSON File</Label>
<input
ref={fileInputRef}
type="file"
accept=".json"
onChange={handleFileUpload}
className="text-lg w-full p-1"
/>
{fileName && <p>Uploaded file: {fileName}</p>}
</>
) : (
<>
<input
className={`
border
border-border
rounded
w-full
py-2
px-3
bg-background-emphasis
`}
value={apiKey}
onChange={(e: any) => setApiKey(e.target.value)}
placeholder="Paste your API key here"
/>
</>
)}
<div className="mb-4 flex flex-col gap-y-2">
<Label className="mt-2">API Key</Label>
{useFileUpload ? (
<>
<Label className="mt-2">Upload JSON File</Label>
<input
ref={fileInputRef}
type="file"
accept=".json"
onChange={handleFileUpload}
className="text-lg w-full p-1"
/>
{fileName && <p>Uploaded file: {fileName}</p>}
</>
) : (
<>
<input
className={`
border
border-border
rounded
w-full
py-2
px-3
bg-background-emphasis
`}
value={apiKey}
onChange={(e: any) => setApiKey(e.target.value)}
placeholder="Paste your API key here"
/>
</>
)}
{isProxy && (
<>
<Label className="mt-2">API URL</Label>
{isProxy && (
<>
<Label className="mt-2">API URL</Label>
<input
className={`
border
border-border
rounded
w-full
py-2
px-3
bg-background-emphasis
`}
value={apiUrl}
onChange={(e: any) => setApiUrl(e.target.value)}
placeholder="Paste your API URL here"
/>
<input
className={`
border
border-border
rounded
w-full
py-2
px-3
bg-background-emphasis
`}
value={apiUrl}
onChange={(e: any) => setApiUrl(e.target.value)}
placeholder="Paste your API URL here"
/>
{deletionError && (
<Callout title="Error" color="red" className="mt-4">
{deletionError}
{deletionError && (
<Callout title="Error" color="red" className="mt-4">
{deletionError}
</Callout>
)}
<div>
<Label className="mt-2">Test Model</Label>
<p>
Since you are using a liteLLM proxy, we&apos;ll need a
model name to test the connection with.
</p>
</div>
<input
className={`
border
border-border
rounded
w-full
py-2
px-3
bg-background-emphasis
`}
value={modelName}
onChange={(e: any) => setModelName(e.target.value)}
placeholder="Paste your model name here"
/>
</>
)}
{testError && (
<Callout title="Error" color="red" className="my-4">
{testError}
</Callout>
)}
<div>
<Label className="mt-2">Test Model</Label>
<p>
Since you are using a liteLLM proxy, we&apos;ll need a model
name to test the connection with.
</p>
</div>
<input
className={`
border
border-border
rounded
w-full
py-2
px-3
bg-background-emphasis
`}
value={modelName}
onChange={(e: any) => setModelName(e.target.value)}
placeholder="Paste your API URL here"
/>
<Button
className="mr-auto mt-4"
color="blue"
onClick={() => handleSubmit()}
disabled={!apiKey}
>
Update Configuration
</Button>
{deletionError && (
<Callout title="Error" color="red" className="mt-4">
{deletionError}
</Callout>
)}
</>
)}
<Divider />
</div>
</>
)}
{testError && (
<Callout title="Error" color="red" className="my-4">
{testError}
</Callout>
)}
<Subtitle className="mt-4 font-bold text-lg mb-2">
You can delete your configuration.
</Subtitle>
<Text className="mb-2">
This is only possible if you have already switched to a different
embedding type!
</Text>
<Button
className="mr-auto mt-4"
color="blue"
onClick={() => handleSubmit()}
disabled={!apiKey}
>
Update Configuration
</Button>
<Divider />
<Subtitle className="mt-4 font-bold text-lg mb-2">
You can also delete your configuration.
</Subtitle>
<Text className="mb-2">
This is only possible if you have already switched to a different
embedding type!
</Text>
<Button className="mr-auto" onClick={handleDelete} color="red">
Delete Configuration
</Button>
{deletionError && (
<Callout title="Error" color="red" className="mt-4">
{deletionError}
</Callout>
)}
</div>
<Button className="mr-auto" onClick={handleDelete} color="red">
Delete Configuration
</Button>
{deletionError && (
<Callout title="Error" color="red" className="mt-4">
{deletionError}
</Callout>
)}
</>
</Modal>
);

View File

@ -4,7 +4,10 @@ import { Formik, Form } from "formik";
import * as Yup from "yup";
import { Label, TextFormField } from "@/components/admin/connectors/Field";
import { LoadingAnimation } from "@/components/Loading";
import { CloudEmbeddingProvider } from "../../../../components/embedding/interfaces";
import {
CloudEmbeddingProvider,
EmbeddingProvider,
} from "../../../../components/embedding/interfaces";
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../../configuration/llm/constants";
import { Modal } from "@/components/Modal";
@ -14,12 +17,19 @@ export function ProviderCreationModal({
onCancel,
existingProvider,
isProxy,
isAzure,
updateCurrentModel,
}: {
updateCurrentModel: (
newModel: string,
provider_type: EmbeddingProvider
) => void;
selectedProvider: CloudEmbeddingProvider;
onConfirm: () => void;
onCancel: () => void;
existingProvider?: CloudEmbeddingProvider;
isProxy?: boolean;
isAzure?: boolean;
}) {
const useFileUpload = selectedProvider.provider_type == "Google";
@ -41,16 +51,24 @@ export function ProviderCreationModal({
const validationSchema = Yup.object({
provider_type: Yup.string().required("Provider type is required"),
api_key: isProxy
? Yup.string()
: useFileUpload
api_key:
isProxy || isAzure
? Yup.string()
: Yup.string().required("API Key is required"),
: useFileUpload
? Yup.string()
: Yup.string().required("API Key is required"),
model_name: isProxy
? Yup.string().required("Model name is required")
: Yup.string().nullable(),
api_url: isProxy
? Yup.string().required("API URL is required")
api_url:
isProxy || isAzure
? Yup.string().required("API URL is required")
: Yup.string(),
deployment_name: isAzure
? Yup.string().required("Deployment name is required")
: Yup.string(),
api_version: isAzure
? Yup.string().required("API Version is required")
: Yup.string(),
custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)),
});
@ -101,6 +119,8 @@ export function ProviderCreationModal({
api_key: values.api_key,
api_url: values.api_url,
model_name: values.model_name,
api_version: values.api_version,
deployment_name: values.deployment_name,
}),
}
);
@ -118,6 +138,8 @@ export function ProviderCreationModal({
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
...values,
api_version: values.api_version,
deployment_name: values.deployment_name,
provider_type: values.provider_type.toLowerCase().split(" ")[0],
custom_config: customConfig,
is_default_provider: false,
@ -125,6 +147,10 @@ export function ProviderCreationModal({
}),
});
if (isAzure) {
updateCurrentModel(values.model_name, EmbeddingProvider.AZURE);
}
if (!response.ok) {
const errorData = await response.json();
throw new Error(
@ -178,26 +204,45 @@ export function ProviderCreationModal({
href={selectedProvider.apiLink}
rel="noreferrer"
>
{isProxy ? "API URL" : "API KEY"}
{isProxy || isAzure ? "API URL" : "API KEY"}
</a>
</Text>
<div className="flex w-full flex-col gap-y-6">
{(isProxy || isAzure) && (
<TextFormField
name="api_url"
label="API URL"
placeholder="API URL"
type="text"
/>
)}
{isProxy && (
<>
<TextFormField
name="api_url"
label="API URL"
placeholder="API URL"
type="text"
/>
<TextFormField
name="model_name"
label="Model Name (for testing)"
placeholder="Model Name"
type="text"
/>
</>
<TextFormField
name="model_name"
label={`Model Name ${isProxy ? "(for testing)" : ""}`}
placeholder="Model Name"
type="text"
/>
)}
{isAzure && (
<TextFormField
name="deployment_name"
label="Deployment Name"
placeholder="Deployment Name"
type="text"
/>
)}
{isAzure && (
<TextFormField
name="api_version"
label="API Version"
placeholder="API Version"
type="text"
/>
)}
{useFileUpload ? (

View File

@ -10,42 +10,42 @@ import {
EmbeddingModelDescriptor,
EmbeddingProvider,
LITELLM_CLOUD_PROVIDER,
AZURE_CLOUD_PROVIDER,
} from "../../../../components/embedding/interfaces";
import { EmbeddingDetails } from "../EmbeddingModelSelectionForm";
import { FiExternalLink, FiInfo, FiTrash } from "react-icons/fi";
import { HoverPopup } from "@/components/HoverPopup";
import { Dispatch, SetStateAction, useEffect, useState } from "react";
import { LiteLLMModelForm } from "@/components/embedding/LiteLLMModelForm";
import { CustomEmbeddingModelForm } from "@/components/embedding/CustomEmbeddingModelForm";
import { deleteSearchSettings } from "./utils";
import { usePopup } from "@/components/admin/connectors/Popup";
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
import { AdvancedSearchConfiguration } from "../interfaces";
export default function CloudEmbeddingPage({
currentModel,
embeddingProviderDetails,
embeddingModelDetails,
newEnabledProviders,
newUnenabledProviders,
setShowTentativeProvider,
setChangeCredentialsProvider,
setAlreadySelectedModel,
setShowTentativeModel,
setShowModelInQueue,
advancedEmbeddingDetails,
}: {
setShowModelInQueue: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel;
setAlreadySelectedModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
newUnenabledProviders: string[];
embeddingModelDetails?: CloudEmbeddingModel[];
embeddingProviderDetails?: EmbeddingDetails[];
newEnabledProviders: string[];
setShowTentativeProvider: React.Dispatch<
React.SetStateAction<CloudEmbeddingProvider | null>
>;
setChangeCredentialsProvider: React.Dispatch<
React.SetStateAction<CloudEmbeddingProvider | null>
>;
advancedEmbeddingDetails: AdvancedSearchConfiguration;
}) {
function hasProviderTypeinArray(
arr: Array<{ provider_type: string }>,
@ -60,27 +60,38 @@ export default function CloudEmbeddingPage({
(model) => ({
...model,
configured:
!newUnenabledProviders.includes(model.provider_type) &&
(newEnabledProviders.includes(model.provider_type) ||
(embeddingProviderDetails &&
hasProviderTypeinArray(
embeddingProviderDetails,
model.provider_type
))!),
embeddingProviderDetails &&
hasProviderTypeinArray(embeddingProviderDetails, model.provider_type),
})
);
const [liteLLMProvider, setLiteLLMProvider] = useState<
EmbeddingDetails | undefined
>(undefined);
const [azureProvider, setAzureProvider] = useState<
EmbeddingDetails | undefined
>(undefined);
useEffect(() => {
const foundProvider = embeddingProviderDetails?.find(
const liteLLMProvider = embeddingProviderDetails?.find(
(provider) =>
provider.provider_type === EmbeddingProvider.LITELLM.toLowerCase()
);
setLiteLLMProvider(foundProvider);
setLiteLLMProvider(liteLLMProvider);
const azureProvider = embeddingProviderDetails?.find(
(provider) =>
provider.provider_type === EmbeddingProvider.AZURE.toLowerCase()
);
setAzureProvider(azureProvider);
}, [embeddingProviderDetails]);
const isAzureConfigured = azureProvider !== undefined;
// Get details of the configured Azure provider
const azureProviderDetails = embeddingProviderDetails?.find(
(provider) => provider.provider_type.toLowerCase() === "azure"
);
return (
<div>
<Title className="mt-8">
@ -248,7 +259,8 @@ export default function CloudEmbeddingPage({
: ""
}`}
>
<LiteLLMModelForm
<CustomEmbeddingModelForm
embeddingType={EmbeddingProvider.LITELLM}
provider={liteLLMProvider}
currentValues={
currentModel.provider_type === EmbeddingProvider.LITELLM
@ -262,6 +274,126 @@ export default function CloudEmbeddingPage({
)}
</div>
</div>
<Text className="mt-6">
You can also use Azure OpenAI models for embeddings. Azure requires
separate configuration for each model.
</Text>
<div key={AZURE_CLOUD_PROVIDER.provider_type} className="mt-4 w-full">
<div className="flex items-center mb-2">
{AZURE_CLOUD_PROVIDER.icon({ size: 40 })}
<h2 className="ml-2 mt-2 text-xl font-bold">
{AZURE_CLOUD_PROVIDER.provider_type}{" "}
</h2>
<HoverPopup
mainContent={
<FiInfo className="ml-2 mt-2 cursor-pointer" size={18} />
}
popupContent={
<div className="text-sm text-text-800 w-52">
<div className="my-auto">
{AZURE_CLOUD_PROVIDER.description}
</div>
</div>
}
style="dark"
/>
</div>
</div>
<div className="w-full flex flex-col items-start">
{!isAzureConfigured ? (
<>
<button
onClick={() => setShowTentativeProvider(AZURE_CLOUD_PROVIDER)}
className="mb-2 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 text-sm cursor-pointer"
>
Configure Azure OpenAI
</button>
<div className="mt-2 w-full max-w-4xl">
<Card className="p-4 border border-gray-200 rounded-lg shadow-sm">
<Text className="text-base font-medium mb-2">
Configure Azure OpenAI for Embeddings
</Text>
<Text className="text-sm text-gray-600 mb-3">
Click &quot;Configure Azure OpenAI&quot; to set up Azure
OpenAI for embeddings.
</Text>
<div className="flex items-center text-sm text-gray-700">
<FiInfo className="text-gray-400 mr-2" size={16} />
<Text>
You&apos;ll need: API version, base URL, API key, model
name, and deployment name.
</Text>
</div>
</Card>
</div>
</>
) : (
<>
<div className="mb-6 w-full">
<Text className="text-lg font-semibold mb-3">
Current Azure Configuration
</Text>
{azureProviderDetails ? (
<Card className="bg-white shadow-sm border border-gray-200 rounded-lg">
<div className="p-4 space-y-3">
<div className="flex justify-between">
<span className="font-medium">API Version:</span>
<span>{azureProviderDetails.api_version}</span>
</div>
<div className="flex justify-between">
<span className="font-medium">Base URL:</span>
<span>{azureProviderDetails.api_url}</span>
</div>
<div className="flex justify-between">
<span className="font-medium">Deployment Name:</span>
<span>{azureProviderDetails.deployment_name}</span>
</div>
</div>
<button
onClick={() =>
setChangeCredentialsProvider(AZURE_CLOUD_PROVIDER)
}
className="mt-2 px-4 py-2 bg-red-500 text-white rounded hover:bg-red-600 text-sm"
>
Delete Current Azure Provider
</button>
</Card>
) : (
<Card className="bg-gray-50 border border-gray-200 rounded-lg">
<div className="p-4 text-gray-500 text-center">
No Azure provider has been configured yet.
</div>
</Card>
)}
</div>
<Card
className={`mt-2 w-full max-w-4xl ${
currentModel.provider_type === EmbeddingProvider.AZURE
? "border-2 border-blue-500"
: ""
}`}
>
{azureProvider && (
<CustomEmbeddingModelForm
embeddingType={EmbeddingProvider.AZURE}
provider={azureProvider}
currentValues={
currentModel.provider_type === EmbeddingProvider.AZURE
? (currentModel as CloudEmbeddingModel)
: null
}
setShowTentativeModel={setShowTentativeModel}
/>
)}
</Card>
</>
)}
</div>
</div>
</div>
);

View File

@ -152,11 +152,6 @@ export default function EmbeddingForm() {
}
}, [currentEmbeddingModel]);
useEffect(() => {
if (currentEmbeddingModel) {
setSelectedProvider(currentEmbeddingModel);
}
}, [currentEmbeddingModel]);
if (!selectedProvider) {
return <ThreeDotsLoader />;
}
@ -164,10 +159,18 @@ export default function EmbeddingForm() {
return <ErrorCallout errorTitle="Failed to fetch embedding model status" />;
}
const updateCurrentModel = (newModel: string) => {
setAdvancedEmbeddingDetails((values) => ({
...values,
model_name: newModel,
}));
};
const updateSearch = async () => {
const values: SavedSearchSettings = {
...rerankingDetails,
...advancedEmbeddingDetails,
...selectedProvider,
provider_type:
selectedProvider.provider_type?.toLowerCase() as EmbeddingProvider | null,
};
@ -311,11 +314,13 @@ export default function EmbeddingForm() {
</Text>
<Card>
<EmbeddingModelSelection
updateCurrentModel={updateCurrentModel}
setModelTab={setModelTab}
modelTab={modelTab}
selectedProvider={selectedProvider}
currentEmbeddingModel={currentEmbeddingModel}
updateSelectedProvider={updateSelectedProvider}
advancedEmbeddingDetails={advancedEmbeddingDetails}
/>
</Card>
<div className="mt-4 flex w-full justify-end">

View File

@ -27,7 +27,7 @@ export async function Layout({ children }: { children: React.ReactNode }) {
const authTypeMetadata = results[0] as AuthTypeMetadata | null;
const user = results[1] as User | null;
console.log("authTypeMetadata", authTypeMetadata);
const authDisabled = authTypeMetadata?.authType === "disabled";
const requiresVerification = authTypeMetadata?.requiresVerification;

View File

@ -1,4 +1,4 @@
import { CloudEmbeddingModel, CloudEmbeddingProvider } from "./interfaces";
import { CloudEmbeddingModel, EmbeddingProvider } from "./interfaces";
import { Formik, Form } from "formik";
import * as Yup from "yup";
import { TextFormField, BooleanFormField } from "../admin/connectors/Field";
@ -6,14 +6,16 @@ import { Dispatch, SetStateAction } from "react";
import { Button, Text } from "@tremor/react";
import { EmbeddingDetails } from "@/app/admin/embeddings/EmbeddingModelSelectionForm";
export function LiteLLMModelForm({
export function CustomEmbeddingModelForm({
setShowTentativeModel,
currentValues,
provider,
embeddingType,
}: {
setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
currentValues: CloudEmbeddingModel | null;
provider: EmbeddingDetails;
embeddingType: EmbeddingProvider;
}) {
return (
<div>
@ -25,7 +27,7 @@ export function LiteLLMModelForm({
normalize: false,
query_prefix: "",
passage_prefix: "",
provider_type: "LiteLLM",
provider_type: embeddingType,
api_key: "",
enabled: true,
api_url: provider.api_url,
@ -55,18 +57,21 @@ export function LiteLLMModelForm({
max_tokens: Yup.number(),
})}
onSubmit={async (values) => {
console.log(values);
setShowTentativeModel(values as CloudEmbeddingModel);
}}
>
{({ isSubmitting }) => (
{({ isSubmitting, submitForm, errors }) => (
<Form>
<Text className="text-xl text-text-900 font-bold mb-4">
Add a new model to LiteLLM proxy at {provider.api_url}
Specify details for your{" "}
{embeddingType === EmbeddingProvider.AZURE ? "Azure" : "LiteLLM"}{" "}
Provider&apos;s model
</Text>
<TextFormField
name="model_name"
label="Model Name:"
subtext="The name of the LiteLLM model"
subtext={`The name of the ${embeddingType === EmbeddingProvider.AZURE ? "Azure" : "LiteLLM"} model`}
placeholder="e.g. 'all-MiniLM-L6-v2'"
autoCompleteDisabled={true}
/>
@ -103,10 +108,13 @@ export function LiteLLMModelForm({
<Button
type="submit"
onClick={() => console.log(errors)}
disabled={isSubmitting}
className="w-64 mx-auto"
>
Configure LiteLLM Model
Configure{" "}
{embeddingType === EmbeddingProvider.AZURE ? "Azure" : "LiteLLM"}{" "}
Model
</Button>
</Form>
)}

View File

@ -31,7 +31,7 @@ export default function EmbeddingSidebar() {
w-[250px]
`}
>
<div className="fixed h-full left-0 top-0 w-[250px]">
<div className="fixed h-full left-0 top-0 bg-background-100 w-[250px]">
<div className="ml-4 mr-3 flex flex gap-x-1 items-center mt-2 my-auto text-text-700 text-xl">
<div className="mr-1 my-auto h-6 w-6">
<Logo height={24} width={24} />

View File

@ -1,4 +1,5 @@
import {
AzureIcon,
CohereIcon,
GoogleIcon,
IconProps,
@ -16,6 +17,7 @@ export enum EmbeddingProvider {
VOYAGE = "Voyage",
GOOGLE = "Google",
LITELLM = "LiteLLM",
AZURE = "Azure",
}
export interface CloudEmbeddingProvider {
@ -49,6 +51,8 @@ export interface EmbeddingModelDescriptor {
description: string;
api_key: string | null;
api_url: string | null;
api_version?: string | null;
deployment_name?: string | null;
index_name: string | null;
}
@ -161,6 +165,20 @@ export const LITELLM_CLOUD_PROVIDER: CloudEmbeddingProvider = {
embedding_models: [], // No default embedding models
};
export const AZURE_CLOUD_PROVIDER: CloudEmbeddingProvider = {
provider_type: EmbeddingProvider.AZURE,
website:
"https://azure.microsoft.com/en-us/products/cognitive-services/openai/",
icon: AzureIcon,
description:
"Azure OpenAI is a cloud-based AI service that provides access to OpenAI models.",
apiLink:
"https://docs.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource",
costslink:
"https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai/",
embedding_models: [], // No default embedding models
};
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
{
provider_type: EmbeddingProvider.COHERE,