mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-06 18:14:35 +02:00
Add Litellm Rerank proxy (#2346)
* add ability ot set reranking litellm proxy * add fully functional rerank litellm cards * minor formatting enforcement * remove logs
This commit is contained in:
@@ -7,7 +7,11 @@ import {
|
||||
rerankingModels,
|
||||
} from "./interfaces";
|
||||
import { FiExternalLink } from "react-icons/fi";
|
||||
import { CohereIcon, MixedBreadIcon } from "@/components/icons/icons";
|
||||
import {
|
||||
CohereIcon,
|
||||
LiteLLMIcon,
|
||||
MixedBreadIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button } from "@tremor/react";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
@@ -35,6 +39,8 @@ const RerankingDetailsForm = forwardRef<
|
||||
ref
|
||||
) => {
|
||||
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false);
|
||||
const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] =
|
||||
useState(false);
|
||||
|
||||
return (
|
||||
<Formik
|
||||
@@ -48,13 +54,17 @@ const RerankingDetailsForm = forwardRef<
|
||||
.optional(),
|
||||
api_key: Yup.string().nullable(),
|
||||
num_rerank: Yup.number().min(1, "Must be at least 1"),
|
||||
rerank_api_url: Yup.string()
|
||||
.url("Must be a valid URL")
|
||||
.matches(/^https?:\/\//, "URL must start with http:// or https://")
|
||||
.nullable(),
|
||||
})}
|
||||
onSubmit={async (_, { setSubmitting }) => {
|
||||
setSubmitting(false);
|
||||
}}
|
||||
enableReinitialize={true}
|
||||
>
|
||||
{({ values, setFieldValue }) => {
|
||||
{({ values, setFieldValue, resetForm }) => {
|
||||
const resetRerankingValues = () => {
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
@@ -131,14 +141,22 @@ const RerankingDetailsForm = forwardRef<
|
||||
)
|
||||
: rerankingModels.filter(
|
||||
(modelCard) =>
|
||||
modelCard.modelName ==
|
||||
originalRerankingDetails.rerank_model_name
|
||||
(modelCard.modelName ==
|
||||
originalRerankingDetails.rerank_model_name &&
|
||||
modelCard.rerank_provider_type ==
|
||||
originalRerankingDetails.rerank_provider_type) ||
|
||||
(modelCard.rerank_provider_type ==
|
||||
RerankerProvider.LITELLM &&
|
||||
originalRerankingDetails.rerank_provider_type ==
|
||||
RerankerProvider.LITELLM)
|
||||
)
|
||||
).map((card) => {
|
||||
const isSelected =
|
||||
values.rerank_provider_type ===
|
||||
card.rerank_provider_type &&
|
||||
values.rerank_model_name === card.modelName;
|
||||
(card.modelName == null ||
|
||||
values.rerank_model_name === card.modelName);
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${card.rerank_provider_type}-${card.modelName}`}
|
||||
@@ -148,26 +166,39 @@ const RerankingDetailsForm = forwardRef<
|
||||
: "border-gray-200 hover:border-blue-300 hover:shadow-sm"
|
||||
}`}
|
||||
onClick={() => {
|
||||
if (card.rerank_provider_type) {
|
||||
if (
|
||||
card.rerank_provider_type == RerankerProvider.COHERE
|
||||
) {
|
||||
setIsApiKeyModalOpen(true);
|
||||
} else if (
|
||||
card.rerank_provider_type ==
|
||||
RerankerProvider.LITELLM
|
||||
) {
|
||||
setShowLiteLLMConfigurationModal(true);
|
||||
}
|
||||
|
||||
if (!isSelected) {
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_provider_type: card.rerank_provider_type!,
|
||||
rerank_model_name: card.modelName || null,
|
||||
rerank_api_key: null,
|
||||
rerank_api_url: null,
|
||||
});
|
||||
setFieldValue(
|
||||
"rerank_provider_type",
|
||||
card.rerank_provider_type
|
||||
);
|
||||
setFieldValue("rerank_model_name", card.modelName);
|
||||
}
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_provider_type: card.rerank_provider_type!,
|
||||
rerank_model_name: card.modelName,
|
||||
rerank_api_key: null,
|
||||
});
|
||||
setFieldValue(
|
||||
"rerank_provider_type",
|
||||
card.rerank_provider_type
|
||||
);
|
||||
setFieldValue("rerank_model_name", card.modelName);
|
||||
}}
|
||||
>
|
||||
<div className="flex items-center justify-between mb-3">
|
||||
<div className="flex items-center">
|
||||
{card.rerank_provider_type ===
|
||||
RerankerProvider.COHERE ? (
|
||||
RerankerProvider.LITELLM ? (
|
||||
<LiteLLMIcon size={24} className="mr-2" />
|
||||
) : RerankerProvider.COHERE ? (
|
||||
<CohereIcon size={24} className="mr-2" />
|
||||
) : (
|
||||
<MixedBreadIcon size={24} className="mr-2" />
|
||||
@@ -199,6 +230,88 @@ const RerankingDetailsForm = forwardRef<
|
||||
})}
|
||||
</div>
|
||||
|
||||
{showLiteLLMConfigurationModal && (
|
||||
<Modal
|
||||
onOutsideClick={() => {
|
||||
resetForm();
|
||||
setShowLiteLLMConfigurationModal(false);
|
||||
}}
|
||||
width="w-[800px]"
|
||||
title="API Key Configuration"
|
||||
>
|
||||
<div className="w-full flex flex-col gap-y-4 px-4">
|
||||
<TextFormField
|
||||
subtext="Set the URL at which your LiteLLM Proxy is hosted"
|
||||
placeholder={values.rerank_api_url || undefined}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_api_url: value,
|
||||
});
|
||||
setFieldValue("rerank_api_url", value);
|
||||
}}
|
||||
type="text"
|
||||
label="LiteLLM Proxy URL"
|
||||
name="rerank_api_url"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
subtext="Set the key to access your LiteLLM Proxy"
|
||||
placeholder={
|
||||
values.rerank_api_key
|
||||
? "*".repeat(values.rerank_api_key.length)
|
||||
: undefined
|
||||
}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_api_key: value,
|
||||
});
|
||||
setFieldValue("rerank_api_key", value);
|
||||
}}
|
||||
type="password"
|
||||
label="LiteLLM Proxy Key"
|
||||
name="rerank_api_key"
|
||||
optional
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
subtext="Set the model name to use for LiteLLM Proxy"
|
||||
placeholder={
|
||||
values.rerank_model_name
|
||||
? "*".repeat(values.rerank_model_name.length)
|
||||
: undefined
|
||||
}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_model_name: value,
|
||||
});
|
||||
setFieldValue("rerank_model_name", value);
|
||||
}}
|
||||
label="LiteLLM Model Name"
|
||||
name="rerank_model_name"
|
||||
optional
|
||||
/>
|
||||
|
||||
<div className="flex w-full justify-end mt-4">
|
||||
<Button
|
||||
onClick={() => {
|
||||
setShowLiteLLMConfigurationModal(false);
|
||||
}}
|
||||
color="blue"
|
||||
size="xs"
|
||||
>
|
||||
Update
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
)}
|
||||
|
||||
{isApiKeyModalOpen && (
|
||||
<Modal
|
||||
onOutsideClick={() => {
|
||||
@@ -218,7 +331,11 @@ const RerankingDetailsForm = forwardRef<
|
||||
>
|
||||
<div className="w-full px-4">
|
||||
<TextFormField
|
||||
placeholder={values.rerank_api_key || undefined}
|
||||
placeholder={
|
||||
values.rerank_api_key
|
||||
? "*".repeat(values.rerank_api_key.length)
|
||||
: undefined
|
||||
}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
setRerankingDetails({
|
||||
|
@@ -5,11 +5,13 @@ export interface RerankingDetails {
|
||||
rerank_model_name: string | null;
|
||||
rerank_provider_type: RerankerProvider | null;
|
||||
rerank_api_key: string | null;
|
||||
rerank_api_url: string | null;
|
||||
num_rerank: number;
|
||||
}
|
||||
|
||||
export enum RerankerProvider {
|
||||
COHERE = "cohere",
|
||||
LITELLM = "litellm",
|
||||
}
|
||||
export interface AdvancedSearchConfiguration {
|
||||
model_name: string;
|
||||
@@ -40,7 +42,7 @@ export interface SavedSearchSettings extends RerankingDetails {
|
||||
|
||||
export interface RerankingModel {
|
||||
rerank_provider_type: RerankerProvider | null;
|
||||
modelName: string;
|
||||
modelName?: string;
|
||||
displayName: string;
|
||||
description: string;
|
||||
link: string;
|
||||
@@ -48,6 +50,13 @@ export interface RerankingModel {
|
||||
}
|
||||
|
||||
export const rerankingModels: RerankingModel[] = [
|
||||
{
|
||||
rerank_provider_type: RerankerProvider.LITELLM,
|
||||
cloud: true,
|
||||
displayName: "LiteLLM",
|
||||
description: "Host your own reranker or router with LiteLLM proxy",
|
||||
link: "https://docs.litellm.ai/docs/proxy",
|
||||
},
|
||||
{
|
||||
rerank_provider_type: null,
|
||||
cloud: false,
|
||||
|
@@ -4,7 +4,7 @@ import * as Yup from "yup";
|
||||
import CredentialSubText from "@/components/credentials/CredentialFields";
|
||||
import { TrashIcon } from "@/components/icons/icons";
|
||||
import { FaPlus } from "react-icons/fa";
|
||||
import { AdvancedSearchConfiguration, RerankingDetails } from "../interfaces";
|
||||
import { AdvancedSearchConfiguration } from "../interfaces";
|
||||
import { BooleanFormField } from "@/components/admin/connectors/Field";
|
||||
import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput";
|
||||
|
||||
|
@@ -10,7 +10,7 @@ import {
|
||||
CloudEmbeddingModel,
|
||||
EmbeddingProvider,
|
||||
HostedEmbeddingModel,
|
||||
} from "../../../../components/embedding/interfaces";
|
||||
} from "@/components/embedding/interfaces";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import useSWR, { mutate } from "swr";
|
||||
@@ -18,7 +18,6 @@ import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import AdvancedEmbeddingFormPage from "./AdvancedEmbeddingFormPage";
|
||||
import {
|
||||
AdvancedSearchConfiguration,
|
||||
RerankerProvider,
|
||||
RerankingDetails,
|
||||
SavedSearchSettings,
|
||||
} from "../interfaces";
|
||||
@@ -49,6 +48,7 @@ export default function EmbeddingForm() {
|
||||
num_rerank: 0,
|
||||
rerank_provider_type: null,
|
||||
rerank_model_name: "",
|
||||
rerank_api_url: null,
|
||||
});
|
||||
|
||||
const updateAdvancedEmbeddingDetails = (
|
||||
@@ -124,6 +124,7 @@ export default function EmbeddingForm() {
|
||||
num_rerank: searchSettings.num_rerank,
|
||||
rerank_provider_type: searchSettings.rerank_provider_type,
|
||||
rerank_model_name: searchSettings.rerank_model_name,
|
||||
rerank_api_url: searchSettings.rerank_api_url,
|
||||
});
|
||||
}
|
||||
}, [searchSettings]);
|
||||
@@ -134,12 +135,14 @@ export default function EmbeddingForm() {
|
||||
num_rerank: searchSettings.num_rerank,
|
||||
rerank_provider_type: searchSettings.rerank_provider_type,
|
||||
rerank_model_name: searchSettings.rerank_model_name,
|
||||
rerank_api_url: searchSettings.rerank_api_url,
|
||||
}
|
||||
: {
|
||||
rerank_api_key: "",
|
||||
num_rerank: 0,
|
||||
rerank_provider_type: null,
|
||||
rerank_model_name: "",
|
||||
rerank_api_url: null,
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
|
Reference in New Issue
Block a user