update interfaces to standardize

This commit is contained in:
pablodanswer
2024-09-19 17:41:57 -07:00
parent 3cc99cf79a
commit 9abde19e44
6 changed files with 88 additions and 28 deletions

View File

@@ -71,8 +71,6 @@ def update_llm_provider(
LLMProviderModel.name == llm_provider_update.name LLMProviderModel.name == llm_provider_update.name
) )
) )
# if llm_provider_update.api_key_set:
if not existing_llm_provider: if not existing_llm_provider:
raise ValueError( raise ValueError(
f"LLM Provider with name {llm_provider_update.name} does not exist" f"LLM Provider with name {llm_provider_update.name} does not exist"
@@ -94,6 +92,23 @@ def create_llm_provider(
) )
def get_llm_provider(
llm_provider_name: str, db_session: Session, user: User | None = None
) -> FullLLMProviderSnapshot:
if not user or not user.is_admin:
raise ValueError("User does not have access to this LLM Provider")
return FullLLMProviderSnapshot.from_full_llm_provider(
FullLLMProvider.from_model(
db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.name == llm_provider_name
)
)
)
)
def upsert_llm_provider( def upsert_llm_provider(
llm_provider: LLMProviderUpsertRequest, llm_provider: LLMProviderUpsertRequest,
db_session: Session, db_session: Session,

View File

@@ -9,9 +9,12 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user from danswer.auth.users import current_user
from danswer.db.engine import get_session from danswer.db.engine import get_session
from danswer.db.llm import create_llm_provider
from danswer.db.llm import fetch_existing_llm_providers from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.llm import get_llm_provider
from danswer.db.llm import remove_llm_provider from danswer.db.llm import remove_llm_provider
from danswer.db.llm import update_default_provider from danswer.db.llm import update_default_provider
from danswer.db.llm import update_llm_provider
from danswer.db.llm import upsert_llm_provider from danswer.db.llm import upsert_llm_provider
from danswer.db.models import User from danswer.db.models import User
from danswer.llm.factory import get_default_llms from danswer.llm.factory import get_default_llms
@@ -19,8 +22,8 @@ from danswer.llm.factory import get_llm
from danswer.llm.llm_provider_options import fetch_available_well_known_llms from danswer.llm.llm_provider_options import fetch_available_well_known_llms
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from danswer.llm.utils import test_llm from danswer.llm.utils import test_llm
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import FullLLMProviderSnapshot from danswer.server.manage.llm.models import FullLLMProviderSnapshot
from danswer.server.manage.llm.models import LLMProviderCreationRequest
from danswer.server.manage.llm.models import LLMProviderDescriptor from danswer.server.manage.llm.models import LLMProviderDescriptor
from danswer.server.manage.llm.models import LLMProviderUpsertRequest from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.manage.llm.models import TestLLMRequest from danswer.server.manage.llm.models import TestLLMRequest
@@ -38,15 +41,21 @@ basic_router = APIRouter(prefix="/llm")
def fetch_llm_options( def fetch_llm_options(
_: User | None = Depends(current_admin_user), _: User | None = Depends(current_admin_user),
) -> list[WellKnownLLMProviderDescriptor]: ) -> list[WellKnownLLMProviderDescriptor]:
print("FETCHING")
return fetch_available_well_known_llms() return fetch_available_well_known_llms()
@admin_router.post("/test") @admin_router.post("/test")
def test_llm_configuration( def test_llm_configuration(
test_llm_request: TestLLMRequest, test_llm_request: TestLLMRequest,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user), _: User | None = Depends(current_admin_user),
) -> None: ) -> None:
if test_llm_request.existing_api_key and not test_llm_request.api_key:
llm_provider = get_llm_provider(
test_llm_request.provider.name, db_session=db_session
)
test_llm_request.api_key = llm_provider.api_key
llm = get_llm( llm = get_llm(
provider=test_llm_request.provider, provider=test_llm_request.provider,
model=test_llm_request.default_model_name, model=test_llm_request.default_model_name,
@@ -111,13 +120,47 @@ def test_default_provider(
def list_llm_providers( def list_llm_providers(
_: User | None = Depends(current_admin_user), _: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> list[FullLLMProvider]: ) -> list[FullLLMProviderSnapshot]:
print(
[
FullLLMProviderSnapshot.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
]
)
return [ return [
FullLLMProvider.from_model(llm_provider_model) FullLLMProviderSnapshot.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session) for llm_provider_model in fetch_existing_llm_providers(db_session)
] ]
@admin_router.patch("/provider/{provider_id}")
def patch_existing_llm_provider(
llm_provider: LLMProviderUpsertRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProviderSnapshot:
return FullLLMProviderSnapshot.from_full_llm_provider(
update_llm_provider(llm_provider=llm_provider, db_session=db_session)
)
@admin_router.post("/provider")
def create_new_llm_provider(
llm_provider: LLMProviderCreationRequest,
is_creation: bool = Query(
True,
description="True if updating an existing provider, False if creating a new one",
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProviderSnapshot:
return FullLLMProviderSnapshot.from_full_llm_provider(
create_llm_provider(
llm_provider=llm_provider, db_session=db_session, is_creation=is_creation
)
)
@admin_router.put("/provider") @admin_router.put("/provider")
def put_llm_provider( def put_llm_provider(
llm_provider: LLMProviderUpsertRequest, llm_provider: LLMProviderUpsertRequest,
@@ -129,13 +172,17 @@ def put_llm_provider(
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> FullLLMProviderSnapshot: ) -> FullLLMProviderSnapshot:
try: try:
return FullLLMProviderSnapshot.from_full_llm_provider( print("hitting htis function")
value = FullLLMProviderSnapshot.from_full_llm_provider(
upsert_llm_provider( upsert_llm_provider(
llm_provider=llm_provider, llm_provider=llm_provider,
db_session=db_session, db_session=db_session,
is_creation=is_creation, is_creation=is_creation,
) )
) )
print(value)
return value
except ValueError as e: except ValueError as e:
logger.exception("Failed to upsert LLM Provider") logger.exception("Failed to upsert LLM Provider")
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))

View File

@@ -11,6 +11,7 @@ if TYPE_CHECKING:
class TestLLMRequest(BaseModel): class TestLLMRequest(BaseModel):
# provider level # provider level
exisitng_api_key: bool = False
provider: str provider: str
api_key: str | None = None api_key: str | None = None
api_base: str | None = None api_base: str | None = None

View File

@@ -7,24 +7,17 @@ import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import { import {
SelectorFormField, SelectorFormField,
TextFormField, TextFormField,
BooleanFormField,
MultiSelectField, MultiSelectField,
} from "@/components/admin/connectors/Field"; } from "@/components/admin/connectors/Field";
import { useState } from "react"; import { useState } from "react";
import { Bubble } from "@/components/Bubble";
import { GroupsIcon } from "@/components/icons/icons";
import { useSWRConfig } from "swr"; import { useSWRConfig } from "swr";
import { import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks";
defaultModelsByProvider,
getDisplayNameForModel,
useUserGroups,
} from "@/lib/hooks";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup"; import { PopupSpec } from "@/components/admin/connectors/Popup";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import * as Yup from "yup"; import * as Yup from "yup";
import isEqual from "lodash/isEqual"; import isEqual from "lodash/isEqual";
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector"; import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
import { defaultPasswordMask } from "@/lib/llm/utils";
export function LLMProviderUpdateForm({ export function LLMProviderUpdateForm({
llmProviderDescriptor, llmProviderDescriptor,
@@ -33,6 +26,7 @@ export function LLMProviderUpdateForm({
shouldMarkAsDefault, shouldMarkAsDefault,
setPopup, setPopup,
hideAdvanced, hideAdvanced,
llmProviderFlow,
}: { }: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor; llmProviderDescriptor: WellKnownLLMProviderDescriptor;
onClose: () => void; onClose: () => void;
@@ -40,14 +34,10 @@ export function LLMProviderUpdateForm({
shouldMarkAsDefault?: boolean; shouldMarkAsDefault?: boolean;
hideAdvanced?: boolean; hideAdvanced?: boolean;
setPopup?: (popup: PopupSpec) => void; setPopup?: (popup: PopupSpec) => void;
llmProviderFlow: "create" | "update";
}) { }) {
const { mutate } = useSWRConfig(); const { mutate } = useSWRConfig();
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
// EE only
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const [isTesting, setIsTesting] = useState(false); const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>(""); const [testError, setTestError] = useState<string>("");
@@ -56,7 +46,7 @@ export function LLMProviderUpdateForm({
// Define the initial values based on the provider's requirements // Define the initial values based on the provider's requirements
const initialValues = { const initialValues = {
name: existingLlmProvider?.name || (hideAdvanced ? "Default" : ""), name: existingLlmProvider?.name || (hideAdvanced ? "Default" : ""),
api_key: existingLlmProvider?.api_key ?? "", api_key: null,
api_base: existingLlmProvider?.api_base ?? "", api_base: existingLlmProvider?.api_base ?? "",
api_version: existingLlmProvider?.api_version ?? "", api_version: existingLlmProvider?.api_version ?? "",
default_model_name: default_model_name:
@@ -86,9 +76,10 @@ export function LLMProviderUpdateForm({
// Setup validation schema if required // Setup validation schema if required
const validationSchema = Yup.object({ const validationSchema = Yup.object({
name: Yup.string().required("Display Name is required"), name: Yup.string().required("Display Name is required"),
api_key: llmProviderDescriptor.api_key_required api_key:
? Yup.string().required("API Key is required") llmProviderDescriptor.api_key_required && llmProviderFlow == "create"
: Yup.string(), ? Yup.string().required("API Key is required")
: Yup.string().nullable(),
api_base: llmProviderDescriptor.api_base_required api_base: llmProviderDescriptor.api_base_required
? Yup.string().required("API Base is required") ? Yup.string().required("API Base is required")
: Yup.string(), : Yup.string(),
@@ -120,6 +111,10 @@ export function LLMProviderUpdateForm({
display_model_names: Yup.array().of(Yup.string()), display_model_names: Yup.array().of(Yup.string()),
}); });
const apiKeyDefault = existingLlmProvider?.api_key_set
? defaultPasswordMask
: "API key";
return ( return (
<Formik <Formik
initialValues={initialValues} initialValues={initialValues}
@@ -151,7 +146,7 @@ export function LLMProviderUpdateForm({
} }
const response = await fetch(LLM_PROVIDERS_ADMIN_URL, { const response = await fetch(LLM_PROVIDERS_ADMIN_URL, {
method: "PUT", method: llmProviderFlow == "create" ? "POST" : "PUT",
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
@@ -237,8 +232,8 @@ export function LLMProviderUpdateForm({
small={hideAdvanced} small={hideAdvanced}
name="api_key" name="api_key"
label="API Key" label="API Key"
placeholder="API Key"
type="password" type="password"
placeholder={formikProps.values.api_key ?? apiKeyDefault}
/> />
)} )}

View File

@@ -34,7 +34,7 @@ export interface WellKnownLLMProviderDescriptor {
export interface LLMProvider { export interface LLMProvider {
name: string; name: string;
provider: string; provider: string;
api_key: string | null; api_key_set: boolean;
api_base: string | null; api_base: string | null;
api_version: string | null; api_version: string | null;
custom_config: { [key: string]: string } | null; custom_config: { [key: string]: string } | null;

View File

@@ -102,3 +102,5 @@ export const destructureValue = (value: string): LlmOverride => {
modelName, modelName,
}; };
}; };
export const defaultPasswordMask = "**************************";