mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 20:24:32 +02:00
Update embedding interface (#2205)
* squash * simplify interface * some updates to typing * cloud provider type * update typing to be even clearer * push local commit (squash) * cleaner interfaces * another quick pass * squash * cleaner alembic * cleaner * remove trailing whitespace * add sequence * quick circle back to double check * update * update naming * update naming
This commit is contained in:
@@ -3,7 +3,6 @@ import { Modal } from "@/components/Modal";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { ConnectorIndexingStatus } from "@/lib/types";
|
||||
import { Button, Text, Title } from "@tremor/react";
|
||||
import Link from "next/link";
|
||||
import { useState } from "react";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { ReindexingProgressTable } from "../../../../components/embedding/ReindexingProgressTable";
|
||||
|
@@ -79,7 +79,7 @@ function Main() {
|
||||
(provider) =>
|
||||
provider.embedding_models.map((model) => ({
|
||||
...model,
|
||||
cloud_provider_id: provider.id,
|
||||
provider_type: provider.provider_type,
|
||||
model_name: model.model_name, // Ensure model_name is set for consistency
|
||||
}))
|
||||
);
|
||||
|
@@ -11,6 +11,7 @@ import {
|
||||
INVALID_OLD_MODEL,
|
||||
HostedEmbeddingModel,
|
||||
EmbeddingModelDescriptor,
|
||||
EmbeddingProvider,
|
||||
} from "../../../components/embedding/interfaces";
|
||||
import { Connector } from "@/lib/connectors/connectors";
|
||||
import OpenEmbeddingPage from "./pages/OpenEmbeddingPage";
|
||||
@@ -28,8 +29,7 @@ import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../configuration/llm/constants";
|
||||
export interface EmbeddingDetails {
|
||||
api_key: string;
|
||||
custom_config: any;
|
||||
default_model_id?: number;
|
||||
name: string;
|
||||
provider_type: EmbeddingProvider;
|
||||
}
|
||||
|
||||
export function EmbeddingModelSelection({
|
||||
@@ -122,28 +122,28 @@ export function EmbeddingModelSelection({
|
||||
};
|
||||
|
||||
const clientsideAddProvider = (provider: CloudEmbeddingProvider) => {
|
||||
const providerName = provider.name;
|
||||
const providerType = provider.provider_type;
|
||||
setNewEnabledProviders((newEnabledProviders) => [
|
||||
...newEnabledProviders,
|
||||
providerName,
|
||||
providerType,
|
||||
]);
|
||||
setNewUnenabledProviders((newUnenabledProviders) =>
|
||||
newUnenabledProviders.filter(
|
||||
(givenProvidername) => givenProvidername != providerName
|
||||
(givenProviderType) => givenProviderType != providerType
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
const clientsideRemoveProvider = (provider: CloudEmbeddingProvider) => {
|
||||
const providerName = provider.name;
|
||||
const providerType = provider.provider_type;
|
||||
setNewEnabledProviders((newEnabledProviders) =>
|
||||
newEnabledProviders.filter(
|
||||
(givenProvidername) => givenProvidername != providerName
|
||||
(givenProviderType) => givenProviderType != providerType
|
||||
)
|
||||
);
|
||||
setNewUnenabledProviders((newUnenabledProviders) => [
|
||||
...newUnenabledProviders,
|
||||
providerName,
|
||||
providerType,
|
||||
]);
|
||||
};
|
||||
|
||||
@@ -191,7 +191,7 @@ export function EmbeddingModelSelection({
|
||||
)}
|
||||
{changeCredentialsProvider && (
|
||||
<ChangeCredentialsModal
|
||||
useFileUpload={changeCredentialsProvider.name == "Google"}
|
||||
useFileUpload={changeCredentialsProvider.provider_type == "Google"}
|
||||
onDeleted={() => {
|
||||
clientsideRemoveProvider(changeCredentialsProvider);
|
||||
setChangeCredentialsProvider(null);
|
||||
|
@@ -74,7 +74,7 @@ export function ChangeCredentialsModal({
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.name}`,
|
||||
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
@@ -99,19 +99,12 @@ export function ChangeCredentialsModal({
|
||||
|
||||
const handleSubmit = async () => {
|
||||
setTestError("");
|
||||
|
||||
try {
|
||||
const body = JSON.stringify({
|
||||
api_key: apiKey,
|
||||
provider: provider.name.toLowerCase().split(" ")[0],
|
||||
default_model_id: provider.name,
|
||||
});
|
||||
|
||||
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider: provider.name.toLowerCase().split(" ")[0],
|
||||
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
|
||||
api_key: apiKey,
|
||||
}),
|
||||
});
|
||||
@@ -125,7 +118,7 @@ export function ChangeCredentialsModal({
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
name: provider.name,
|
||||
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
|
||||
api_key: apiKey,
|
||||
is_default_provider: false,
|
||||
is_configured: true,
|
||||
@@ -151,7 +144,7 @@ export function ChangeCredentialsModal({
|
||||
<Modal
|
||||
width="max-w-3xl"
|
||||
icon={provider.icon}
|
||||
title={`Modify your ${provider.name} key`}
|
||||
title={`Modify your ${provider.provider_type} key`}
|
||||
onOutsideClick={onCancel}
|
||||
>
|
||||
<div className="mb-4">
|
||||
|
@@ -15,13 +15,13 @@ export function DeleteCredentialsModal({
|
||||
return (
|
||||
<Modal
|
||||
width="max-w-3xl"
|
||||
title={`Nuke ${modelProvider.name} Credentials?`}
|
||||
title={`Delete ${modelProvider.provider_type} Credentials?`}
|
||||
onOutsideClick={onCancel}
|
||||
>
|
||||
<div className="mb-4">
|
||||
<Text className="text-lg mb-2">
|
||||
You're about to delete your {modelProvider.name} credentials. Are
|
||||
you sure?
|
||||
You're about to delete your {modelProvider.provider_type}{" "}
|
||||
credentials. Are you sure?
|
||||
</Text>
|
||||
<Callout
|
||||
title="Point of No Return"
|
||||
|
@@ -19,24 +19,24 @@ export function ProviderCreationModal({
|
||||
onCancel: () => void;
|
||||
existingProvider?: CloudEmbeddingProvider;
|
||||
}) {
|
||||
const useFileUpload = selectedProvider.name == "Google";
|
||||
const useFileUpload = selectedProvider.provider_type == "Google";
|
||||
|
||||
const [isProcessing, setIsProcessing] = useState(false);
|
||||
const [errorMsg, setErrorMsg] = useState<string>("");
|
||||
const [fileName, setFileName] = useState<string>("");
|
||||
|
||||
const initialValues = {
|
||||
name: existingProvider?.name || selectedProvider.name,
|
||||
provider_type:
|
||||
existingProvider?.provider_type || selectedProvider.provider_type,
|
||||
api_key: existingProvider?.api_key || "",
|
||||
custom_config: existingProvider?.custom_config
|
||||
? Object.entries(existingProvider.custom_config)
|
||||
: [],
|
||||
default_model_name: "",
|
||||
model_id: 0,
|
||||
};
|
||||
|
||||
const validationSchema = Yup.object({
|
||||
name: Yup.string().required("Name is required"),
|
||||
provider_type: Yup.string().required("Provider type is required"),
|
||||
api_key: useFileUpload
|
||||
? Yup.string()
|
||||
: Yup.string().required("API Key is required"),
|
||||
@@ -76,7 +76,6 @@ export function ProviderCreationModal({
|
||||
) => {
|
||||
setIsProcessing(true);
|
||||
setErrorMsg("");
|
||||
|
||||
try {
|
||||
const customConfig = Object.fromEntries(values.custom_config);
|
||||
|
||||
@@ -86,7 +85,7 @@ export function ProviderCreationModal({
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider: values.name.toLowerCase().split(" ")[0],
|
||||
provider_type: values.provider_type.toLowerCase().split(" ")[0],
|
||||
api_key: values.api_key,
|
||||
}),
|
||||
}
|
||||
@@ -105,6 +104,7 @@ export function ProviderCreationModal({
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
...values,
|
||||
provider_type: values.provider_type.toLowerCase().split(" ")[0],
|
||||
custom_config: customConfig,
|
||||
is_default_provider: false,
|
||||
is_configured: true,
|
||||
@@ -134,7 +134,7 @@ export function ProviderCreationModal({
|
||||
return (
|
||||
<Modal
|
||||
width="max-w-3xl"
|
||||
title={`Configure ${selectedProvider.name}`}
|
||||
title={`Configure ${selectedProvider.provider_type}`}
|
||||
onOutsideClick={onCancel}
|
||||
icon={selectedProvider.icon}
|
||||
>
|
||||
|
@@ -39,12 +39,12 @@ export default function CloudEmbeddingPage({
|
||||
React.SetStateAction<CloudEmbeddingProvider | null>
|
||||
>;
|
||||
}) {
|
||||
function hasNameInArray(
|
||||
arr: Array<{ name: string }>,
|
||||
function hasProviderTypeinArray(
|
||||
arr: Array<{ provider_type: string }>,
|
||||
searchName: string
|
||||
): boolean {
|
||||
return arr.some(
|
||||
(item) => item.name.toLowerCase() === searchName.toLowerCase()
|
||||
(item) => item.provider_type.toLowerCase() === searchName.toLowerCase()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -52,10 +52,13 @@ export default function CloudEmbeddingPage({
|
||||
(model) => ({
|
||||
...model,
|
||||
configured:
|
||||
!newUnenabledProviders.includes(model.name) &&
|
||||
(newEnabledProviders.includes(model.name) ||
|
||||
!newUnenabledProviders.includes(model.provider_type) &&
|
||||
(newEnabledProviders.includes(model.provider_type) ||
|
||||
(embeddingProviderDetails &&
|
||||
hasNameInArray(embeddingProviderDetails, model.name))!),
|
||||
hasProviderTypeinArray(
|
||||
embeddingProviderDetails,
|
||||
model.provider_type
|
||||
))!),
|
||||
})
|
||||
);
|
||||
|
||||
@@ -71,11 +74,12 @@ export default function CloudEmbeddingPage({
|
||||
|
||||
<div className="gap-4 mt-2 pb-10 flex content-start flex-wrap">
|
||||
{providers.map((provider) => (
|
||||
<div key={provider.name} className="mt-4 w-full">
|
||||
<div key={provider.provider_type} className="mt-4 w-full">
|
||||
<div className="flex items-center mb-2">
|
||||
{provider.icon({ size: 40 })}
|
||||
<h2 className="ml-2 mt-2 text-xl font-bold">
|
||||
{provider.name} {provider.name == "Cohere" && "(recommended)"}
|
||||
{provider.provider_type}{" "}
|
||||
{provider.provider_type == "Cohere" && "(recommended)"}
|
||||
</h2>
|
||||
<HoverPopup
|
||||
mainContent={
|
||||
|
@@ -167,12 +167,14 @@ export default function EmbeddingForm() {
|
||||
const onConfirm = async () => {
|
||||
let newModel: EmbeddingModelDescriptor;
|
||||
|
||||
if ("cloud_provider_name" in selectedProvider) {
|
||||
if ("provider_type" in selectedProvider) {
|
||||
// This is a CloudEmbeddingModel
|
||||
newModel = {
|
||||
...selectedProvider,
|
||||
model_name: selectedProvider.model_name,
|
||||
cloud_provider_name: selectedProvider.cloud_provider_name,
|
||||
provider_type: selectedProvider.provider_type
|
||||
?.toLowerCase()
|
||||
.split(" ")[0],
|
||||
};
|
||||
} else {
|
||||
// This is an EmbeddingModelDescriptor
|
||||
@@ -180,7 +182,7 @@ export default function EmbeddingForm() {
|
||||
...selectedProvider,
|
||||
model_name: selectedProvider.model_name!,
|
||||
description: "",
|
||||
cloud_provider_name: null,
|
||||
provider_type: null,
|
||||
};
|
||||
}
|
||||
|
||||
|
@@ -9,11 +9,15 @@ import {
|
||||
VoyageIcon,
|
||||
} from "@/components/icons/icons";
|
||||
|
||||
// Cloud Provider (not needed for hosted ones)
|
||||
export enum EmbeddingProvider {
|
||||
OPENAI = "OpenAI",
|
||||
COHERE = "Cohere",
|
||||
VOYAGE = "Voyage",
|
||||
GOOGLE = "Google",
|
||||
}
|
||||
|
||||
export interface CloudEmbeddingProvider {
|
||||
id: number;
|
||||
name: string;
|
||||
provider_type: EmbeddingProvider;
|
||||
api_key?: string;
|
||||
custom_config?: Record<string, string>;
|
||||
docsLink?: string;
|
||||
@@ -37,12 +41,11 @@ export interface EmbeddingModelDescriptor {
|
||||
normalize: boolean;
|
||||
query_prefix: string;
|
||||
passage_prefix: string;
|
||||
cloud_provider_name?: string | null;
|
||||
provider_type?: string | null;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export interface CloudEmbeddingModel extends EmbeddingModelDescriptor {
|
||||
cloud_provider_name: string | null;
|
||||
pricePerMillion: number;
|
||||
enabled?: boolean;
|
||||
mtebScore: number;
|
||||
@@ -124,8 +127,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
|
||||
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
{
|
||||
id: 1,
|
||||
name: "Cohere",
|
||||
provider_type: EmbeddingProvider.COHERE,
|
||||
website: "https://cohere.ai",
|
||||
icon: CohereIcon,
|
||||
docsLink:
|
||||
@@ -136,8 +138,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
costslink: "https://cohere.com/pricing",
|
||||
embedding_models: [
|
||||
{
|
||||
provider_type: EmbeddingProvider.COHERE,
|
||||
model_name: "embed-english-v3.0",
|
||||
cloud_provider_name: "Cohere",
|
||||
description:
|
||||
"Cohere's English embedding model. Good performance for English-language tasks.",
|
||||
pricePerMillion: 0.1,
|
||||
@@ -151,7 +153,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
},
|
||||
{
|
||||
model_name: "embed-english-light-v3.0",
|
||||
cloud_provider_name: "Cohere",
|
||||
provider_type: EmbeddingProvider.COHERE,
|
||||
description:
|
||||
"Cohere's lightweight English embedding model. Faster and more efficient for simpler tasks.",
|
||||
pricePerMillion: 0.1,
|
||||
@@ -166,8 +168,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
],
|
||||
},
|
||||
{
|
||||
id: 0,
|
||||
name: "OpenAI",
|
||||
provider_type: EmbeddingProvider.OPENAI,
|
||||
website: "https://openai.com",
|
||||
icon: OpenAIIcon,
|
||||
description: "AI industry leader known for ChatGPT and DALL-E",
|
||||
@@ -177,8 +178,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
costslink: "https://openai.com/pricing",
|
||||
embedding_models: [
|
||||
{
|
||||
provider_type: EmbeddingProvider.OPENAI,
|
||||
model_name: "text-embedding-3-large",
|
||||
cloud_provider_name: "OpenAI",
|
||||
description:
|
||||
"OpenAI's large embedding model. Best performance, but more expensive.",
|
||||
pricePerMillion: 0.13,
|
||||
@@ -191,8 +192,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
enabled: false,
|
||||
},
|
||||
{
|
||||
provider_type: EmbeddingProvider.OPENAI,
|
||||
model_name: "text-embedding-3-small",
|
||||
cloud_provider_name: "OpenAI",
|
||||
model_dim: 1536,
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
@@ -208,8 +209,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
},
|
||||
|
||||
{
|
||||
id: 2,
|
||||
name: "Google",
|
||||
provider_type: EmbeddingProvider.GOOGLE,
|
||||
website: "https://ai.google",
|
||||
icon: GoogleIcon,
|
||||
docsLink:
|
||||
@@ -220,7 +220,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
costslink: "https://cloud.google.com/vertex-ai/pricing",
|
||||
embedding_models: [
|
||||
{
|
||||
cloud_provider_name: "Google",
|
||||
provider_type: EmbeddingProvider.GOOGLE,
|
||||
model_name: "text-embedding-004",
|
||||
description: "Google's most recent text embedding model.",
|
||||
pricePerMillion: 0.025,
|
||||
@@ -233,7 +233,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
passage_prefix: "",
|
||||
},
|
||||
{
|
||||
cloud_provider_name: "Google",
|
||||
provider_type: EmbeddingProvider.GOOGLE,
|
||||
model_name: "textembedding-gecko@003",
|
||||
description: "Google's Gecko embedding model. Powerful and efficient.",
|
||||
pricePerMillion: 0.025,
|
||||
@@ -248,8 +248,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
],
|
||||
},
|
||||
{
|
||||
id: 3,
|
||||
name: "Voyage",
|
||||
provider_type: EmbeddingProvider.VOYAGE,
|
||||
website: "https://www.voyageai.com",
|
||||
icon: VoyageIcon,
|
||||
description: "Advanced NLP research startup born from Stanford AI Labs",
|
||||
@@ -259,7 +258,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
costslink: "https://www.voyageai.com/pricing",
|
||||
embedding_models: [
|
||||
{
|
||||
cloud_provider_name: "Voyage",
|
||||
provider_type: EmbeddingProvider.VOYAGE,
|
||||
model_name: "voyage-large-2-instruct",
|
||||
description:
|
||||
"Voyage's large embedding model. High performance with instruction fine-tuning.",
|
||||
@@ -273,7 +272,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
passage_prefix: "",
|
||||
},
|
||||
{
|
||||
cloud_provider_name: "Voyage",
|
||||
provider_type: EmbeddingProvider.VOYAGE,
|
||||
model_name: "voyage-light-2-instruct",
|
||||
description:
|
||||
"Voyage's lightweight embedding model. Good balance of performance and efficiency.",
|
||||
|
Reference in New Issue
Block a user