mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-18 11:34:12 +02:00
update interfaces to standardize
This commit is contained in:
@@ -71,8 +71,6 @@ def update_llm_provider(
|
|||||||
LLMProviderModel.name == llm_provider_update.name
|
LLMProviderModel.name == llm_provider_update.name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# if llm_provider_update.api_key_set:
|
|
||||||
|
|
||||||
if not existing_llm_provider:
|
if not existing_llm_provider:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"LLM Provider with name {llm_provider_update.name} does not exist"
|
f"LLM Provider with name {llm_provider_update.name} does not exist"
|
||||||
@@ -94,6 +92,23 @@ def create_llm_provider(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_provider(
|
||||||
|
llm_provider_name: str, db_session: Session, user: User | None = None
|
||||||
|
) -> FullLLMProviderSnapshot:
|
||||||
|
if not user or not user.is_admin:
|
||||||
|
raise ValueError("User does not have access to this LLM Provider")
|
||||||
|
|
||||||
|
return FullLLMProviderSnapshot.from_full_llm_provider(
|
||||||
|
FullLLMProvider.from_model(
|
||||||
|
db_session.scalar(
|
||||||
|
select(LLMProviderModel).where(
|
||||||
|
LLMProviderModel.name == llm_provider_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def upsert_llm_provider(
|
def upsert_llm_provider(
|
||||||
llm_provider: LLMProviderUpsertRequest,
|
llm_provider: LLMProviderUpsertRequest,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
|
@@ -9,9 +9,12 @@ from sqlalchemy.orm import Session
|
|||||||
from danswer.auth.users import current_admin_user
|
from danswer.auth.users import current_admin_user
|
||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
|
from danswer.db.llm import create_llm_provider
|
||||||
from danswer.db.llm import fetch_existing_llm_providers
|
from danswer.db.llm import fetch_existing_llm_providers
|
||||||
|
from danswer.db.llm import get_llm_provider
|
||||||
from danswer.db.llm import remove_llm_provider
|
from danswer.db.llm import remove_llm_provider
|
||||||
from danswer.db.llm import update_default_provider
|
from danswer.db.llm import update_default_provider
|
||||||
|
from danswer.db.llm import update_llm_provider
|
||||||
from danswer.db.llm import upsert_llm_provider
|
from danswer.db.llm import upsert_llm_provider
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.llm.factory import get_default_llms
|
from danswer.llm.factory import get_default_llms
|
||||||
@@ -19,8 +22,8 @@ from danswer.llm.factory import get_llm
|
|||||||
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
|
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
|
||||||
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||||
from danswer.llm.utils import test_llm
|
from danswer.llm.utils import test_llm
|
||||||
from danswer.server.manage.llm.models import FullLLMProvider
|
|
||||||
from danswer.server.manage.llm.models import FullLLMProviderSnapshot
|
from danswer.server.manage.llm.models import FullLLMProviderSnapshot
|
||||||
|
from danswer.server.manage.llm.models import LLMProviderCreationRequest
|
||||||
from danswer.server.manage.llm.models import LLMProviderDescriptor
|
from danswer.server.manage.llm.models import LLMProviderDescriptor
|
||||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||||
from danswer.server.manage.llm.models import TestLLMRequest
|
from danswer.server.manage.llm.models import TestLLMRequest
|
||||||
@@ -38,15 +41,21 @@ basic_router = APIRouter(prefix="/llm")
|
|||||||
def fetch_llm_options(
|
def fetch_llm_options(
|
||||||
_: User | None = Depends(current_admin_user),
|
_: User | None = Depends(current_admin_user),
|
||||||
) -> list[WellKnownLLMProviderDescriptor]:
|
) -> list[WellKnownLLMProviderDescriptor]:
|
||||||
print("FETCHING")
|
|
||||||
return fetch_available_well_known_llms()
|
return fetch_available_well_known_llms()
|
||||||
|
|
||||||
|
|
||||||
@admin_router.post("/test")
|
@admin_router.post("/test")
|
||||||
def test_llm_configuration(
|
def test_llm_configuration(
|
||||||
test_llm_request: TestLLMRequest,
|
test_llm_request: TestLLMRequest,
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
_: User | None = Depends(current_admin_user),
|
_: User | None = Depends(current_admin_user),
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if test_llm_request.existing_api_key and not test_llm_request.api_key:
|
||||||
|
llm_provider = get_llm_provider(
|
||||||
|
test_llm_request.provider.name, db_session=db_session
|
||||||
|
)
|
||||||
|
test_llm_request.api_key = llm_provider.api_key
|
||||||
|
|
||||||
llm = get_llm(
|
llm = get_llm(
|
||||||
provider=test_llm_request.provider,
|
provider=test_llm_request.provider,
|
||||||
model=test_llm_request.default_model_name,
|
model=test_llm_request.default_model_name,
|
||||||
@@ -111,13 +120,47 @@ def test_default_provider(
|
|||||||
def list_llm_providers(
|
def list_llm_providers(
|
||||||
_: User | None = Depends(current_admin_user),
|
_: User | None = Depends(current_admin_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> list[FullLLMProvider]:
|
) -> list[FullLLMProviderSnapshot]:
|
||||||
|
print(
|
||||||
|
[
|
||||||
|
FullLLMProviderSnapshot.from_model(llm_provider_model)
|
||||||
|
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||||
|
]
|
||||||
|
)
|
||||||
return [
|
return [
|
||||||
FullLLMProvider.from_model(llm_provider_model)
|
FullLLMProviderSnapshot.from_model(llm_provider_model)
|
||||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.patch("/provider/{provider_id}")
|
||||||
|
def patch_existing_llm_provider(
|
||||||
|
llm_provider: LLMProviderUpsertRequest,
|
||||||
|
_: User | None = Depends(current_admin_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> FullLLMProviderSnapshot:
|
||||||
|
return FullLLMProviderSnapshot.from_full_llm_provider(
|
||||||
|
update_llm_provider(llm_provider=llm_provider, db_session=db_session)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@admin_router.post("/provider")
|
||||||
|
def create_new_llm_provider(
|
||||||
|
llm_provider: LLMProviderCreationRequest,
|
||||||
|
is_creation: bool = Query(
|
||||||
|
True,
|
||||||
|
description="True if updating an existing provider, False if creating a new one",
|
||||||
|
),
|
||||||
|
_: User | None = Depends(current_admin_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> FullLLMProviderSnapshot:
|
||||||
|
return FullLLMProviderSnapshot.from_full_llm_provider(
|
||||||
|
create_llm_provider(
|
||||||
|
llm_provider=llm_provider, db_session=db_session, is_creation=is_creation
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@admin_router.put("/provider")
|
@admin_router.put("/provider")
|
||||||
def put_llm_provider(
|
def put_llm_provider(
|
||||||
llm_provider: LLMProviderUpsertRequest,
|
llm_provider: LLMProviderUpsertRequest,
|
||||||
@@ -129,13 +172,17 @@ def put_llm_provider(
|
|||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> FullLLMProviderSnapshot:
|
) -> FullLLMProviderSnapshot:
|
||||||
try:
|
try:
|
||||||
return FullLLMProviderSnapshot.from_full_llm_provider(
|
print("hitting htis function")
|
||||||
|
|
||||||
|
value = FullLLMProviderSnapshot.from_full_llm_provider(
|
||||||
upsert_llm_provider(
|
upsert_llm_provider(
|
||||||
llm_provider=llm_provider,
|
llm_provider=llm_provider,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
is_creation=is_creation,
|
is_creation=is_creation,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
print(value)
|
||||||
|
return value
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.exception("Failed to upsert LLM Provider")
|
logger.exception("Failed to upsert LLM Provider")
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
class TestLLMRequest(BaseModel):
|
class TestLLMRequest(BaseModel):
|
||||||
# provider level
|
# provider level
|
||||||
|
exisitng_api_key: bool = False
|
||||||
provider: str
|
provider: str
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
api_base: str | None = None
|
api_base: str | None = None
|
||||||
|
@@ -7,24 +7,17 @@ import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
|||||||
import {
|
import {
|
||||||
SelectorFormField,
|
SelectorFormField,
|
||||||
TextFormField,
|
TextFormField,
|
||||||
BooleanFormField,
|
|
||||||
MultiSelectField,
|
MultiSelectField,
|
||||||
} from "@/components/admin/connectors/Field";
|
} from "@/components/admin/connectors/Field";
|
||||||
import { useState } from "react";
|
import { useState } from "react";
|
||||||
import { Bubble } from "@/components/Bubble";
|
|
||||||
import { GroupsIcon } from "@/components/icons/icons";
|
|
||||||
import { useSWRConfig } from "swr";
|
import { useSWRConfig } from "swr";
|
||||||
import {
|
import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks";
|
||||||
defaultModelsByProvider,
|
|
||||||
getDisplayNameForModel,
|
|
||||||
useUserGroups,
|
|
||||||
} from "@/lib/hooks";
|
|
||||||
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
|
||||||
import * as Yup from "yup";
|
import * as Yup from "yup";
|
||||||
import isEqual from "lodash/isEqual";
|
import isEqual from "lodash/isEqual";
|
||||||
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
|
import { IsPublicGroupSelector } from "@/components/IsPublicGroupSelector";
|
||||||
|
import { defaultPasswordMask } from "@/lib/llm/utils";
|
||||||
|
|
||||||
export function LLMProviderUpdateForm({
|
export function LLMProviderUpdateForm({
|
||||||
llmProviderDescriptor,
|
llmProviderDescriptor,
|
||||||
@@ -33,6 +26,7 @@ export function LLMProviderUpdateForm({
|
|||||||
shouldMarkAsDefault,
|
shouldMarkAsDefault,
|
||||||
setPopup,
|
setPopup,
|
||||||
hideAdvanced,
|
hideAdvanced,
|
||||||
|
llmProviderFlow,
|
||||||
}: {
|
}: {
|
||||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
|
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
|
||||||
onClose: () => void;
|
onClose: () => void;
|
||||||
@@ -40,14 +34,10 @@ export function LLMProviderUpdateForm({
|
|||||||
shouldMarkAsDefault?: boolean;
|
shouldMarkAsDefault?: boolean;
|
||||||
hideAdvanced?: boolean;
|
hideAdvanced?: boolean;
|
||||||
setPopup?: (popup: PopupSpec) => void;
|
setPopup?: (popup: PopupSpec) => void;
|
||||||
|
llmProviderFlow: "create" | "update";
|
||||||
}) {
|
}) {
|
||||||
const { mutate } = useSWRConfig();
|
const { mutate } = useSWRConfig();
|
||||||
|
|
||||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
|
||||||
|
|
||||||
// EE only
|
|
||||||
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
|
|
||||||
|
|
||||||
const [isTesting, setIsTesting] = useState(false);
|
const [isTesting, setIsTesting] = useState(false);
|
||||||
const [testError, setTestError] = useState<string>("");
|
const [testError, setTestError] = useState<string>("");
|
||||||
|
|
||||||
@@ -56,7 +46,7 @@ export function LLMProviderUpdateForm({
|
|||||||
// Define the initial values based on the provider's requirements
|
// Define the initial values based on the provider's requirements
|
||||||
const initialValues = {
|
const initialValues = {
|
||||||
name: existingLlmProvider?.name || (hideAdvanced ? "Default" : ""),
|
name: existingLlmProvider?.name || (hideAdvanced ? "Default" : ""),
|
||||||
api_key: existingLlmProvider?.api_key ?? "",
|
api_key: null,
|
||||||
api_base: existingLlmProvider?.api_base ?? "",
|
api_base: existingLlmProvider?.api_base ?? "",
|
||||||
api_version: existingLlmProvider?.api_version ?? "",
|
api_version: existingLlmProvider?.api_version ?? "",
|
||||||
default_model_name:
|
default_model_name:
|
||||||
@@ -86,9 +76,10 @@ export function LLMProviderUpdateForm({
|
|||||||
// Setup validation schema if required
|
// Setup validation schema if required
|
||||||
const validationSchema = Yup.object({
|
const validationSchema = Yup.object({
|
||||||
name: Yup.string().required("Display Name is required"),
|
name: Yup.string().required("Display Name is required"),
|
||||||
api_key: llmProviderDescriptor.api_key_required
|
api_key:
|
||||||
? Yup.string().required("API Key is required")
|
llmProviderDescriptor.api_key_required && llmProviderFlow == "create"
|
||||||
: Yup.string(),
|
? Yup.string().required("API Key is required")
|
||||||
|
: Yup.string().nullable(),
|
||||||
api_base: llmProviderDescriptor.api_base_required
|
api_base: llmProviderDescriptor.api_base_required
|
||||||
? Yup.string().required("API Base is required")
|
? Yup.string().required("API Base is required")
|
||||||
: Yup.string(),
|
: Yup.string(),
|
||||||
@@ -120,6 +111,10 @@ export function LLMProviderUpdateForm({
|
|||||||
display_model_names: Yup.array().of(Yup.string()),
|
display_model_names: Yup.array().of(Yup.string()),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const apiKeyDefault = existingLlmProvider?.api_key_set
|
||||||
|
? defaultPasswordMask
|
||||||
|
: "API key";
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Formik
|
<Formik
|
||||||
initialValues={initialValues}
|
initialValues={initialValues}
|
||||||
@@ -151,7 +146,7 @@ export function LLMProviderUpdateForm({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const response = await fetch(LLM_PROVIDERS_ADMIN_URL, {
|
const response = await fetch(LLM_PROVIDERS_ADMIN_URL, {
|
||||||
method: "PUT",
|
method: llmProviderFlow == "create" ? "POST" : "PUT",
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
},
|
},
|
||||||
@@ -237,8 +232,8 @@ export function LLMProviderUpdateForm({
|
|||||||
small={hideAdvanced}
|
small={hideAdvanced}
|
||||||
name="api_key"
|
name="api_key"
|
||||||
label="API Key"
|
label="API Key"
|
||||||
placeholder="API Key"
|
|
||||||
type="password"
|
type="password"
|
||||||
|
placeholder={formikProps.values.api_key ?? apiKeyDefault}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
@@ -34,7 +34,7 @@ export interface WellKnownLLMProviderDescriptor {
|
|||||||
export interface LLMProvider {
|
export interface LLMProvider {
|
||||||
name: string;
|
name: string;
|
||||||
provider: string;
|
provider: string;
|
||||||
api_key: string | null;
|
api_key_set: boolean;
|
||||||
api_base: string | null;
|
api_base: string | null;
|
||||||
api_version: string | null;
|
api_version: string | null;
|
||||||
custom_config: { [key: string]: string } | null;
|
custom_config: { [key: string]: string } | null;
|
||||||
|
@@ -102,3 +102,5 @@ export const destructureValue = (value: string): LlmOverride => {
|
|||||||
modelName,
|
modelName,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const defaultPasswordMask = "**************************";
|
||||||
|
Reference in New Issue
Block a user