* k

* update enum imports

* add functional types + model swaps

* remove a log

* remove kv

* fully functional + robustified for kv swap

* validated with hosted + cloud

* ensure not updating current search settings when reindexing

* add instance check

* revert back to updating search settings (will need a slight refactor for endpoint)

* protect advanced config override1

* run pretty

* fix typing

* update typing

* remove unnecessary function

* update model name

* clearer interface names

* validated foreign key constaint

* proper migration

* squash

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
pablodanswer
2024-08-26 21:26:51 -07:00
committed by GitHub
parent 5f12b7ad58
commit 97ba71e1b3
54 changed files with 1078 additions and 673 deletions

View File

@@ -40,7 +40,7 @@ export default function UpgradingPage({
method: "POST",
});
if (response.ok) {
mutate("/api/search-settings/get-secondary-embedding-model");
mutate("/api/search-settings/get-secondary-search-settings");
} else {
alert(
`Failed to cancel embedding model update - ${await response.text()}`

View File

@@ -36,14 +36,14 @@ function Main() {
isLoading: isLoadingCurrentModel,
error: currentEmeddingModelError,
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
"/api/search-settings/get-current-embedding-model",
"/api/search-settings/get-current-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
const { data: searchSettings, isLoading: isLoadingSearchSettings } =
useSWR<SavedSearchSettings | null>(
"/api/search-settings/get-search-settings",
"/api/search-settings/get-current-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
@@ -53,7 +53,7 @@ function Main() {
isLoading: isLoadingFutureModel,
error: futureEmeddingModelError,
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
"/api/search-settings/get-secondary-embedding-model",
"/api/search-settings/get-secondary-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);

View File

@@ -93,10 +93,10 @@ export function EmbeddingModelSelection({
const onConfirmSelection = async (model: EmbeddingModelDescriptor) => {
const response = await fetch(
"/api/search-settings/set-new-embedding-model",
"/api/search-settings/set-new-search-settings",
{
method: "POST",
body: JSON.stringify(model),
body: JSON.stringify({ ...model, index_name: null }),
headers: {
"Content-Type": "application/json",
},
@@ -104,7 +104,7 @@ export function EmbeddingModelSelection({
);
if (response.ok) {
setShowTentativeModel(null);
mutate("/api/search-settings/get-secondary-embedding-model");
mutate("/api/search-settings/get-secondary-search-settings");
if (!connectors || !connectors.length) {
setShowAddConnectorPopup(true);
}

View File

@@ -114,32 +114,36 @@ const RerankingDetailsForm = forwardRef<
)
).map((card) => {
const isSelected =
values.provider_type === card.provider &&
values.rerank_provider_type === card.rerank_provider_type &&
values.rerank_model_name === card.modelName;
return (
<div
key={`${card.provider}-${card.modelName}`}
key={`${card.rerank_provider_type}-${card.modelName}`}
className={`p-4 border rounded-lg cursor-pointer transition-all duration-200 ${
isSelected
? "border-blue-500 bg-blue-50 shadow-md"
: "border-gray-200 hover:border-blue-300 hover:shadow-sm"
}`}
onClick={() => {
if (card.provider) {
if (card.rerank_provider_type) {
setIsApiKeyModalOpen(true);
}
setRerankingDetails({
...values,
provider_type: card.provider!,
rerank_provider_type: card.rerank_provider_type!,
rerank_model_name: card.modelName,
});
setFieldValue("provider_type", card.provider);
setFieldValue(
"rerank_provider_type",
card.rerank_provider_type
);
setFieldValue("rerank_model_name", card.modelName);
}}
>
<div className="flex items-center justify-between mb-3">
<div className="flex items-center">
{card.provider === RerankerProvider.COHERE ? (
{card.rerank_provider_type ===
RerankerProvider.COHERE ? (
<CohereIcon size={24} className="mr-2" />
) : (
<MixedBreadIcon size={24} className="mr-2" />

View File

@@ -1,6 +1,9 @@
import { EmbeddingProvider } from "@/components/embedding/interfaces";
import { NonNullChain } from "typescript";
export interface RerankingDetails {
rerank_model_name: string | null;
provider_type: RerankerProvider | null;
rerank_provider_type: RerankerProvider | null;
api_key: string | null;
num_rerank: number;
}
@@ -8,20 +11,33 @@ export interface RerankingDetails {
export enum RerankerProvider {
COHERE = "cohere",
}
export interface AdvancedDetails {
multilingual_expansion: string[];
export interface AdvancedSearchConfiguration {
model_name: string;
model_dim: number;
normalize: boolean;
query_prefix: string;
passage_prefix: string;
index_name: string | null;
multipass_indexing: boolean;
multilingual_expansion: string[];
disable_rerank_for_streaming: boolean;
}
export interface SavedSearchSettings extends RerankingDetails {
multilingual_expansion: string[];
model_name: string;
model_dim: number;
normalize: boolean;
query_prefix: string;
passage_prefix: string;
index_name: string | null;
multipass_indexing: boolean;
multilingual_expansion: string[];
disable_rerank_for_streaming: boolean;
provider_type: EmbeddingProvider | null;
}
export interface RerankingModel {
provider?: RerankerProvider;
rerank_provider_type: RerankerProvider | null;
modelName: string;
displayName: string;
description: string;
@@ -31,6 +47,7 @@ export interface RerankingModel {
export const rerankingModels: RerankingModel[] = [
{
rerank_provider_type: null,
cloud: false,
modelName: "mixedbread-ai/mxbai-rerank-xsmall-v1",
displayName: "MixedBread XSmall",
@@ -38,6 +55,7 @@ export const rerankingModels: RerankingModel[] = [
link: "https://huggingface.co/mixedbread-ai/mxbai-rerank-xsmall-v1",
},
{
rerank_provider_type: null,
cloud: false,
modelName: "mixedbread-ai/mxbai-rerank-base-v1",
displayName: "MixedBread Base",
@@ -45,6 +63,7 @@ export const rerankingModels: RerankingModel[] = [
link: "https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1",
},
{
rerank_provider_type: null,
cloud: false,
modelName: "mixedbread-ai/mxbai-rerank-large-v1",
displayName: "MixedBread Large",
@@ -53,7 +72,7 @@ export const rerankingModels: RerankingModel[] = [
},
{
cloud: true,
provider: RerankerProvider.COHERE,
rerank_provider_type: RerankerProvider.COHERE,
modelName: "rerank-english-v3.0",
displayName: "Cohere English",
description: "High-performance English-focused reranking model.",
@@ -61,7 +80,7 @@ export const rerankingModels: RerankingModel[] = [
},
{
cloud: true,
provider: RerankerProvider.COHERE,
rerank_provider_type: RerankerProvider.COHERE,
modelName: "rerank-multilingual-v3.0",
displayName: "Cohere Multilingual",
description: "Powerful multilingual reranking model.",

View File

@@ -5,14 +5,14 @@ import { EditingValue } from "@/components/credentials/EditingValue";
import CredentialSubText from "@/components/credentials/CredentialFields";
import { TrashIcon } from "@/components/icons/icons";
import { FaPlus } from "react-icons/fa";
import { AdvancedDetails, RerankingDetails } from "../interfaces";
import { AdvancedSearchConfiguration, RerankingDetails } from "../interfaces";
interface AdvancedEmbeddingFormPageProps {
updateAdvancedEmbeddingDetails: (
key: keyof AdvancedDetails,
key: keyof AdvancedSearchConfiguration,
value: any
) => void;
advancedEmbeddingDetails: AdvancedDetails;
advancedEmbeddingDetails: AdvancedSearchConfiguration;
setRerankingDetails: Dispatch<SetStateAction<RerankingDetails>>;
numRerank: number;
}

View File

@@ -8,7 +8,7 @@ import { Button, Card, Text } from "@tremor/react";
import { ArrowLeft, ArrowRight, WarningCircle } from "@phosphor-icons/react";
import {
CloudEmbeddingModel,
EmbeddingModelDescriptor,
EmbeddingProvider,
HostedEmbeddingModel,
} from "../../../../components/embedding/interfaces";
import { errorHandlingFetcher } from "@/lib/fetcher";
@@ -17,7 +17,8 @@ import useSWR, { mutate } from "swr";
import { ThreeDotsLoader } from "@/components/Loading";
import AdvancedEmbeddingFormPage from "./AdvancedEmbeddingFormPage";
import {
AdvancedDetails,
AdvancedSearchConfiguration,
RerankerProvider,
RerankingDetails,
SavedSearchSettings,
} from "../interfaces";
@@ -30,21 +31,27 @@ export default function EmbeddingForm() {
const { popup, setPopup } = usePopup();
const [advancedEmbeddingDetails, setAdvancedEmbeddingDetails] =
useState<AdvancedDetails>({
disable_rerank_for_streaming: false,
multilingual_expansion: [],
useState<AdvancedSearchConfiguration>({
model_name: "",
model_dim: 0,
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
multipass_indexing: true,
multilingual_expansion: [],
disable_rerank_for_streaming: false,
});
const [rerankingDetails, setRerankingDetails] = useState<RerankingDetails>({
api_key: "",
num_rerank: 0,
provider_type: null,
rerank_provider_type: null,
rerank_model_name: "",
});
const updateAdvancedEmbeddingDetails = (
key: keyof AdvancedDetails,
key: keyof AdvancedSearchConfiguration,
value: any
) => {
setAdvancedEmbeddingDetails((values) => ({ ...values, [key]: value }));
@@ -52,7 +59,7 @@ export default function EmbeddingForm() {
async function updateSearchSettings(searchSettings: SavedSearchSettings) {
const response = await fetch(
"/api/search-settings/update-search-settings",
"/api/search-settings/update-inference-settings",
{
method: "POST",
headers: {
@@ -80,7 +87,7 @@ export default function EmbeddingForm() {
isLoading: isLoadingCurrentModel,
error: currentEmbeddingModelError,
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
"/api/search-settings/get-current-embedding-model",
"/api/search-settings/get-current-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
@@ -91,7 +98,7 @@ export default function EmbeddingForm() {
const { data: searchSettings, isLoading: isLoadingSearchSettings } =
useSWR<SavedSearchSettings | null>(
"/api/search-settings/get-search-settings",
"/api/search-settings/get-current-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
@@ -99,31 +106,37 @@ export default function EmbeddingForm() {
useEffect(() => {
if (searchSettings) {
setAdvancedEmbeddingDetails({
model_name: searchSettings.model_name,
model_dim: searchSettings.model_dim,
normalize: searchSettings.normalize,
query_prefix: searchSettings.query_prefix,
passage_prefix: searchSettings.passage_prefix,
index_name: searchSettings.index_name,
multipass_indexing: searchSettings.multipass_indexing,
multilingual_expansion: searchSettings.multilingual_expansion,
disable_rerank_for_streaming:
searchSettings.disable_rerank_for_streaming,
multilingual_expansion: searchSettings.multilingual_expansion,
multipass_indexing: searchSettings.multipass_indexing,
});
setRerankingDetails({
api_key: searchSettings.api_key,
num_rerank: searchSettings.num_rerank,
provider_type: searchSettings.provider_type,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
});
}
}, [searchSettings]);
const originalRerankingDetails = searchSettings
const originalRerankingDetails: RerankingDetails = searchSettings
? {
api_key: searchSettings.api_key,
num_rerank: searchSettings.num_rerank,
provider_type: searchSettings.provider_type,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
}
: {
api_key: "",
num_rerank: 0,
provider_type: null,
rerank_provider_type: null,
rerank_model_name: "",
};
@@ -149,14 +162,17 @@ export default function EmbeddingForm() {
let values: SavedSearchSettings = {
...rerankingDetails,
...advancedEmbeddingDetails,
provider_type:
selectedProvider.provider_type?.toLowerCase() as EmbeddingProvider | null,
};
const response = await updateSearchSettings(values);
if (response.ok) {
setPopup({
message: "Updated search settings succesffuly",
type: "success",
});
mutate("/api/search-settings/get-search-settings");
mutate("/api/search-settings/get-current-search-settings");
return true;
} else {
setPopup({ message: "Failed to update search settings", type: "error" });
@@ -165,29 +181,37 @@ export default function EmbeddingForm() {
};
const onConfirm = async () => {
let newModel: EmbeddingModelDescriptor;
if (!selectedProvider) {
return;
}
let newModel: SavedSearchSettings;
if ("provider_type" in selectedProvider) {
// This is a CloudEmbeddingModel
if (selectedProvider.provider_type != null) {
// This is a cloud model
newModel = {
...advancedEmbeddingDetails,
...selectedProvider,
...rerankingDetails,
model_name: selectedProvider.model_name,
provider_type: selectedProvider.provider_type
?.toLowerCase()
.split(" ")[0],
provider_type:
(selectedProvider.provider_type
?.toLowerCase()
.split(" ")[0] as EmbeddingProvider) || null,
};
} else {
// This is an EmbeddingModelDescriptor
// This is a locally hosted model
newModel = {
...advancedEmbeddingDetails,
...selectedProvider,
...rerankingDetails,
model_name: selectedProvider.model_name!,
description: "",
provider_type: null,
};
}
newModel.index_name = null;
const response = await fetch(
"/api/search-settings/set-new-embedding-model",
"/api/search-settings/set-new-search-settings",
{
method: "POST",
body: JSON.stringify(newModel),
@@ -201,7 +225,7 @@ export default function EmbeddingForm() {
message: "Changed provider suceessfully. Redirecing to embedding page",
type: "success",
});
mutate("/api/search-settings/get-secondary-embedding-model");
mutate("/api/search-settings/get-secondary-search-settings");
setTimeout(() => {
window.open("/admin/configuration/search", "_self");
}, 2000);
@@ -217,14 +241,14 @@ export default function EmbeddingForm() {
searchSettings?.multipass_indexing !=
advancedEmbeddingDetails.multipass_indexing;
const ReIndxingButton = () => {
return (
const ReIndexingButton = ({ needsReIndex }: { needsReIndex: boolean }) => {
return needsReIndex ? (
<div className="flex mx-auto gap-x-1 ml-auto items-center">
<button
className="enabled:cursor-pointer disabled:bg-accent/50 disabled:cursor-not-allowed bg-accent flex gap-x-1 items-center text-white py-2.5 px-3.5 text-sm font-regular rounded-sm"
onClick={async () => {
const updated = await updateSearch();
if (updated) {
const update = await updateSearch();
if (update) {
await onConfirm();
}
}}
@@ -251,6 +275,15 @@ export default function EmbeddingForm() {
</div>
</div>
</div>
) : (
<button
className="enabled:cursor-pointer ml-auto disabled:bg-accent/50 disabled:cursor-not-allowed bg-accent flex mx-auto gap-x-1 items-center text-white py-2.5 px-3.5 text-sm font-regular rounded-sm"
onClick={async () => {
updateSearch();
}}
>
Update Search
</button>
);
};
@@ -361,18 +394,7 @@ export default function EmbeddingForm() {
Previous
</button>
{needsReIndex ? (
<ReIndxingButton />
) : (
<button
className="enabled:cursor-pointer ml-auto disabled:bg-accent/50 disabled:cursor-not-allowed bg-accent flex mx-auto gap-x-1 items-center text-white py-2.5 px-3.5 text-sm font-regular rounded-sm"
onClick={async () => {
updateSearch();
}}
>
Update Search
</button>
)}
<ReIndexingButton needsReIndex={needsReIndex} />
<div className="flex w-full justify-end">
<button
@@ -410,20 +432,7 @@ export default function EmbeddingForm() {
Previous
</button>
{needsReIndex ? (
<ReIndxingButton />
) : (
<button
className="enabled:cursor-pointer ml-auto disabled:bg-accent/50
disabled:cursor-not-allowed bg-accent flex mx-auto gap-x-1 items-center
text-white py-2.5 px-3.5 text-sm font-regular rounded-sm"
onClick={async () => {
updateSearch();
}}
>
Update Search
</button>
)}
<ReIndexingButton needsReIndex={needsReIndex} />
</div>
</>
)}

View File

@@ -49,7 +49,7 @@ export default async function Home() {
fetchSS("/manage/document-set"),
fetchAssistantsSS(),
fetchSS("/query/valid-tags"),
fetchSS("/search-settings/get-embedding-models"),
fetchSS("/search-settings/get-all-search-settings"),
fetchSS("/query/user-searches"),
];

View File

@@ -35,7 +35,13 @@ export function CustomModelForm({
normalize: Yup.boolean().required(),
})}
onSubmit={async (values, formikHelpers) => {
onSubmit({ ...values, model_dim: parseInt(values.model_dim) });
onSubmit({
...values,
model_dim: parseInt(values.model_dim),
api_key: null,
provider_type: null,
index_name: null,
});
}}
>
{({ isSubmitting, setFieldValue }) => (

View File

@@ -41,8 +41,10 @@ export interface EmbeddingModelDescriptor {
normalize: boolean;
query_prefix: string;
passage_prefix: string;
provider_type?: string | null;
provider_type: string | null;
description: string;
api_key: string | null;
index_name: string | null;
}
export interface CloudEmbeddingModel extends EmbeddingModelDescriptor {
@@ -82,6 +84,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
link: "https://huggingface.co/nomic-ai/nomic-embed-text-v1",
query_prefix: "search_query: ",
passage_prefix: "search_document: ",
index_name: "",
provider_type: null,
api_key: null,
},
{
model_name: "intfloat/e5-base-v2",
@@ -92,6 +97,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
link: "https://huggingface.co/intfloat/e5-base-v2",
query_prefix: "query: ",
passage_prefix: "passage: ",
index_name: "",
provider_type: null,
api_key: null,
},
{
model_name: "intfloat/e5-small-v2",
@@ -102,6 +110,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
link: "https://huggingface.co/intfloat/e5-small-v2",
query_prefix: "query: ",
passage_prefix: "passage: ",
index_name: "",
provider_type: null,
api_key: null,
},
{
model_name: "intfloat/multilingual-e5-base",
@@ -112,6 +123,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
link: "https://huggingface.co/intfloat/multilingual-e5-base",
query_prefix: "query: ",
passage_prefix: "passage: ",
index_name: "",
provider_type: null,
api_key: null,
},
{
model_name: "intfloat/multilingual-e5-small",
@@ -122,6 +136,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
link: "https://huggingface.co/intfloat/multilingual-e5-base",
query_prefix: "query: ",
passage_prefix: "passage: ",
index_name: "",
provider_type: null,
api_key: null,
},
];
@@ -150,6 +167,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
api_key: null,
},
{
model_name: "embed-english-light-v3.0",
@@ -164,6 +183,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
api_key: null,
},
],
},
@@ -190,6 +211,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
mtebScore: 64.6,
maxContext: 8191,
enabled: false,
index_name: "",
api_key: null,
},
{
provider_type: EmbeddingProvider.OPENAI,
@@ -204,6 +227,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
enabled: false,
mtebScore: 62.3,
maxContext: 8191,
index_name: "",
api_key: null,
},
],
},
@@ -231,6 +256,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
api_key: null,
},
{
provider_type: EmbeddingProvider.GOOGLE,
@@ -244,6 +271,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
api_key: null,
},
],
},
@@ -270,6 +299,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
api_key: null,
},
{
provider_type: EmbeddingProvider.VOYAGE,
@@ -284,6 +315,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
api_key: null,
},
],
},