Update embedding interface (#2205)

* squash

* simplify interface

* some updates to typing

* cloud provider type

* update typing to be even clearer

* push local commit (squash)

* cleaner interfaces

* another quick pass

* squash

* cleaner alembic

* cleaner

* remove trailing whitespace

* add sequence

* quick circle back to double check

* update

* update naming

* update naming
This commit is contained in:
pablodanswer
2024-08-22 20:52:02 -07:00
committed by GitHub
parent 7da6d33451
commit e89dc67e5d
20 changed files with 295 additions and 147 deletions

View File

@@ -3,7 +3,6 @@ import { Modal } from "@/components/Modal";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { ConnectorIndexingStatus } from "@/lib/types";
import { Button, Text, Title } from "@tremor/react";
import Link from "next/link";
import { useState } from "react";
import useSWR, { mutate } from "swr";
import { ReindexingProgressTable } from "../../../../components/embedding/ReindexingProgressTable";

View File

@@ -79,7 +79,7 @@ function Main() {
(provider) =>
provider.embedding_models.map((model) => ({
...model,
cloud_provider_id: provider.id,
provider_type: provider.provider_type,
model_name: model.model_name, // Ensure model_name is set for consistency
}))
);

View File

@@ -11,6 +11,7 @@ import {
INVALID_OLD_MODEL,
HostedEmbeddingModel,
EmbeddingModelDescriptor,
EmbeddingProvider,
} from "../../../components/embedding/interfaces";
import { Connector } from "@/lib/connectors/connectors";
import OpenEmbeddingPage from "./pages/OpenEmbeddingPage";
@@ -28,8 +29,7 @@ import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../configuration/llm/constants";
export interface EmbeddingDetails {
api_key: string;
custom_config: any;
default_model_id?: number;
name: string;
provider_type: EmbeddingProvider;
}
export function EmbeddingModelSelection({
@@ -122,28 +122,28 @@ export function EmbeddingModelSelection({
};
const clientsideAddProvider = (provider: CloudEmbeddingProvider) => {
const providerName = provider.name;
const providerType = provider.provider_type;
setNewEnabledProviders((newEnabledProviders) => [
...newEnabledProviders,
providerName,
providerType,
]);
setNewUnenabledProviders((newUnenabledProviders) =>
newUnenabledProviders.filter(
(givenProvidername) => givenProvidername != providerName
(givenProviderType) => givenProviderType != providerType
)
);
};
const clientsideRemoveProvider = (provider: CloudEmbeddingProvider) => {
const providerName = provider.name;
const providerType = provider.provider_type;
setNewEnabledProviders((newEnabledProviders) =>
newEnabledProviders.filter(
(givenProvidername) => givenProvidername != providerName
(givenProviderType) => givenProviderType != providerType
)
);
setNewUnenabledProviders((newUnenabledProviders) => [
...newUnenabledProviders,
providerName,
providerType,
]);
};
@@ -191,7 +191,7 @@ export function EmbeddingModelSelection({
)}
{changeCredentialsProvider && (
<ChangeCredentialsModal
useFileUpload={changeCredentialsProvider.name == "Google"}
useFileUpload={changeCredentialsProvider.provider_type == "Google"}
onDeleted={() => {
clientsideRemoveProvider(changeCredentialsProvider);
setChangeCredentialsProvider(null);

View File

@@ -74,7 +74,7 @@ export function ChangeCredentialsModal({
try {
const response = await fetch(
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.name}`,
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type}`,
{
method: "DELETE",
}
@@ -99,19 +99,12 @@ export function ChangeCredentialsModal({
const handleSubmit = async () => {
setTestError("");
try {
const body = JSON.stringify({
api_key: apiKey,
provider: provider.name.toLowerCase().split(" ")[0],
default_model_id: provider.name,
});
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider: provider.name.toLowerCase().split(" ")[0],
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
api_key: apiKey,
}),
});
@@ -125,7 +118,7 @@ export function ChangeCredentialsModal({
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
name: provider.name,
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
api_key: apiKey,
is_default_provider: false,
is_configured: true,
@@ -151,7 +144,7 @@ export function ChangeCredentialsModal({
<Modal
width="max-w-3xl"
icon={provider.icon}
title={`Modify your ${provider.name} key`}
title={`Modify your ${provider.provider_type} key`}
onOutsideClick={onCancel}
>
<div className="mb-4">

View File

@@ -15,13 +15,13 @@ export function DeleteCredentialsModal({
return (
<Modal
width="max-w-3xl"
title={`Nuke ${modelProvider.name} Credentials?`}
title={`Delete ${modelProvider.provider_type} Credentials?`}
onOutsideClick={onCancel}
>
<div className="mb-4">
<Text className="text-lg mb-2">
You&apos;re about to delete your {modelProvider.name} credentials. Are
you sure?
You&apos;re about to delete your {modelProvider.provider_type}{" "}
credentials. Are you sure?
</Text>
<Callout
title="Point of No Return"

View File

@@ -19,24 +19,24 @@ export function ProviderCreationModal({
onCancel: () => void;
existingProvider?: CloudEmbeddingProvider;
}) {
const useFileUpload = selectedProvider.name == "Google";
const useFileUpload = selectedProvider.provider_type == "Google";
const [isProcessing, setIsProcessing] = useState(false);
const [errorMsg, setErrorMsg] = useState<string>("");
const [fileName, setFileName] = useState<string>("");
const initialValues = {
name: existingProvider?.name || selectedProvider.name,
provider_type:
existingProvider?.provider_type || selectedProvider.provider_type,
api_key: existingProvider?.api_key || "",
custom_config: existingProvider?.custom_config
? Object.entries(existingProvider.custom_config)
: [],
default_model_name: "",
model_id: 0,
};
const validationSchema = Yup.object({
name: Yup.string().required("Name is required"),
provider_type: Yup.string().required("Provider type is required"),
api_key: useFileUpload
? Yup.string()
: Yup.string().required("API Key is required"),
@@ -76,7 +76,6 @@ export function ProviderCreationModal({
) => {
setIsProcessing(true);
setErrorMsg("");
try {
const customConfig = Object.fromEntries(values.custom_config);
@@ -86,7 +85,7 @@ export function ProviderCreationModal({
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider: values.name.toLowerCase().split(" ")[0],
provider_type: values.provider_type.toLowerCase().split(" ")[0],
api_key: values.api_key,
}),
}
@@ -105,6 +104,7 @@ export function ProviderCreationModal({
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
...values,
provider_type: values.provider_type.toLowerCase().split(" ")[0],
custom_config: customConfig,
is_default_provider: false,
is_configured: true,
@@ -134,7 +134,7 @@ export function ProviderCreationModal({
return (
<Modal
width="max-w-3xl"
title={`Configure ${selectedProvider.name}`}
title={`Configure ${selectedProvider.provider_type}`}
onOutsideClick={onCancel}
icon={selectedProvider.icon}
>

View File

@@ -39,12 +39,12 @@ export default function CloudEmbeddingPage({
React.SetStateAction<CloudEmbeddingProvider | null>
>;
}) {
function hasNameInArray(
arr: Array<{ name: string }>,
function hasProviderTypeinArray(
arr: Array<{ provider_type: string }>,
searchName: string
): boolean {
return arr.some(
(item) => item.name.toLowerCase() === searchName.toLowerCase()
(item) => item.provider_type.toLowerCase() === searchName.toLowerCase()
);
}
@@ -52,10 +52,13 @@ export default function CloudEmbeddingPage({
(model) => ({
...model,
configured:
!newUnenabledProviders.includes(model.name) &&
(newEnabledProviders.includes(model.name) ||
!newUnenabledProviders.includes(model.provider_type) &&
(newEnabledProviders.includes(model.provider_type) ||
(embeddingProviderDetails &&
hasNameInArray(embeddingProviderDetails, model.name))!),
hasProviderTypeinArray(
embeddingProviderDetails,
model.provider_type
))!),
})
);
@@ -71,11 +74,12 @@ export default function CloudEmbeddingPage({
<div className="gap-4 mt-2 pb-10 flex content-start flex-wrap">
{providers.map((provider) => (
<div key={provider.name} className="mt-4 w-full">
<div key={provider.provider_type} className="mt-4 w-full">
<div className="flex items-center mb-2">
{provider.icon({ size: 40 })}
<h2 className="ml-2 mt-2 text-xl font-bold">
{provider.name} {provider.name == "Cohere" && "(recommended)"}
{provider.provider_type}{" "}
{provider.provider_type == "Cohere" && "(recommended)"}
</h2>
<HoverPopup
mainContent={

View File

@@ -167,12 +167,14 @@ export default function EmbeddingForm() {
const onConfirm = async () => {
let newModel: EmbeddingModelDescriptor;
if ("cloud_provider_name" in selectedProvider) {
if ("provider_type" in selectedProvider) {
// This is a CloudEmbeddingModel
newModel = {
...selectedProvider,
model_name: selectedProvider.model_name,
cloud_provider_name: selectedProvider.cloud_provider_name,
provider_type: selectedProvider.provider_type
?.toLowerCase()
.split(" ")[0],
};
} else {
// This is an EmbeddingModelDescriptor
@@ -180,7 +182,7 @@ export default function EmbeddingForm() {
...selectedProvider,
model_name: selectedProvider.model_name!,
description: "",
cloud_provider_name: null,
provider_type: null,
};
}

View File

@@ -9,11 +9,15 @@ import {
VoyageIcon,
} from "@/components/icons/icons";
// Cloud Provider (not needed for hosted ones)
export enum EmbeddingProvider {
OPENAI = "OpenAI",
COHERE = "Cohere",
VOYAGE = "Voyage",
GOOGLE = "Google",
}
export interface CloudEmbeddingProvider {
id: number;
name: string;
provider_type: EmbeddingProvider;
api_key?: string;
custom_config?: Record<string, string>;
docsLink?: string;
@@ -37,12 +41,11 @@ export interface EmbeddingModelDescriptor {
normalize: boolean;
query_prefix: string;
passage_prefix: string;
cloud_provider_name?: string | null;
provider_type?: string | null;
description: string;
}
export interface CloudEmbeddingModel extends EmbeddingModelDescriptor {
cloud_provider_name: string | null;
pricePerMillion: number;
enabled?: boolean;
mtebScore: number;
@@ -124,8 +127,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
{
id: 1,
name: "Cohere",
provider_type: EmbeddingProvider.COHERE,
website: "https://cohere.ai",
icon: CohereIcon,
docsLink:
@@ -136,8 +138,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
costslink: "https://cohere.com/pricing",
embedding_models: [
{
provider_type: EmbeddingProvider.COHERE,
model_name: "embed-english-v3.0",
cloud_provider_name: "Cohere",
description:
"Cohere's English embedding model. Good performance for English-language tasks.",
pricePerMillion: 0.1,
@@ -151,7 +153,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
},
{
model_name: "embed-english-light-v3.0",
cloud_provider_name: "Cohere",
provider_type: EmbeddingProvider.COHERE,
description:
"Cohere's lightweight English embedding model. Faster and more efficient for simpler tasks.",
pricePerMillion: 0.1,
@@ -166,8 +168,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
],
},
{
id: 0,
name: "OpenAI",
provider_type: EmbeddingProvider.OPENAI,
website: "https://openai.com",
icon: OpenAIIcon,
description: "AI industry leader known for ChatGPT and DALL-E",
@@ -177,8 +178,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
costslink: "https://openai.com/pricing",
embedding_models: [
{
provider_type: EmbeddingProvider.OPENAI,
model_name: "text-embedding-3-large",
cloud_provider_name: "OpenAI",
description:
"OpenAI's large embedding model. Best performance, but more expensive.",
pricePerMillion: 0.13,
@@ -191,8 +192,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
enabled: false,
},
{
provider_type: EmbeddingProvider.OPENAI,
model_name: "text-embedding-3-small",
cloud_provider_name: "OpenAI",
model_dim: 1536,
normalize: false,
query_prefix: "",
@@ -208,8 +209,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
},
{
id: 2,
name: "Google",
provider_type: EmbeddingProvider.GOOGLE,
website: "https://ai.google",
icon: GoogleIcon,
docsLink:
@@ -220,7 +220,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
costslink: "https://cloud.google.com/vertex-ai/pricing",
embedding_models: [
{
cloud_provider_name: "Google",
provider_type: EmbeddingProvider.GOOGLE,
model_name: "text-embedding-004",
description: "Google's most recent text embedding model.",
pricePerMillion: 0.025,
@@ -233,7 +233,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "",
},
{
cloud_provider_name: "Google",
provider_type: EmbeddingProvider.GOOGLE,
model_name: "textembedding-gecko@003",
description: "Google's Gecko embedding model. Powerful and efficient.",
pricePerMillion: 0.025,
@@ -248,8 +248,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
],
},
{
id: 3,
name: "Voyage",
provider_type: EmbeddingProvider.VOYAGE,
website: "https://www.voyageai.com",
icon: VoyageIcon,
description: "Advanced NLP research startup born from Stanford AI Labs",
@@ -259,7 +258,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
costslink: "https://www.voyageai.com/pricing",
embedding_models: [
{
cloud_provider_name: "Voyage",
provider_type: EmbeddingProvider.VOYAGE,
model_name: "voyage-large-2-instruct",
description:
"Voyage's large embedding model. High performance with instruction fine-tuning.",
@@ -273,7 +272,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "",
},
{
cloud_provider_name: "Voyage",
provider_type: EmbeddingProvider.VOYAGE,
model_name: "voyage-light-2-instruct",
description:
"Voyage's lightweight embedding model. Good balance of performance and efficiency.",