From 9abde19e44e3132b64daa3f11e3ad06b034341ff Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 19 Sep 2024 17:41:57 -0700 Subject: [PATCH] update interfaces to standardize --- backend/danswer/db/llm.py | 19 ++++++- backend/danswer/server/manage/llm/api.py | 57 +++++++++++++++++-- backend/danswer/server/manage/llm/models.py | 1 + .../llm/LLMProviderUpdateForm.tsx | 35 +++++------- .../app/admin/configuration/llm/interfaces.ts | 2 +- web/src/lib/llm/utils.ts | 2 + 6 files changed, 88 insertions(+), 28 deletions(-) diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py index 39478fe29..072aa32b8 100644 --- a/backend/danswer/db/llm.py +++ b/backend/danswer/db/llm.py @@ -71,8 +71,6 @@ def update_llm_provider( LLMProviderModel.name == llm_provider_update.name ) ) - # if llm_provider_update.api_key_set: - if not existing_llm_provider: raise ValueError( f"LLM Provider with name {llm_provider_update.name} does not exist" @@ -94,6 +92,23 @@ def create_llm_provider( ) +def get_llm_provider( + llm_provider_name: str, db_session: Session, user: User | None = None +) -> FullLLMProviderSnapshot: + if not user or not user.is_admin: + raise ValueError("User does not have access to this LLM Provider") + + return FullLLMProviderSnapshot.from_full_llm_provider( + FullLLMProvider.from_model( + db_session.scalar( + select(LLMProviderModel).where( + LLMProviderModel.name == llm_provider_name + ) + ) + ) + ) + + def upsert_llm_provider( llm_provider: LLMProviderUpsertRequest, db_session: Session, diff --git a/backend/danswer/server/manage/llm/api.py b/backend/danswer/server/manage/llm/api.py index 341f4c2fa..5588a1440 100644 --- a/backend/danswer/server/manage/llm/api.py +++ b/backend/danswer/server/manage/llm/api.py @@ -9,9 +9,12 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user from danswer.auth.users import current_user from danswer.db.engine import get_session +from danswer.db.llm import create_llm_provider from danswer.db.llm import fetch_existing_llm_providers +from danswer.db.llm import get_llm_provider from danswer.db.llm import remove_llm_provider from danswer.db.llm import update_default_provider +from danswer.db.llm import update_llm_provider from danswer.db.llm import upsert_llm_provider from danswer.db.models import User from danswer.llm.factory import get_default_llms @@ -19,8 +22,8 @@ from danswer.llm.factory import get_llm from danswer.llm.llm_provider_options import fetch_available_well_known_llms from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor from danswer.llm.utils import test_llm -from danswer.server.manage.llm.models import FullLLMProvider from danswer.server.manage.llm.models import FullLLMProviderSnapshot +from danswer.server.manage.llm.models import LLMProviderCreationRequest from danswer.server.manage.llm.models import LLMProviderDescriptor from danswer.server.manage.llm.models import LLMProviderUpsertRequest from danswer.server.manage.llm.models import TestLLMRequest @@ -38,15 +41,21 @@ basic_router = APIRouter(prefix="/llm") def fetch_llm_options( _: User | None = Depends(current_admin_user), ) -> list[WellKnownLLMProviderDescriptor]: - print("FETCHING") return fetch_available_well_known_llms() @admin_router.post("/test") def test_llm_configuration( test_llm_request: TestLLMRequest, + db_session: Session = Depends(get_session), _: User | None = Depends(current_admin_user), ) -> None: + if test_llm_request.existing_api_key and not test_llm_request.api_key: + llm_provider = get_llm_provider( + test_llm_request.provider.name, db_session=db_session + ) + test_llm_request.api_key = llm_provider.api_key + llm = get_llm( provider=test_llm_request.provider, model=test_llm_request.default_model_name, @@ -111,13 +120,47 @@ def test_default_provider( def list_llm_providers( _: User | None = Depends(current_admin_user), db_session: Session = Depends(get_session), -) -> list[FullLLMProvider]: +) -> list[FullLLMProviderSnapshot]: + print( + [ + FullLLMProviderSnapshot.from_model(llm_provider_model) + for llm_provider_model in fetch_existing_llm_providers(db_session) + ] + ) return [ - FullLLMProvider.from_model(llm_provider_model) + FullLLMProviderSnapshot.from_model(llm_provider_model) for llm_provider_model in fetch_existing_llm_providers(db_session) ] +@admin_router.patch("/provider/{provider_id}") +def patch_existing_llm_provider( + llm_provider: LLMProviderUpsertRequest, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> FullLLMProviderSnapshot: + return FullLLMProviderSnapshot.from_full_llm_provider( + update_llm_provider(llm_provider=llm_provider, db_session=db_session) + ) + + +@admin_router.post("/provider") +def create_new_llm_provider( + llm_provider: LLMProviderCreationRequest, + is_creation: bool = Query( + True, + description="True if updating an existing provider, False if creating a new one", + ), + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> FullLLMProviderSnapshot: + return FullLLMProviderSnapshot.from_full_llm_provider( + create_llm_provider( + llm_provider=llm_provider, db_session=db_session, is_creation=is_creation + ) + ) + + @admin_router.put("/provider") def put_llm_provider( llm_provider: LLMProviderUpsertRequest, @@ -129,13 +172,17 @@ def put_llm_provider( db_session: Session = Depends(get_session), ) -> FullLLMProviderSnapshot: try: - return FullLLMProviderSnapshot.from_full_llm_provider( + print("hitting htis function") + + value = FullLLMProviderSnapshot.from_full_llm_provider( upsert_llm_provider( llm_provider=llm_provider, db_session=db_session, is_creation=is_creation, ) ) + print(value) + return value except ValueError as e: logger.exception("Failed to upsert LLM Provider") raise HTTPException(status_code=400, detail=str(e)) diff --git a/backend/danswer/server/manage/llm/models.py b/backend/danswer/server/manage/llm/models.py index 52cbb3c7a..916f8ebbc 100644 --- a/backend/danswer/server/manage/llm/models.py +++ b/backend/danswer/server/manage/llm/models.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: class TestLLMRequest(BaseModel): # provider level + exisitng_api_key: bool = False provider: str api_key: str | None = None api_base: str | None = None diff --git a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx index f461ffbe8..6785961a3 100644 --- a/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx +++ b/web/src/app/admin/configuration/llm/LLMProviderUpdateForm.tsx @@ -7,24 +7,17 @@ import { LLM_PROVIDERS_ADMIN_URL } from "./constants"; import { SelectorFormField, TextFormField, - BooleanFormField, MultiSelectField, } from "@/components/admin/connectors/Field"; import { useState } from "react"; -import { Bubble } from "@/components/Bubble"; -import { GroupsIcon } from "@/components/icons/icons"; import { useSWRConfig } from "swr"; -import { - defaultModelsByProvider, - getDisplayNameForModel, - useUserGroups, -} from "@/lib/hooks"; +import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks"; import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; import { PopupSpec } from "@/components/admin/connectors/Popup"; -import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled"; import * as Yup from "yup"; import isEqual from "lodash/isEqual"; import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector"; +import { defaultPasswordMask } from "@/lib/llm/utils"; export function LLMProviderUpdateForm({ llmProviderDescriptor, @@ -33,6 +26,7 @@ export function LLMProviderUpdateForm({ shouldMarkAsDefault, setPopup, hideAdvanced, + llmProviderFlow, }: { llmProviderDescriptor: WellKnownLLMProviderDescriptor; onClose: () => void; @@ -40,14 +34,10 @@ export function LLMProviderUpdateForm({ shouldMarkAsDefault?: boolean; hideAdvanced?: boolean; setPopup?: (popup: PopupSpec) => void; + llmProviderFlow: "create" | "update"; }) { const { mutate } = useSWRConfig(); - const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled(); - - // EE only - const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups(); - const [isTesting, setIsTesting] = useState(false); const [testError, setTestError] = useState(""); @@ -56,7 +46,7 @@ export function LLMProviderUpdateForm({ // Define the initial values based on the provider's requirements const initialValues = { name: existingLlmProvider?.name || (hideAdvanced ? "Default" : ""), - api_key: existingLlmProvider?.api_key ?? "", + api_key: null, api_base: existingLlmProvider?.api_base ?? "", api_version: existingLlmProvider?.api_version ?? "", default_model_name: @@ -86,9 +76,10 @@ export function LLMProviderUpdateForm({ // Setup validation schema if required const validationSchema = Yup.object({ name: Yup.string().required("Display Name is required"), - api_key: llmProviderDescriptor.api_key_required - ? Yup.string().required("API Key is required") - : Yup.string(), + api_key: + llmProviderDescriptor.api_key_required && llmProviderFlow == "create" + ? Yup.string().required("API Key is required") + : Yup.string().nullable(), api_base: llmProviderDescriptor.api_base_required ? Yup.string().required("API Base is required") : Yup.string(), @@ -120,6 +111,10 @@ export function LLMProviderUpdateForm({ display_model_names: Yup.array().of(Yup.string()), }); + const apiKeyDefault = existingLlmProvider?.api_key_set + ? defaultPasswordMask + : "API key"; + return ( )} diff --git a/web/src/app/admin/configuration/llm/interfaces.ts b/web/src/app/admin/configuration/llm/interfaces.ts index 33fa94d7f..71b98a3b9 100644 --- a/web/src/app/admin/configuration/llm/interfaces.ts +++ b/web/src/app/admin/configuration/llm/interfaces.ts @@ -34,7 +34,7 @@ export interface WellKnownLLMProviderDescriptor { export interface LLMProvider { name: string; provider: string; - api_key: string | null; + api_key_set: boolean; api_base: string | null; api_version: string | null; custom_config: { [key: string]: string } | null; diff --git a/web/src/lib/llm/utils.ts b/web/src/lib/llm/utils.ts index 92e75cf46..bf1d72bb5 100644 --- a/web/src/lib/llm/utils.ts +++ b/web/src/lib/llm/utils.ts @@ -102,3 +102,5 @@ export const destructureValue = (value: string): LlmOverride => { modelName, }; }; + +export const defaultPasswordMask = "**************************";