mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-01 00:18:18 +02:00
Add Litellm Rerank proxy (#2346)
* add ability ot set reranking litellm proxy * add fully functional rerank litellm cards * minor formatting enforcement * remove logs
This commit is contained in:
parent
f04ecbf87a
commit
3a9b964d5c
@ -0,0 +1,26 @@
|
||||
"""add support for litellm proxy in reranking
|
||||
|
||||
Revision ID: ba98eba0f66a
|
||||
Revises: bceb1e139447
|
||||
Create Date: 2024-09-06 10:36:04.507332
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ba98eba0f66a"
|
||||
down_revision = "bceb1e139447"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_settings", "rerank_api_url")
|
@ -576,6 +576,8 @@ class SearchSettings(Base):
|
||||
Enum(RerankerProvider, native_enum=False), nullable=True
|
||||
)
|
||||
rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
rerank_api_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS)
|
||||
|
||||
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
|
||||
|
@ -392,8 +392,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
logger.notice(
|
||||
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
|
||||
)
|
||||
|
||||
if search_settings.rerank_model_name and not search_settings.provider_type:
|
||||
if (
|
||||
search_settings.rerank_model_name
|
||||
and not search_settings.provider_type
|
||||
and not search_settings.rerank_provider_type
|
||||
):
|
||||
warm_up_cross_encoder(search_settings.rerank_model_name)
|
||||
|
||||
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
|
@ -242,6 +242,7 @@ class RerankingModel:
|
||||
model_name: str,
|
||||
provider_type: RerankerProvider | None,
|
||||
api_key: str | None,
|
||||
api_url: str | None,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
@ -250,6 +251,7 @@ class RerankingModel:
|
||||
self.model_name = model_name
|
||||
self.provider_type = provider_type
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
|
||||
def predict(self, query: str, passages: list[str]) -> list[float]:
|
||||
rerank_request = RerankRequest(
|
||||
@ -258,6 +260,7 @@ class RerankingModel:
|
||||
model_name=self.model_name,
|
||||
provider_type=self.provider_type,
|
||||
api_key=self.api_key,
|
||||
api_url=self.api_url,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
@ -400,6 +403,7 @@ def warm_up_cross_encoder(
|
||||
reranking_model = RerankingModel(
|
||||
model_name=rerank_model_name,
|
||||
provider_type=None,
|
||||
api_url=None,
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
|
@ -26,6 +26,7 @@ MAX_METRICS_CONTENT = (
|
||||
class RerankingDetails(BaseModel):
|
||||
# If model is None (or num_rerank is 0), then reranking is turned off
|
||||
rerank_model_name: str | None
|
||||
rerank_api_url: str | None
|
||||
rerank_provider_type: RerankerProvider | None
|
||||
rerank_api_key: str | None = None
|
||||
|
||||
@ -42,6 +43,7 @@ class RerankingDetails(BaseModel):
|
||||
rerank_provider_type=search_settings.rerank_provider_type,
|
||||
rerank_api_key=search_settings.rerank_api_key,
|
||||
num_rerank=search_settings.num_rerank,
|
||||
rerank_api_url=search_settings.rerank_api_url,
|
||||
)
|
||||
|
||||
|
||||
@ -81,7 +83,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
num_rerank=search_settings.num_rerank,
|
||||
# Multilingual Expansion
|
||||
multilingual_expansion=search_settings.multilingual_expansion,
|
||||
api_url=search_settings.api_url,
|
||||
rerank_api_url=search_settings.rerank_api_url,
|
||||
)
|
||||
|
||||
|
||||
|
@ -100,6 +100,7 @@ def semantic_reranking(
|
||||
model_name=rerank_settings.rerank_model_name,
|
||||
provider_type=rerank_settings.rerank_provider_type,
|
||||
api_key=rerank_settings.rerank_api_key,
|
||||
api_url=rerank_settings.rerank_api_url,
|
||||
)
|
||||
|
||||
passages = [
|
||||
|
@ -362,6 +362,28 @@ def cohere_rerank(
|
||||
return [result.relevance_score for result in sorted_results]
|
||||
|
||||
|
||||
def litellm_rerank(
|
||||
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
|
||||
) -> list[float]:
|
||||
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [
|
||||
item["relevance_score"]
|
||||
for item in sorted(result["results"], key=lambda x: x["index"])
|
||||
]
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
@ -418,6 +440,20 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
model_name=rerank_request.model_name,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
elif rerank_request.provider_type == RerankerProvider.LITELLM:
|
||||
if rerank_request.api_url is None:
|
||||
raise ValueError("API URL is required for LiteLLM reranking.")
|
||||
|
||||
sim_scores = litellm_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
api_url=rerank_request.api_url,
|
||||
model_name=rerank_request.model_name,
|
||||
api_key=rerank_request.api_key,
|
||||
)
|
||||
|
||||
return RerankResponse(scores=sim_scores)
|
||||
|
||||
elif rerank_request.provider_type == RerankerProvider.COHERE:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Cohere Rerank Requires an API Key")
|
||||
|
@ -11,6 +11,7 @@ class EmbeddingProvider(str, Enum):
|
||||
|
||||
class RerankerProvider(str, Enum):
|
||||
COHERE = "cohere"
|
||||
LITELLM = "litellm"
|
||||
|
||||
|
||||
class EmbedTextType(str, Enum):
|
||||
|
@ -43,6 +43,7 @@ class RerankRequest(BaseModel):
|
||||
model_name: str
|
||||
provider_type: RerankerProvider | None = None
|
||||
api_key: str | None = None
|
||||
api_url: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
@ -7,7 +7,11 @@ import {
|
||||
rerankingModels,
|
||||
} from "./interfaces";
|
||||
import { FiExternalLink } from "react-icons/fi";
|
||||
import { CohereIcon, MixedBreadIcon } from "@/components/icons/icons";
|
||||
import {
|
||||
CohereIcon,
|
||||
LiteLLMIcon,
|
||||
MixedBreadIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Button } from "@tremor/react";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
@ -35,6 +39,8 @@ const RerankingDetailsForm = forwardRef<
|
||||
ref
|
||||
) => {
|
||||
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false);
|
||||
const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] =
|
||||
useState(false);
|
||||
|
||||
return (
|
||||
<Formik
|
||||
@ -48,13 +54,17 @@ const RerankingDetailsForm = forwardRef<
|
||||
.optional(),
|
||||
api_key: Yup.string().nullable(),
|
||||
num_rerank: Yup.number().min(1, "Must be at least 1"),
|
||||
rerank_api_url: Yup.string()
|
||||
.url("Must be a valid URL")
|
||||
.matches(/^https?:\/\//, "URL must start with http:// or https://")
|
||||
.nullable(),
|
||||
})}
|
||||
onSubmit={async (_, { setSubmitting }) => {
|
||||
setSubmitting(false);
|
||||
}}
|
||||
enableReinitialize={true}
|
||||
>
|
||||
{({ values, setFieldValue }) => {
|
||||
{({ values, setFieldValue, resetForm }) => {
|
||||
const resetRerankingValues = () => {
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
@ -131,14 +141,22 @@ const RerankingDetailsForm = forwardRef<
|
||||
)
|
||||
: rerankingModels.filter(
|
||||
(modelCard) =>
|
||||
modelCard.modelName ==
|
||||
originalRerankingDetails.rerank_model_name
|
||||
(modelCard.modelName ==
|
||||
originalRerankingDetails.rerank_model_name &&
|
||||
modelCard.rerank_provider_type ==
|
||||
originalRerankingDetails.rerank_provider_type) ||
|
||||
(modelCard.rerank_provider_type ==
|
||||
RerankerProvider.LITELLM &&
|
||||
originalRerankingDetails.rerank_provider_type ==
|
||||
RerankerProvider.LITELLM)
|
||||
)
|
||||
).map((card) => {
|
||||
const isSelected =
|
||||
values.rerank_provider_type ===
|
||||
card.rerank_provider_type &&
|
||||
values.rerank_model_name === card.modelName;
|
||||
(card.modelName == null ||
|
||||
values.rerank_model_name === card.modelName);
|
||||
|
||||
return (
|
||||
<div
|
||||
key={`${card.rerank_provider_type}-${card.modelName}`}
|
||||
@ -148,26 +166,39 @@ const RerankingDetailsForm = forwardRef<
|
||||
: "border-gray-200 hover:border-blue-300 hover:shadow-sm"
|
||||
}`}
|
||||
onClick={() => {
|
||||
if (card.rerank_provider_type) {
|
||||
if (
|
||||
card.rerank_provider_type == RerankerProvider.COHERE
|
||||
) {
|
||||
setIsApiKeyModalOpen(true);
|
||||
} else if (
|
||||
card.rerank_provider_type ==
|
||||
RerankerProvider.LITELLM
|
||||
) {
|
||||
setShowLiteLLMConfigurationModal(true);
|
||||
}
|
||||
|
||||
if (!isSelected) {
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_provider_type: card.rerank_provider_type!,
|
||||
rerank_model_name: card.modelName || null,
|
||||
rerank_api_key: null,
|
||||
rerank_api_url: null,
|
||||
});
|
||||
setFieldValue(
|
||||
"rerank_provider_type",
|
||||
card.rerank_provider_type
|
||||
);
|
||||
setFieldValue("rerank_model_name", card.modelName);
|
||||
}
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_provider_type: card.rerank_provider_type!,
|
||||
rerank_model_name: card.modelName,
|
||||
rerank_api_key: null,
|
||||
});
|
||||
setFieldValue(
|
||||
"rerank_provider_type",
|
||||
card.rerank_provider_type
|
||||
);
|
||||
setFieldValue("rerank_model_name", card.modelName);
|
||||
}}
|
||||
>
|
||||
<div className="flex items-center justify-between mb-3">
|
||||
<div className="flex items-center">
|
||||
{card.rerank_provider_type ===
|
||||
RerankerProvider.COHERE ? (
|
||||
RerankerProvider.LITELLM ? (
|
||||
<LiteLLMIcon size={24} className="mr-2" />
|
||||
) : RerankerProvider.COHERE ? (
|
||||
<CohereIcon size={24} className="mr-2" />
|
||||
) : (
|
||||
<MixedBreadIcon size={24} className="mr-2" />
|
||||
@ -199,6 +230,88 @@ const RerankingDetailsForm = forwardRef<
|
||||
})}
|
||||
</div>
|
||||
|
||||
{showLiteLLMConfigurationModal && (
|
||||
<Modal
|
||||
onOutsideClick={() => {
|
||||
resetForm();
|
||||
setShowLiteLLMConfigurationModal(false);
|
||||
}}
|
||||
width="w-[800px]"
|
||||
title="API Key Configuration"
|
||||
>
|
||||
<div className="w-full flex flex-col gap-y-4 px-4">
|
||||
<TextFormField
|
||||
subtext="Set the URL at which your LiteLLM Proxy is hosted"
|
||||
placeholder={values.rerank_api_url || undefined}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_api_url: value,
|
||||
});
|
||||
setFieldValue("rerank_api_url", value);
|
||||
}}
|
||||
type="text"
|
||||
label="LiteLLM Proxy URL"
|
||||
name="rerank_api_url"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
subtext="Set the key to access your LiteLLM Proxy"
|
||||
placeholder={
|
||||
values.rerank_api_key
|
||||
? "*".repeat(values.rerank_api_key.length)
|
||||
: undefined
|
||||
}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_api_key: value,
|
||||
});
|
||||
setFieldValue("rerank_api_key", value);
|
||||
}}
|
||||
type="password"
|
||||
label="LiteLLM Proxy Key"
|
||||
name="rerank_api_key"
|
||||
optional
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
subtext="Set the model name to use for LiteLLM Proxy"
|
||||
placeholder={
|
||||
values.rerank_model_name
|
||||
? "*".repeat(values.rerank_model_name.length)
|
||||
: undefined
|
||||
}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
rerank_model_name: value,
|
||||
});
|
||||
setFieldValue("rerank_model_name", value);
|
||||
}}
|
||||
label="LiteLLM Model Name"
|
||||
name="rerank_model_name"
|
||||
optional
|
||||
/>
|
||||
|
||||
<div className="flex w-full justify-end mt-4">
|
||||
<Button
|
||||
onClick={() => {
|
||||
setShowLiteLLMConfigurationModal(false);
|
||||
}}
|
||||
color="blue"
|
||||
size="xs"
|
||||
>
|
||||
Update
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
)}
|
||||
|
||||
{isApiKeyModalOpen && (
|
||||
<Modal
|
||||
onOutsideClick={() => {
|
||||
@ -218,7 +331,11 @@ const RerankingDetailsForm = forwardRef<
|
||||
>
|
||||
<div className="w-full px-4">
|
||||
<TextFormField
|
||||
placeholder={values.rerank_api_key || undefined}
|
||||
placeholder={
|
||||
values.rerank_api_key
|
||||
? "*".repeat(values.rerank_api_key.length)
|
||||
: undefined
|
||||
}
|
||||
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
setRerankingDetails({
|
||||
|
@ -5,11 +5,13 @@ export interface RerankingDetails {
|
||||
rerank_model_name: string | null;
|
||||
rerank_provider_type: RerankerProvider | null;
|
||||
rerank_api_key: string | null;
|
||||
rerank_api_url: string | null;
|
||||
num_rerank: number;
|
||||
}
|
||||
|
||||
export enum RerankerProvider {
|
||||
COHERE = "cohere",
|
||||
LITELLM = "litellm",
|
||||
}
|
||||
export interface AdvancedSearchConfiguration {
|
||||
model_name: string;
|
||||
@ -40,7 +42,7 @@ export interface SavedSearchSettings extends RerankingDetails {
|
||||
|
||||
export interface RerankingModel {
|
||||
rerank_provider_type: RerankerProvider | null;
|
||||
modelName: string;
|
||||
modelName?: string;
|
||||
displayName: string;
|
||||
description: string;
|
||||
link: string;
|
||||
@ -48,6 +50,13 @@ export interface RerankingModel {
|
||||
}
|
||||
|
||||
export const rerankingModels: RerankingModel[] = [
|
||||
{
|
||||
rerank_provider_type: RerankerProvider.LITELLM,
|
||||
cloud: true,
|
||||
displayName: "LiteLLM",
|
||||
description: "Host your own reranker or router with LiteLLM proxy",
|
||||
link: "https://docs.litellm.ai/docs/proxy",
|
||||
},
|
||||
{
|
||||
rerank_provider_type: null,
|
||||
cloud: false,
|
||||
|
@ -4,7 +4,7 @@ import * as Yup from "yup";
|
||||
import CredentialSubText from "@/components/credentials/CredentialFields";
|
||||
import { TrashIcon } from "@/components/icons/icons";
|
||||
import { FaPlus } from "react-icons/fa";
|
||||
import { AdvancedSearchConfiguration, RerankingDetails } from "../interfaces";
|
||||
import { AdvancedSearchConfiguration } from "../interfaces";
|
||||
import { BooleanFormField } from "@/components/admin/connectors/Field";
|
||||
import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput";
|
||||
|
||||
|
@ -10,7 +10,7 @@ import {
|
||||
CloudEmbeddingModel,
|
||||
EmbeddingProvider,
|
||||
HostedEmbeddingModel,
|
||||
} from "../../../../components/embedding/interfaces";
|
||||
} from "@/components/embedding/interfaces";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import useSWR, { mutate } from "swr";
|
||||
@ -18,7 +18,6 @@ import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import AdvancedEmbeddingFormPage from "./AdvancedEmbeddingFormPage";
|
||||
import {
|
||||
AdvancedSearchConfiguration,
|
||||
RerankerProvider,
|
||||
RerankingDetails,
|
||||
SavedSearchSettings,
|
||||
} from "../interfaces";
|
||||
@ -49,6 +48,7 @@ export default function EmbeddingForm() {
|
||||
num_rerank: 0,
|
||||
rerank_provider_type: null,
|
||||
rerank_model_name: "",
|
||||
rerank_api_url: null,
|
||||
});
|
||||
|
||||
const updateAdvancedEmbeddingDetails = (
|
||||
@ -124,6 +124,7 @@ export default function EmbeddingForm() {
|
||||
num_rerank: searchSettings.num_rerank,
|
||||
rerank_provider_type: searchSettings.rerank_provider_type,
|
||||
rerank_model_name: searchSettings.rerank_model_name,
|
||||
rerank_api_url: searchSettings.rerank_api_url,
|
||||
});
|
||||
}
|
||||
}, [searchSettings]);
|
||||
@ -134,12 +135,14 @@ export default function EmbeddingForm() {
|
||||
num_rerank: searchSettings.num_rerank,
|
||||
rerank_provider_type: searchSettings.rerank_provider_type,
|
||||
rerank_model_name: searchSettings.rerank_model_name,
|
||||
rerank_api_url: searchSettings.rerank_api_url,
|
||||
}
|
||||
: {
|
||||
rerank_api_key: "",
|
||||
num_rerank: 0,
|
||||
rerank_provider_type: null,
|
||||
rerank_model_name: "",
|
||||
rerank_api_url: null,
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
|
Loading…
x
Reference in New Issue
Block a user