Add Litellm Rerank proxy (#2346)

* add ability ot set reranking litellm proxy

* add fully functional rerank litellm cards

* minor formatting enforcement

* remove logs
This commit is contained in:
pablodanswer
2024-09-09 08:57:01 -07:00
committed by GitHub
parent f04ecbf87a
commit 3a9b964d5c
13 changed files with 231 additions and 26 deletions

View File

@@ -7,7 +7,11 @@ import {
rerankingModels,
} from "./interfaces";
import { FiExternalLink } from "react-icons/fi";
import { CohereIcon, MixedBreadIcon } from "@/components/icons/icons";
import {
CohereIcon,
LiteLLMIcon,
MixedBreadIcon,
} from "@/components/icons/icons";
import { Modal } from "@/components/Modal";
import { Button } from "@tremor/react";
import { TextFormField } from "@/components/admin/connectors/Field";
@@ -35,6 +39,8 @@ const RerankingDetailsForm = forwardRef<
ref
) => {
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false);
const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] =
useState(false);
return (
<Formik
@@ -48,13 +54,17 @@ const RerankingDetailsForm = forwardRef<
.optional(),
api_key: Yup.string().nullable(),
num_rerank: Yup.number().min(1, "Must be at least 1"),
rerank_api_url: Yup.string()
.url("Must be a valid URL")
.matches(/^https?:\/\//, "URL must start with http:// or https://")
.nullable(),
})}
onSubmit={async (_, { setSubmitting }) => {
setSubmitting(false);
}}
enableReinitialize={true}
>
{({ values, setFieldValue }) => {
{({ values, setFieldValue, resetForm }) => {
const resetRerankingValues = () => {
setRerankingDetails({
...values,
@@ -131,14 +141,22 @@ const RerankingDetailsForm = forwardRef<
)
: rerankingModels.filter(
(modelCard) =>
modelCard.modelName ==
originalRerankingDetails.rerank_model_name
(modelCard.modelName ==
originalRerankingDetails.rerank_model_name &&
modelCard.rerank_provider_type ==
originalRerankingDetails.rerank_provider_type) ||
(modelCard.rerank_provider_type ==
RerankerProvider.LITELLM &&
originalRerankingDetails.rerank_provider_type ==
RerankerProvider.LITELLM)
)
).map((card) => {
const isSelected =
values.rerank_provider_type ===
card.rerank_provider_type &&
values.rerank_model_name === card.modelName;
(card.modelName == null ||
values.rerank_model_name === card.modelName);
return (
<div
key={`${card.rerank_provider_type}-${card.modelName}`}
@@ -148,26 +166,39 @@ const RerankingDetailsForm = forwardRef<
: "border-gray-200 hover:border-blue-300 hover:shadow-sm"
}`}
onClick={() => {
if (card.rerank_provider_type) {
if (
card.rerank_provider_type == RerankerProvider.COHERE
) {
setIsApiKeyModalOpen(true);
} else if (
card.rerank_provider_type ==
RerankerProvider.LITELLM
) {
setShowLiteLLMConfigurationModal(true);
}
if (!isSelected) {
setRerankingDetails({
...values,
rerank_provider_type: card.rerank_provider_type!,
rerank_model_name: card.modelName || null,
rerank_api_key: null,
rerank_api_url: null,
});
setFieldValue(
"rerank_provider_type",
card.rerank_provider_type
);
setFieldValue("rerank_model_name", card.modelName);
}
setRerankingDetails({
...values,
rerank_provider_type: card.rerank_provider_type!,
rerank_model_name: card.modelName,
rerank_api_key: null,
});
setFieldValue(
"rerank_provider_type",
card.rerank_provider_type
);
setFieldValue("rerank_model_name", card.modelName);
}}
>
<div className="flex items-center justify-between mb-3">
<div className="flex items-center">
{card.rerank_provider_type ===
RerankerProvider.COHERE ? (
RerankerProvider.LITELLM ? (
<LiteLLMIcon size={24} className="mr-2" />
) : RerankerProvider.COHERE ? (
<CohereIcon size={24} className="mr-2" />
) : (
<MixedBreadIcon size={24} className="mr-2" />
@@ -199,6 +230,88 @@ const RerankingDetailsForm = forwardRef<
})}
</div>
{showLiteLLMConfigurationModal && (
<Modal
onOutsideClick={() => {
resetForm();
setShowLiteLLMConfigurationModal(false);
}}
width="w-[800px]"
title="API Key Configuration"
>
<div className="w-full flex flex-col gap-y-4 px-4">
<TextFormField
subtext="Set the URL at which your LiteLLM Proxy is hosted"
placeholder={values.rerank_api_url || undefined}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
setRerankingDetails({
...values,
rerank_api_url: value,
});
setFieldValue("rerank_api_url", value);
}}
type="text"
label="LiteLLM Proxy URL"
name="rerank_api_url"
/>
<TextFormField
subtext="Set the key to access your LiteLLM Proxy"
placeholder={
values.rerank_api_key
? "*".repeat(values.rerank_api_key.length)
: undefined
}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
setRerankingDetails({
...values,
rerank_api_key: value,
});
setFieldValue("rerank_api_key", value);
}}
type="password"
label="LiteLLM Proxy Key"
name="rerank_api_key"
optional
/>
<TextFormField
subtext="Set the model name to use for LiteLLM Proxy"
placeholder={
values.rerank_model_name
? "*".repeat(values.rerank_model_name.length)
: undefined
}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
setRerankingDetails({
...values,
rerank_model_name: value,
});
setFieldValue("rerank_model_name", value);
}}
label="LiteLLM Model Name"
name="rerank_model_name"
optional
/>
<div className="flex w-full justify-end mt-4">
<Button
onClick={() => {
setShowLiteLLMConfigurationModal(false);
}}
color="blue"
size="xs"
>
Update
</Button>
</div>
</div>
</Modal>
)}
{isApiKeyModalOpen && (
<Modal
onOutsideClick={() => {
@@ -218,7 +331,11 @@ const RerankingDetailsForm = forwardRef<
>
<div className="w-full px-4">
<TextFormField
placeholder={values.rerank_api_key || undefined}
placeholder={
values.rerank_api_key
? "*".repeat(values.rerank_api_key.length)
: undefined
}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
setRerankingDetails({

View File

@@ -5,11 +5,13 @@ export interface RerankingDetails {
rerank_model_name: string | null;
rerank_provider_type: RerankerProvider | null;
rerank_api_key: string | null;
rerank_api_url: string | null;
num_rerank: number;
}
export enum RerankerProvider {
COHERE = "cohere",
LITELLM = "litellm",
}
export interface AdvancedSearchConfiguration {
model_name: string;
@@ -40,7 +42,7 @@ export interface SavedSearchSettings extends RerankingDetails {
export interface RerankingModel {
rerank_provider_type: RerankerProvider | null;
modelName: string;
modelName?: string;
displayName: string;
description: string;
link: string;
@@ -48,6 +50,13 @@ export interface RerankingModel {
}
export const rerankingModels: RerankingModel[] = [
{
rerank_provider_type: RerankerProvider.LITELLM,
cloud: true,
displayName: "LiteLLM",
description: "Host your own reranker or router with LiteLLM proxy",
link: "https://docs.litellm.ai/docs/proxy",
},
{
rerank_provider_type: null,
cloud: false,

View File

@@ -4,7 +4,7 @@ import * as Yup from "yup";
import CredentialSubText from "@/components/credentials/CredentialFields";
import { TrashIcon } from "@/components/icons/icons";
import { FaPlus } from "react-icons/fa";
import { AdvancedSearchConfiguration, RerankingDetails } from "../interfaces";
import { AdvancedSearchConfiguration } from "../interfaces";
import { BooleanFormField } from "@/components/admin/connectors/Field";
import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput";

View File

@@ -10,7 +10,7 @@ import {
CloudEmbeddingModel,
EmbeddingProvider,
HostedEmbeddingModel,
} from "../../../../components/embedding/interfaces";
} from "@/components/embedding/interfaces";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { ErrorCallout } from "@/components/ErrorCallout";
import useSWR, { mutate } from "swr";
@@ -18,7 +18,6 @@ import { ThreeDotsLoader } from "@/components/Loading";
import AdvancedEmbeddingFormPage from "./AdvancedEmbeddingFormPage";
import {
AdvancedSearchConfiguration,
RerankerProvider,
RerankingDetails,
SavedSearchSettings,
} from "../interfaces";
@@ -49,6 +48,7 @@ export default function EmbeddingForm() {
num_rerank: 0,
rerank_provider_type: null,
rerank_model_name: "",
rerank_api_url: null,
});
const updateAdvancedEmbeddingDetails = (
@@ -124,6 +124,7 @@ export default function EmbeddingForm() {
num_rerank: searchSettings.num_rerank,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
rerank_api_url: searchSettings.rerank_api_url,
});
}
}, [searchSettings]);
@@ -134,12 +135,14 @@ export default function EmbeddingForm() {
num_rerank: searchSettings.num_rerank,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
rerank_api_url: searchSettings.rerank_api_url,
}
: {
rerank_api_key: "",
num_rerank: 0,
rerank_provider_type: null,
rerank_model_name: "",
rerank_api_url: null,
};
useEffect(() => {