update interfaces to standardize

This commit is contained in:
pablodanswer 2024-09-19 17:41:57 -07:00
parent 3cc99cf79a
commit 9abde19e44
6 changed files with 88 additions and 28 deletions

View File

@ -71,8 +71,6 @@ def update_llm_provider(
LLMProviderModel.name == llm_provider_update.name
)
)
# if llm_provider_update.api_key_set:
if not existing_llm_provider:
raise ValueError(
f"LLM Provider with name {llm_provider_update.name} does not exist"
@ -94,6 +92,23 @@ def create_llm_provider(
)
def get_llm_provider(
llm_provider_name: str, db_session: Session, user: User | None = None
) -> FullLLMProviderSnapshot:
if not user or not user.is_admin:
raise ValueError("User does not have access to this LLM Provider")
return FullLLMProviderSnapshot.from_full_llm_provider(
FullLLMProvider.from_model(
db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.name == llm_provider_name
)
)
)
)
def upsert_llm_provider(
llm_provider: LLMProviderUpsertRequest,
db_session: Session,

View File

@ -9,9 +9,12 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.llm import create_llm_provider
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.llm import get_llm_provider
from danswer.db.llm import remove_llm_provider
from danswer.db.llm import update_default_provider
from danswer.db.llm import update_llm_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.models import User
from danswer.llm.factory import get_default_llms
@ -19,8 +22,8 @@ from danswer.llm.factory import get_llm
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from danswer.llm.utils import test_llm
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import FullLLMProviderSnapshot
from danswer.server.manage.llm.models import LLMProviderCreationRequest
from danswer.server.manage.llm.models import LLMProviderDescriptor
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.manage.llm.models import TestLLMRequest
@ -38,15 +41,21 @@ basic_router = APIRouter(prefix="/llm")
def fetch_llm_options(
_: User | None = Depends(current_admin_user),
) -> list[WellKnownLLMProviderDescriptor]:
print("FETCHING")
return fetch_available_well_known_llms()
@admin_router.post("/test")
def test_llm_configuration(
test_llm_request: TestLLMRequest,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> None:
if test_llm_request.existing_api_key and not test_llm_request.api_key:
llm_provider = get_llm_provider(
test_llm_request.provider.name, db_session=db_session
)
test_llm_request.api_key = llm_provider.api_key
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
@ -111,13 +120,47 @@ def test_default_provider(
def list_llm_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[FullLLMProvider]:
) -> list[FullLLMProviderSnapshot]:
print(
[
FullLLMProviderSnapshot.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
]
)
return [
FullLLMProvider.from_model(llm_provider_model)
FullLLMProviderSnapshot.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
]
@admin_router.patch("/provider/{provider_id}")
def patch_existing_llm_provider(
llm_provider: LLMProviderUpsertRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProviderSnapshot:
return FullLLMProviderSnapshot.from_full_llm_provider(
update_llm_provider(llm_provider=llm_provider, db_session=db_session)
)
@admin_router.post("/provider")
def create_new_llm_provider(
llm_provider: LLMProviderCreationRequest,
is_creation: bool = Query(
True,
description="True if updating an existing provider, False if creating a new one",
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProviderSnapshot:
return FullLLMProviderSnapshot.from_full_llm_provider(
create_llm_provider(
llm_provider=llm_provider, db_session=db_session, is_creation=is_creation
)
)
@admin_router.put("/provider")
def put_llm_provider(
llm_provider: LLMProviderUpsertRequest,
@ -129,13 +172,17 @@ def put_llm_provider(
db_session: Session = Depends(get_session),
) -> FullLLMProviderSnapshot:
try:
return FullLLMProviderSnapshot.from_full_llm_provider(
print("hitting htis function")
value = FullLLMProviderSnapshot.from_full_llm_provider(
upsert_llm_provider(
llm_provider=llm_provider,
db_session=db_session,
is_creation=is_creation,
)
)
print(value)
return value
except ValueError as e:
logger.exception("Failed to upsert LLM Provider")
raise HTTPException(status_code=400, detail=str(e))

View File

@ -11,6 +11,7 @@ if TYPE_CHECKING:
class TestLLMRequest(BaseModel):
# provider level
exisitng_api_key: bool = False
provider: str
api_key: str | None = None
api_base: str | None = None

View File

@ -7,24 +7,17 @@ import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import {
SelectorFormField,
TextFormField,
BooleanFormField,
MultiSelectField,
} from "@/components/admin/connectors/Field";
import { useState } from "react";
import { Bubble } from "@/components/Bubble";
import { GroupsIcon } from "@/components/icons/icons";
import { useSWRConfig } from "swr";
import {
defaultModelsByProvider,
getDisplayNameForModel,
useUserGroups,
} from "@/lib/hooks";
import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
import { defaultPasswordMask } from "@/lib/llm/utils";
export function LLMProviderUpdateForm({
llmProviderDescriptor,
@ -33,6 +26,7 @@ export function LLMProviderUpdateForm({
shouldMarkAsDefault,
setPopup,
hideAdvanced,
llmProviderFlow,
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
onClose: () => void;
@ -40,14 +34,10 @@ export function LLMProviderUpdateForm({
shouldMarkAsDefault?: boolean;
hideAdvanced?: boolean;
setPopup?: (popup: PopupSpec) => void;
llmProviderFlow: "create" | "update";
}) {
const { mutate } = useSWRConfig();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
// EE only
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>("");
@ -56,7 +46,7 @@ export function LLMProviderUpdateForm({
// Define the initial values based on the provider's requirements
const initialValues = {
name: existingLlmProvider?.name || (hideAdvanced ? "Default" : ""),
api_key: existingLlmProvider?.api_key ?? "",
api_key: null,
api_base: existingLlmProvider?.api_base ?? "",
api_version: existingLlmProvider?.api_version ?? "",
default_model_name:
@ -86,9 +76,10 @@ export function LLMProviderUpdateForm({
// Setup validation schema if required
const validationSchema = Yup.object({
name: Yup.string().required("Display Name is required"),
api_key: llmProviderDescriptor.api_key_required
? Yup.string().required("API Key is required")
: Yup.string(),
api_key:
llmProviderDescriptor.api_key_required && llmProviderFlow == "create"
? Yup.string().required("API Key is required")
: Yup.string().nullable(),
api_base: llmProviderDescriptor.api_base_required
? Yup.string().required("API Base is required")
: Yup.string(),
@ -120,6 +111,10 @@ export function LLMProviderUpdateForm({
display_model_names: Yup.array().of(Yup.string()),
});
const apiKeyDefault = existingLlmProvider?.api_key_set
? defaultPasswordMask
: "API key";
return (
<Formik
initialValues={initialValues}
@ -151,7 +146,7 @@ export function LLMProviderUpdateForm({
}
const response = await fetch(LLM_PROVIDERS_ADMIN_URL, {
method: "PUT",
method: llmProviderFlow == "create" ? "POST" : "PUT",
headers: {
"Content-Type": "application/json",
},
@ -237,8 +232,8 @@ export function LLMProviderUpdateForm({
small={hideAdvanced}
name="api_key"
label="API Key"
placeholder="API Key"
type="password"
placeholder={formikProps.values.api_key ?? apiKeyDefault}
/>
)}

View File

@ -34,7 +34,7 @@ export interface WellKnownLLMProviderDescriptor {
export interface LLMProvider {
name: string;
provider: string;
api_key: string | null;
api_key_set: boolean;
api_base: string | null;
api_version: string | null;
custom_config: { [key: string]: string } | null;

View File

@ -102,3 +102,5 @@ export const destructureValue = (value: string): LlmOverride => {
modelName,
};
};
export const defaultPasswordMask = "**************************";