Add Litellm Rerank proxy (#2346)

* add ability ot set reranking litellm proxy

* add fully functional rerank litellm cards

* minor formatting enforcement

* remove logs
This commit is contained in:
pablodanswer 2024-09-09 08:57:01 -07:00 committed by GitHub
parent f04ecbf87a
commit 3a9b964d5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 231 additions and 26 deletions

View File

@ -0,0 +1,26 @@
"""add support for litellm proxy in reranking
Revision ID: ba98eba0f66a
Revises: bceb1e139447
Create Date: 2024-09-06 10:36:04.507332
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ba98eba0f66a"
down_revision = "bceb1e139447"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("search_settings", "rerank_api_url")

View File

@ -576,6 +576,8 @@ class SearchSettings(Base):
Enum(RerankerProvider, native_enum=False), nullable=True
)
rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True)
rerank_api_url: Mapped[str | None] = mapped_column(String, nullable=True)
num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS)
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(

View File

@ -392,8 +392,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice(
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
)
if search_settings.rerank_model_name and not search_settings.provider_type:
if (
search_settings.rerank_model_name
and not search_settings.provider_type
and not search_settings.rerank_provider_type
):
warm_up_cross_encoder(search_settings.rerank_model_name)
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")

View File

@ -242,6 +242,7 @@ class RerankingModel:
model_name: str,
provider_type: RerankerProvider | None,
api_key: str | None,
api_url: str | None,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
@ -250,6 +251,7 @@ class RerankingModel:
self.model_name = model_name
self.provider_type = provider_type
self.api_key = api_key
self.api_url = api_url
def predict(self, query: str, passages: list[str]) -> list[float]:
rerank_request = RerankRequest(
@ -258,6 +260,7 @@ class RerankingModel:
model_name=self.model_name,
provider_type=self.provider_type,
api_key=self.api_key,
api_url=self.api_url,
)
response = requests.post(
@ -400,6 +403,7 @@ def warm_up_cross_encoder(
reranking_model = RerankingModel(
model_name=rerank_model_name,
provider_type=None,
api_url=None,
api_key=None,
)

View File

@ -26,6 +26,7 @@ MAX_METRICS_CONTENT = (
class RerankingDetails(BaseModel):
# If model is None (or num_rerank is 0), then reranking is turned off
rerank_model_name: str | None
rerank_api_url: str | None
rerank_provider_type: RerankerProvider | None
rerank_api_key: str | None = None
@ -42,6 +43,7 @@ class RerankingDetails(BaseModel):
rerank_provider_type=search_settings.rerank_provider_type,
rerank_api_key=search_settings.rerank_api_key,
num_rerank=search_settings.num_rerank,
rerank_api_url=search_settings.rerank_api_url,
)
@ -81,7 +83,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
num_rerank=search_settings.num_rerank,
# Multilingual Expansion
multilingual_expansion=search_settings.multilingual_expansion,
api_url=search_settings.api_url,
rerank_api_url=search_settings.rerank_api_url,
)

View File

@ -100,6 +100,7 @@ def semantic_reranking(
model_name=rerank_settings.rerank_model_name,
provider_type=rerank_settings.rerank_provider_type,
api_key=rerank_settings.rerank_api_key,
api_url=rerank_settings.rerank_api_url,
)
passages = [

View File

@ -362,6 +362,28 @@ def cohere_rerank(
return [result.relevance_score for result in sorted_results]
def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
with httpx.Client() as client:
response = client.post(
api_url,
json={
"model": model_name,
"query": query,
"documents": docs,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [
item["relevance_score"]
for item in sorted(result["results"], key=lambda x: x["index"])
]
@router.post("/bi-encoder-embed")
async def process_embed_request(
embed_request: EmbedRequest,
@ -418,6 +440,20 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
model_name=rerank_request.model_name,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.LITELLM:
if rerank_request.api_url is None:
raise ValueError("API URL is required for LiteLLM reranking.")
sim_scores = litellm_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
api_url=rerank_request.api_url,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")

View File

@ -11,6 +11,7 @@ class EmbeddingProvider(str, Enum):
class RerankerProvider(str, Enum):
COHERE = "cohere"
LITELLM = "litellm"
class EmbedTextType(str, Enum):

View File

@ -43,6 +43,7 @@ class RerankRequest(BaseModel):
model_name: str
provider_type: RerankerProvider | None = None
api_key: str | None = None
api_url: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

View File

@ -7,7 +7,11 @@ import {
rerankingModels,
} from "./interfaces";
import { FiExternalLink } from "react-icons/fi";
import { CohereIcon, MixedBreadIcon } from "@/components/icons/icons";
import {
CohereIcon,
LiteLLMIcon,
MixedBreadIcon,
} from "@/components/icons/icons";
import { Modal } from "@/components/Modal";
import { Button } from "@tremor/react";
import { TextFormField } from "@/components/admin/connectors/Field";
@ -35,6 +39,8 @@ const RerankingDetailsForm = forwardRef<
ref
) => {
const [isApiKeyModalOpen, setIsApiKeyModalOpen] = useState(false);
const [showLiteLLMConfigurationModal, setShowLiteLLMConfigurationModal] =
useState(false);
return (
<Formik
@ -48,13 +54,17 @@ const RerankingDetailsForm = forwardRef<
.optional(),
api_key: Yup.string().nullable(),
num_rerank: Yup.number().min(1, "Must be at least 1"),
rerank_api_url: Yup.string()
.url("Must be a valid URL")
.matches(/^https?:\/\//, "URL must start with http:// or https://")
.nullable(),
})}
onSubmit={async (_, { setSubmitting }) => {
setSubmitting(false);
}}
enableReinitialize={true}
>
{({ values, setFieldValue }) => {
{({ values, setFieldValue, resetForm }) => {
const resetRerankingValues = () => {
setRerankingDetails({
...values,
@ -131,14 +141,22 @@ const RerankingDetailsForm = forwardRef<
)
: rerankingModels.filter(
(modelCard) =>
modelCard.modelName ==
originalRerankingDetails.rerank_model_name
(modelCard.modelName ==
originalRerankingDetails.rerank_model_name &&
modelCard.rerank_provider_type ==
originalRerankingDetails.rerank_provider_type) ||
(modelCard.rerank_provider_type ==
RerankerProvider.LITELLM &&
originalRerankingDetails.rerank_provider_type ==
RerankerProvider.LITELLM)
)
).map((card) => {
const isSelected =
values.rerank_provider_type ===
card.rerank_provider_type &&
values.rerank_model_name === card.modelName;
(card.modelName == null ||
values.rerank_model_name === card.modelName);
return (
<div
key={`${card.rerank_provider_type}-${card.modelName}`}
@ -148,26 +166,39 @@ const RerankingDetailsForm = forwardRef<
: "border-gray-200 hover:border-blue-300 hover:shadow-sm"
}`}
onClick={() => {
if (card.rerank_provider_type) {
if (
card.rerank_provider_type == RerankerProvider.COHERE
) {
setIsApiKeyModalOpen(true);
} else if (
card.rerank_provider_type ==
RerankerProvider.LITELLM
) {
setShowLiteLLMConfigurationModal(true);
}
if (!isSelected) {
setRerankingDetails({
...values,
rerank_provider_type: card.rerank_provider_type!,
rerank_model_name: card.modelName || null,
rerank_api_key: null,
rerank_api_url: null,
});
setFieldValue(
"rerank_provider_type",
card.rerank_provider_type
);
setFieldValue("rerank_model_name", card.modelName);
}
setRerankingDetails({
...values,
rerank_provider_type: card.rerank_provider_type!,
rerank_model_name: card.modelName,
rerank_api_key: null,
});
setFieldValue(
"rerank_provider_type",
card.rerank_provider_type
);
setFieldValue("rerank_model_name", card.modelName);
}}
>
<div className="flex items-center justify-between mb-3">
<div className="flex items-center">
{card.rerank_provider_type ===
RerankerProvider.COHERE ? (
RerankerProvider.LITELLM ? (
<LiteLLMIcon size={24} className="mr-2" />
) : RerankerProvider.COHERE ? (
<CohereIcon size={24} className="mr-2" />
) : (
<MixedBreadIcon size={24} className="mr-2" />
@ -199,6 +230,88 @@ const RerankingDetailsForm = forwardRef<
})}
</div>
{showLiteLLMConfigurationModal && (
<Modal
onOutsideClick={() => {
resetForm();
setShowLiteLLMConfigurationModal(false);
}}
width="w-[800px]"
title="API Key Configuration"
>
<div className="w-full flex flex-col gap-y-4 px-4">
<TextFormField
subtext="Set the URL at which your LiteLLM Proxy is hosted"
placeholder={values.rerank_api_url || undefined}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
setRerankingDetails({
...values,
rerank_api_url: value,
});
setFieldValue("rerank_api_url", value);
}}
type="text"
label="LiteLLM Proxy URL"
name="rerank_api_url"
/>
<TextFormField
subtext="Set the key to access your LiteLLM Proxy"
placeholder={
values.rerank_api_key
? "*".repeat(values.rerank_api_key.length)
: undefined
}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
setRerankingDetails({
...values,
rerank_api_key: value,
});
setFieldValue("rerank_api_key", value);
}}
type="password"
label="LiteLLM Proxy Key"
name="rerank_api_key"
optional
/>
<TextFormField
subtext="Set the model name to use for LiteLLM Proxy"
placeholder={
values.rerank_model_name
? "*".repeat(values.rerank_model_name.length)
: undefined
}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
setRerankingDetails({
...values,
rerank_model_name: value,
});
setFieldValue("rerank_model_name", value);
}}
label="LiteLLM Model Name"
name="rerank_model_name"
optional
/>
<div className="flex w-full justify-end mt-4">
<Button
onClick={() => {
setShowLiteLLMConfigurationModal(false);
}}
color="blue"
size="xs"
>
Update
</Button>
</div>
</div>
</Modal>
)}
{isApiKeyModalOpen && (
<Modal
onOutsideClick={() => {
@ -218,7 +331,11 @@ const RerankingDetailsForm = forwardRef<
>
<div className="w-full px-4">
<TextFormField
placeholder={values.rerank_api_key || undefined}
placeholder={
values.rerank_api_key
? "*".repeat(values.rerank_api_key.length)
: undefined
}
onChange={(e: React.ChangeEvent<HTMLInputElement>) => {
const value = e.target.value;
setRerankingDetails({

View File

@ -5,11 +5,13 @@ export interface RerankingDetails {
rerank_model_name: string | null;
rerank_provider_type: RerankerProvider | null;
rerank_api_key: string | null;
rerank_api_url: string | null;
num_rerank: number;
}
export enum RerankerProvider {
COHERE = "cohere",
LITELLM = "litellm",
}
export interface AdvancedSearchConfiguration {
model_name: string;
@ -40,7 +42,7 @@ export interface SavedSearchSettings extends RerankingDetails {
export interface RerankingModel {
rerank_provider_type: RerankerProvider | null;
modelName: string;
modelName?: string;
displayName: string;
description: string;
link: string;
@ -48,6 +50,13 @@ export interface RerankingModel {
}
export const rerankingModels: RerankingModel[] = [
{
rerank_provider_type: RerankerProvider.LITELLM,
cloud: true,
displayName: "LiteLLM",
description: "Host your own reranker or router with LiteLLM proxy",
link: "https://docs.litellm.ai/docs/proxy",
},
{
rerank_provider_type: null,
cloud: false,

View File

@ -4,7 +4,7 @@ import * as Yup from "yup";
import CredentialSubText from "@/components/credentials/CredentialFields";
import { TrashIcon } from "@/components/icons/icons";
import { FaPlus } from "react-icons/fa";
import { AdvancedSearchConfiguration, RerankingDetails } from "../interfaces";
import { AdvancedSearchConfiguration } from "../interfaces";
import { BooleanFormField } from "@/components/admin/connectors/Field";
import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput";

View File

@ -10,7 +10,7 @@ import {
CloudEmbeddingModel,
EmbeddingProvider,
HostedEmbeddingModel,
} from "../../../../components/embedding/interfaces";
} from "@/components/embedding/interfaces";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { ErrorCallout } from "@/components/ErrorCallout";
import useSWR, { mutate } from "swr";
@ -18,7 +18,6 @@ import { ThreeDotsLoader } from "@/components/Loading";
import AdvancedEmbeddingFormPage from "./AdvancedEmbeddingFormPage";
import {
AdvancedSearchConfiguration,
RerankerProvider,
RerankingDetails,
SavedSearchSettings,
} from "../interfaces";
@ -49,6 +48,7 @@ export default function EmbeddingForm() {
num_rerank: 0,
rerank_provider_type: null,
rerank_model_name: "",
rerank_api_url: null,
});
const updateAdvancedEmbeddingDetails = (
@ -124,6 +124,7 @@ export default function EmbeddingForm() {
num_rerank: searchSettings.num_rerank,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
rerank_api_url: searchSettings.rerank_api_url,
});
}
}, [searchSettings]);
@ -134,12 +135,14 @@ export default function EmbeddingForm() {
num_rerank: searchSettings.num_rerank,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
rerank_api_url: searchSettings.rerank_api_url,
}
: {
rerank_api_key: "",
num_rerank: 0,
rerank_provider_type: null,
rerank_model_name: "",
rerank_api_url: null,
};
useEffect(() => {