Update auth for litellm proxy (#2316)

* update for auth

* validated embedding model names

* remove embedding provider

* remove logs

* add ability to delete search setting

* add abiility to delete models + more streamlined API endpoints

* remove upsert

* minor typing fix

* add connector utils
This commit is contained in:
pablodanswer 2024-09-04 13:59:07 -07:00 committed by GitHub
parent 630e2248bd
commit 34ba3181ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 282 additions and 91 deletions

View File

@ -1,3 +1,5 @@
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -13,6 +15,7 @@ from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.llm import fetch_embedding_provider from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexModelStatus from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings from danswer.db.models import SearchSettings
from danswer.indexing.models import IndexingSetting from danswer.indexing.models import IndexingSetting
@ -89,6 +92,30 @@ def get_current_db_embedding_provider(
return current_embedding_provider return current_embedding_provider
def delete_search_settings(db_session: Session, search_settings_id: int) -> None:
current_settings = get_current_search_settings(db_session)
if current_settings.id == search_settings_id:
raise ValueError("Cannot delete currently active search settings")
# First, delete associated index attempts
index_attempts_query = delete(IndexAttempt).where(
IndexAttempt.search_settings_id == search_settings_id
)
db_session.execute(index_attempts_query)
# Then, delete the search settings
search_settings_query = delete(SearchSettings).where(
and_(
SearchSettings.id == search_settings_id,
SearchSettings.status != IndexModelStatus.PRESENT,
)
)
db_session.execute(search_settings_query)
db_session.commit()
def get_current_search_settings(db_session: Session) -> SearchSettings: def get_current_search_settings(db_session: Session) -> SearchSettings:
query = ( query = (
select(SearchSettings) select(SearchSettings)

View File

@ -95,6 +95,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
class EmbeddingModelDetail(BaseModel): class EmbeddingModelDetail(BaseModel):
id: int | None = None
model_name: str model_name: str
normalize: bool normalize: bool
query_prefix: str | None query_prefix: str | None
@ -112,6 +113,7 @@ class EmbeddingModelDetail(BaseModel):
search_settings: "SearchSettings", search_settings: "SearchSettings",
) -> "EmbeddingModelDetail": ) -> "EmbeddingModelDetail":
return cls( return cls(
id=search_settings.id,
model_name=search_settings.model_name, model_name=search_settings.model_name,
normalize=search_settings.normalize, normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix, query_prefix=search_settings.query_prefix,

View File

@ -42,10 +42,10 @@ def test_embedding_configuration(
api_key=test_llm_request.api_key, api_key=test_llm_request.api_key,
api_url=test_llm_request.api_url, api_url=test_llm_request.api_url,
provider_type=test_llm_request.provider_type, provider_type=test_llm_request.provider_type,
model_name=test_llm_request.model_name,
normalize=False, normalize=False,
query_prefix=None, query_prefix=None,
passage_prefix=None, passage_prefix=None,
model_name=None,
) )
test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY) test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY)

View File

@ -8,10 +8,15 @@ if TYPE_CHECKING:
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
class SearchSettingsDeleteRequest(BaseModel):
search_settings_id: int
class TestEmbeddingRequest(BaseModel): class TestEmbeddingRequest(BaseModel):
provider_type: EmbeddingProvider provider_type: EmbeddingProvider
api_key: str | None = None api_key: str | None = None
api_url: str | None = None api_url: str | None = None
model_name: str | None = None
class CloudEmbeddingProvider(BaseModel): class CloudEmbeddingProvider(BaseModel):

View File

@ -14,6 +14,7 @@ from danswer.db.index_attempt import expire_index_attempts
from danswer.db.models import IndexModelStatus from danswer.db.models import IndexModelStatus
from danswer.db.models import User from danswer.db.models import User
from danswer.db.search_settings import create_search_settings from danswer.db.search_settings import create_search_settings
from danswer.db.search_settings import delete_search_settings
from danswer.db.search_settings import get_current_search_settings from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_embedding_provider_from_provider_type from danswer.db.search_settings import get_embedding_provider_from_provider_type
from danswer.db.search_settings import get_secondary_search_settings from danswer.db.search_settings import get_secondary_search_settings
@ -23,6 +24,7 @@ from danswer.document_index.factory import get_default_document_index
from danswer.natural_language_processing.search_nlp_models import clean_model_name from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.search.models import SavedSearchSettings from danswer.search.models import SavedSearchSettings
from danswer.search.models import SearchSettingsCreationRequest from danswer.search.models import SearchSettingsCreationRequest
from danswer.server.manage.embedding.models import SearchSettingsDeleteRequest
from danswer.server.manage.models import FullModelVersionResponse from danswer.server.manage.models import FullModelVersionResponse
from danswer.server.models import IdReturn from danswer.server.models import IdReturn
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@ -97,6 +99,7 @@ def set_new_search_settings(
primary_index_name=search_settings.index_name, primary_index_name=search_settings.index_name,
secondary_index_name=new_search_settings.index_name, secondary_index_name=new_search_settings.index_name,
) )
document_index.ensure_indices_exist( document_index.ensure_indices_exist(
index_embedding_dim=search_settings.model_dim, index_embedding_dim=search_settings.model_dim,
secondary_index_embedding_dim=new_search_settings.model_dim, secondary_index_embedding_dim=new_search_settings.model_dim,
@ -132,6 +135,21 @@ def cancel_new_embedding(
) )
@router.delete("/delete-search-settings")
def delete_search_settings_endpoint(
deletion_request: SearchSettingsDeleteRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
delete_search_settings(
db_session=db_session,
search_settings_id=deletion_request.search_settings_id,
)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@router.get("/get-current-search-settings") @router.get("/get-current-search-settings")
def get_current_search_settings_endpoint( def get_current_search_settings_endpoint(
_: User | None = Depends(current_user), _: User | None = Depends(current_user),

View File

@ -237,15 +237,18 @@ def get_local_reranking_model(
def embed_with_litellm_proxy( def embed_with_litellm_proxy(
texts: list[str], api_url: str, model: str texts: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[Embedding]: ) -> list[Embedding]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
with httpx.Client() as client: with httpx.Client() as client:
response = client.post( response = client.post(
api_url, api_url,
json={ json={
"model": model, "model": model_name,
"input": texts, "input": texts,
}, },
headers=headers,
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
@ -280,7 +283,12 @@ def embed_text(
logger.error("API URL not provided for LiteLLM proxy") logger.error("API URL not provided for LiteLLM proxy")
raise ValueError("API URL is required for LiteLLM proxy embedding.") raise ValueError("API URL is required for LiteLLM proxy embedding.")
try: try:
return embed_with_litellm_proxy(texts, api_url, model_name or "") return embed_with_litellm_proxy(
texts=texts,
api_url=api_url,
model_name=model_name or "",
api_key=api_key,
)
except Exception as e: except Exception as e:
logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}") logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}")
raise raise

View File

@ -58,6 +58,7 @@ LOG_LEVEL = os.environ.get("LOG_LEVEL", "notice")
# Fields which should only be set on new search setting # Fields which should only be set on new search setting
PRESERVED_SEARCH_FIELDS = [ PRESERVED_SEARCH_FIELDS = [
"id",
"provider_type", "provider_type",
"api_key", "api_key",
"model_name", "model_name",

View File

@ -91,7 +91,7 @@ export default function AddConnector({
>({ >({
name: "", name: "",
groups: [], groups: [],
is_public: false, is_public: true,
...configuration.values.reduce( ...configuration.values.reduce(
(acc, field) => { (acc, field) => {
if (field.type === "list") { if (field.type === "list") {

View File

@ -153,9 +153,66 @@ export function ChangeCredentialsModal({
title={`Modify your ${provider.provider_type} ${isProxy ? "URL" : "key"}`} title={`Modify your ${provider.provider_type} ${isProxy ? "URL" : "key"}`}
onOutsideClick={onCancel} onOutsideClick={onCancel}
> >
<>
{isProxy && (
<div className="mb-4"> <div className="mb-4">
<Subtitle className="font-bold text-lg"> <Subtitle className="font-bold text-lg">
Want to swap out your {isProxy ? "URL" : "key"}? Want to swap out your URL?
</Subtitle>
<a
href={provider.apiLink}
target="_blank"
rel="noopener noreferrer"
className="underline cursor-pointer mt-2 mb-4"
>
Visit API
</a>
<div className="flex flex-col mt-4 gap-y-2">
<input
className={`
border
border-border
rounded
w-full
py-2
px-3
bg-background-emphasis
`}
value={apiKeyOrUrl}
onChange={(e: any) => setApiKeyOrUrl(e.target.value)}
placeholder="Paste your API URL here"
/>
</div>
{testError && (
<Callout title="Error" color="red" className="mt-4">
{testError}
</Callout>
)}
<div className="flex mt-4 justify-between">
<Button
color="blue"
onClick={() => handleSubmit()}
disabled={!apiKeyOrUrl}
>
Swap URL
</Button>
</div>
{deletionError && (
<Callout title="Error" color="red" className="mt-4">
{deletionError}
</Callout>
)}
<Divider />
</div>
)}
<div className="mb-4">
<Subtitle className="font-bold text-lg">
Want to swap out your key?
</Subtitle> </Subtitle>
<a <a
href={provider.apiLink} href={provider.apiLink}
@ -193,7 +250,7 @@ export function ChangeCredentialsModal({
`} `}
value={apiKeyOrUrl} value={apiKeyOrUrl}
onChange={(e: any) => setApiKeyOrUrl(e.target.value)} onChange={(e: any) => setApiKeyOrUrl(e.target.value)}
placeholder={`Paste your ${isProxy ? "API URL" : "API key"} here`} placeholder="Paste your API key here"
/> />
</> </>
)} )}
@ -211,13 +268,13 @@ export function ChangeCredentialsModal({
onClick={() => handleSubmit()} onClick={() => handleSubmit()}
disabled={!apiKeyOrUrl} disabled={!apiKeyOrUrl}
> >
Swap {isProxy ? "URL" : "Key"} Swap Key
</Button> </Button>
</div> </div>
<Divider /> <Divider />
<Subtitle className="mt-4 font-bold text-lg mb-2"> <Subtitle className="mt-4 font-bold text-lg mb-2">
You can also delete your {isProxy ? "URL" : "key"}. You can also delete your configuration.
</Subtitle> </Subtitle>
<Text className="mb-2"> <Text className="mb-2">
This is only possible if you have already switched to a different This is only possible if you have already switched to a different
@ -225,7 +282,7 @@ export function ChangeCredentialsModal({
</Text> </Text>
<Button onClick={handleDelete} color="red"> <Button onClick={handleDelete} color="red">
Delete {isProxy ? "URL" : "key"} Delete Configuration
</Button> </Button>
{deletionError && ( {deletionError && (
<Callout title="Error" color="red" className="mt-4"> <Callout title="Error" color="red" className="mt-4">
@ -233,6 +290,7 @@ export function ChangeCredentialsModal({
</Callout> </Callout>
)} )}
</div> </div>
</>
</Modal> </Modal>
); );
} }

View File

@ -36,6 +36,7 @@ export function ProviderCreationModal({
? Object.entries(existingProvider.custom_config) ? Object.entries(existingProvider.custom_config)
: [], : [],
model_id: 0, model_id: 0,
model_name: null,
}; };
const validationSchema = Yup.object({ const validationSchema = Yup.object({
@ -45,6 +46,9 @@ export function ProviderCreationModal({
: useFileUpload : useFileUpload
? Yup.string() ? Yup.string()
: Yup.string().required("API Key is required"), : Yup.string().required("API Key is required"),
model_name: isProxy
? Yup.string().required("Model name is required")
: Yup.string().nullable(),
api_url: isProxy api_url: isProxy
? Yup.string().required("API URL is required") ? Yup.string().required("API URL is required")
: Yup.string(), : Yup.string(),
@ -96,6 +100,7 @@ export function ProviderCreationModal({
provider_type: values.provider_type.toLowerCase().split(" ")[0], provider_type: values.provider_type.toLowerCase().split(" ")[0],
api_key: values.api_key, api_key: values.api_key,
api_url: values.api_url, api_url: values.api_url,
model_name: values.model_name,
}), }),
} }
); );
@ -182,15 +187,25 @@ export function ProviderCreationModal({
</a> </a>
</Text> </Text>
<div className="flex w-full flex-col gap-y-2"> <div className="flex w-full flex-col gap-y-6">
{isProxy ? ( {isProxy && (
<>
<TextFormField <TextFormField
name="api_url" name="api_url"
label="API URL" label="API URL"
placeholder="API URL" placeholder="API URL"
type="text" type="text"
/> />
) : useFileUpload ? ( <TextFormField
name="model_name"
label="Model Name (for testing)"
placeholder="Model Name"
type="text"
/>
</>
)}
{useFileUpload ? (
<> <>
<Label>Upload JSON File</Label> <Label>Upload JSON File</Label>
<input <input
@ -205,7 +220,7 @@ export function ProviderCreationModal({
) : ( ) : (
<TextFormField <TextFormField
name="api_key" name="api_key"
label="API Key" label={`API Key ${isProxy && "(for non-local deployments)"}`}
placeholder="API Key" placeholder="API Key"
type="password" type="password"
/> />

View File

@ -12,10 +12,13 @@ import {
LITELLM_CLOUD_PROVIDER, LITELLM_CLOUD_PROVIDER,
} from "../../../../components/embedding/interfaces"; } from "../../../../components/embedding/interfaces";
import { EmbeddingDetails } from "../EmbeddingModelSelectionForm"; import { EmbeddingDetails } from "../EmbeddingModelSelectionForm";
import { FiExternalLink, FiInfo } from "react-icons/fi"; import { FiExternalLink, FiInfo, FiTrash } from "react-icons/fi";
import { HoverPopup } from "@/components/HoverPopup"; import { HoverPopup } from "@/components/HoverPopup";
import { Dispatch, SetStateAction, useEffect, useState } from "react"; import { Dispatch, SetStateAction, useEffect, useState } from "react";
import { LiteLLMModelForm } from "@/components/embedding/LiteLLMModelForm"; import { LiteLLMModelForm } from "@/components/embedding/LiteLLMModelForm";
import { deleteSearchSettings } from "./utils";
import { usePopup } from "@/components/admin/connectors/Popup";
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
export default function CloudEmbeddingPage({ export default function CloudEmbeddingPage({
currentModel, currentModel,
@ -181,7 +184,7 @@ export default function CloudEmbeddingPage({
onClick={() => setShowTentativeProvider(LITELLM_CLOUD_PROVIDER)} onClick={() => setShowTentativeProvider(LITELLM_CLOUD_PROVIDER)}
className="mb-2 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 text-sm cursor-pointer" className="mb-2 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 text-sm cursor-pointer"
> >
Provide API URL Set API Configuration
</button> </button>
) : ( ) : (
<button <button
@ -190,7 +193,7 @@ export default function CloudEmbeddingPage({
} }
className="mb-2 hover:underline text-sm cursor-pointer" className="mb-2 hover:underline text-sm cursor-pointer"
> >
Modify API URL Modify API Configuration
</button> </button>
)} )}
@ -283,11 +286,33 @@ export function CloudModelCard({
React.SetStateAction<CloudEmbeddingProvider | null> React.SetStateAction<CloudEmbeddingProvider | null>
>; >;
}) { }) {
const { popup, setPopup } = usePopup();
const [showDeleteModel, setShowDeleteModel] = useState(false);
const enabled = const enabled =
model.model_name === currentModel.model_name && model.model_name === currentModel.model_name &&
model.provider_type?.toLowerCase() == model.provider_type?.toLowerCase() ==
currentModel.provider_type?.toLowerCase(); currentModel.provider_type?.toLowerCase();
const deleteModel = async () => {
if (!model.id) {
setPopup({ message: "Model cannot be deleted", type: "error" });
return;
}
const response = await deleteSearchSettings(model.id);
if (response.ok) {
setPopup({ message: "Model deleted successfully", type: "success" });
setShowDeleteModel(false);
} else {
setPopup({
message:
"Failed to delete model. Ensure you are not attempting to delete a curently active model.",
type: "error",
});
}
};
return ( return (
<div <div
className={`p-4 w-96 border rounded-lg transition-all duration-200 ${ className={`p-4 w-96 border rounded-lg transition-all duration-200 ${
@ -296,8 +321,28 @@ export function CloudModelCard({
: "border-gray-300 hover:border-blue-300 hover:shadow-sm" : "border-gray-300 hover:border-blue-300 hover:shadow-sm"
} ${!provider.configured && "opacity-80 hover:opacity-100"}`} } ${!provider.configured && "opacity-80 hover:opacity-100"}`}
> >
{popup}
{showDeleteModel && (
<DeleteEntityModal
entityName={model.model_name}
entityType="embedding model configuration"
onSubmit={() => deleteModel()}
onClose={() => setShowDeleteModel(false)}
/>
)}
<div className="flex items-center justify-between mb-3"> <div className="flex items-center justify-between mb-3">
<h3 className="font-bold text-lg">{model.model_name}</h3> <h3 className="font-bold text-lg">{model.model_name}</h3>
<div className="flex gap-x-2">
{model.provider_type == EmbeddingProvider.LITELLM.toLowerCase() && (
<button
onClickCapture={() => setShowDeleteModel(true)}
onClick={(e) => e.stopPropagation()}
className="text-blue-500 hover:text-blue-700 transition-colors duration-200"
>
<FiTrash size={18} />
</button>
)}
<a <a
href={provider.website} href={provider.website}
target="_blank" target="_blank"
@ -308,6 +353,7 @@ export function CloudModelCard({
<FiExternalLink size={18} /> <FiExternalLink size={18} />
</a> </a>
</div> </div>
</div>
<p className="text-sm text-gray-600 mb-2">{model.description}</p> <p className="text-sm text-gray-600 mb-2">{model.description}</p>
{model?.provider_type?.toLowerCase() != {model?.provider_type?.toLowerCase() !=
EmbeddingProvider.LITELLM.toLowerCase() && ( EmbeddingProvider.LITELLM.toLowerCase() && (

View File

@ -0,0 +1,10 @@
export const deleteSearchSettings = async (search_settings_id: number) => {
const response = await fetch(`/api/search-settings/delete-search-settings`, {
method: "DELETE",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ search_settings_id }),
});
return response;
};

View File

@ -39,6 +39,7 @@ export interface CloudEmbeddingProvider {
// Embedding Models // Embedding Models
export interface EmbeddingModelDescriptor { export interface EmbeddingModelDescriptor {
id?: number;
model_name: string; model_name: string;
model_dim: number; model_dim: number;
normalize: boolean; normalize: boolean;