From 8b9e6a91a40bef03183dd8c3619734195e2bed86 Mon Sep 17 00:00:00 2001 From: Weves Date: Tue, 30 Jan 2024 10:08:46 -0800 Subject: [PATCH] Fix change model popup --- backend/danswer/server/manage/models.py | 5 +++ .../danswer/server/manage/secondary_index.py | 19 ++++++++- .../admin/models/embedding/embeddingModels.ts | 5 +++ web/src/app/chat/page.tsx | 22 ++++++---- web/src/app/search/page.tsx | 40 +++++++++++-------- web/src/components/SwitchModelModal.tsx | 4 +- web/src/components/WelcomeModal.tsx | 2 +- 7 files changed, 67 insertions(+), 30 deletions(-) diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index 01cff7c97ab9..ee4470eab14c 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -107,3 +107,8 @@ class SlackBotConfig(BaseModel): class ModelVersionResponse(BaseModel): model_name: str | None # None only applicable to secondary index + + +class FullModelVersionResponse(BaseModel): + current_model_name: str + secondary_model_name: str | None diff --git a/backend/danswer/server/manage/secondary_index.py b/backend/danswer/server/manage/secondary_index.py index d1c5ffdb2780..2013cfc5bb8d 100644 --- a/backend/danswer/server/manage/secondary_index.py +++ b/backend/danswer/server/manage/secondary_index.py @@ -5,6 +5,7 @@ from fastapi import status from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user +from danswer.auth.users import current_user from danswer.db.embedding_model import create_embedding_model from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_secondary_db_embedding_model @@ -15,6 +16,7 @@ from danswer.db.models import IndexModelStatus from danswer.db.models import User from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import EmbeddingModelDetail +from danswer.server.manage.models import FullModelVersionResponse from danswer.server.manage.models import ModelVersionResponse from danswer.server.models import IdReturn from danswer.utils.logger import setup_logger @@ -100,7 +102,7 @@ def cancel_new_embedding( @router.get("/get-current-embedding-model") def get_current_embedding_model( - _: User | None = Depends(current_admin_user), + _: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> ModelVersionResponse: current_model = get_current_db_embedding_model(db_session) @@ -109,7 +111,7 @@ def get_current_embedding_model( @router.get("/get-secondary-embedding-model") def get_secondary_embedding_model( - _: User | None = Depends(current_admin_user), + _: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> ModelVersionResponse: next_model = get_secondary_db_embedding_model(db_session) @@ -117,3 +119,16 @@ def get_secondary_embedding_model( return ModelVersionResponse( model_name=next_model.model_name if next_model else None ) + + +@router.get("/get-embedding-models") +def get_embedding_models( + _: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> FullModelVersionResponse: + current_model = get_current_db_embedding_model(db_session) + next_model = get_secondary_db_embedding_model(db_session) + return FullModelVersionResponse( + current_model_name=current_model.model_name, + secondary_model_name=next_model.model_name if next_model else None, + ) diff --git a/web/src/app/admin/models/embedding/embeddingModels.ts b/web/src/app/admin/models/embedding/embeddingModels.ts index 4bacdab99181..64ccfff9581a 100644 --- a/web/src/app/admin/models/embedding/embeddingModels.ts +++ b/web/src/app/admin/models/embedding/embeddingModels.ts @@ -2,6 +2,11 @@ export interface EmbeddingModelResponse { model_name: string | null; } +export interface FullEmbeddingModelResponse { + current_model_name: string; + secondary_model_name: string | null; +} + export interface EmbeddingModelDescriptor { model_name: string; model_dim: number; diff --git a/web/src/app/chat/page.tsx b/web/src/app/chat/page.tsx index 169b55232b0e..eb8ff8918c53 100644 --- a/web/src/app/chat/page.tsx +++ b/web/src/app/chat/page.tsx @@ -17,7 +17,7 @@ import { DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME } from "@/components/resizable/conta import { personaComparator } from "../admin/personas/lib"; import { ChatLayout } from "./ChatPage"; import { - EmbeddingModelResponse, + FullEmbeddingModelResponse, checkModelNameIsValid, } from "../admin/models/embedding/embeddingModels"; import { SwitchModelModal } from "@/components/SwitchModelModal"; @@ -37,7 +37,7 @@ export default async function Page({ fetchSS("/persona?include_default=true"), fetchSS("/chat/get-user-chat-sessions"), fetchSS("/query/valid-tags"), - fetchSS("/secondary-index/get-current-embedding-model"), + fetchSS("/secondary-index/get-embedding-models"), ]; // catch cases where the backend is completely unreachable here @@ -47,7 +47,7 @@ export default async function Page({ | User | Response | AuthTypeMetadata - | EmbeddingModelResponse + | FullEmbeddingModelResponse | null )[] = [null, null, null, null, null, null, null, null]; try { @@ -124,10 +124,14 @@ export default async function Page({ console.log(`Failed to fetch tags - ${tagsResponse?.status}`); } - const embeddingModelName = + const embeddingModelVersionInfo = embeddingModelResponse && embeddingModelResponse.ok - ? ((await embeddingModelResponse.json()).model_name as string) + ? ((await embeddingModelResponse.json()) as FullEmbeddingModelResponse) : null; + const currentEmbeddingModelName = + embeddingModelVersionInfo?.current_model_name; + const nextEmbeddingModelName = + embeddingModelVersionInfo?.secondary_model_name; const defaultPersonaIdRaw = searchParams["personaId"]; const defaultPersonaId = defaultPersonaIdRaw @@ -147,10 +151,12 @@ export default async function Page({ {connectors.length === 0 ? ( - + ) : ( - !checkModelNameIsValid(embeddingModelName) && ( - + embeddingModelVersionInfo && + !checkModelNameIsValid(currentEmbeddingModelName) && + !nextEmbeddingModelName && ( + ) )} diff --git a/web/src/app/search/page.tsx b/web/src/app/search/page.tsx index e39a43dc137d..0903b9ff51fc 100644 --- a/web/src/app/search/page.tsx +++ b/web/src/app/search/page.tsx @@ -17,7 +17,10 @@ import { WelcomeModal } from "@/components/WelcomeModal"; import { unstable_noStore as noStore } from "next/cache"; import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh"; import { personaComparator } from "../admin/personas/lib"; -import { checkModelNameIsValid } from "../admin/models/embedding/embeddingModels"; +import { + FullEmbeddingModelResponse, + checkModelNameIsValid, +} from "../admin/models/embedding/embeddingModels"; import { SwitchModelModal } from "@/components/SwitchModelModal"; export default async function Home() { @@ -33,21 +36,19 @@ export default async function Home() { fetchSS("/manage/document-set"), fetchSS("/persona"), fetchSS("/query/valid-tags"), - fetchSS("/secondary-index/get-current-embedding-model"), + fetchSS("/secondary-index/get-embedding-models"), ]; // catch cases where the backend is completely unreachable here // without try / catch, will just raise an exception and the page // will not render - let results: (User | Response | AuthTypeMetadata | null)[] = [ - null, - null, - null, - null, - null, - null, - null, - ]; + let results: ( + | User + | Response + | AuthTypeMetadata + | FullEmbeddingModelResponse + | null + )[] = [null, null, null, null, null, null, null]; try { results = await Promise.all(tasks); } catch (e) { @@ -104,10 +105,15 @@ export default async function Home() { console.log(`Failed to fetch tags - ${tagsResponse?.status}`); } - const embeddingModelName = + const embeddingModelVersionInfo = embeddingModelResponse && embeddingModelResponse.ok - ? ((await embeddingModelResponse.json()).model_name as string) + ? ((await embeddingModelResponse.json()) as FullEmbeddingModelResponse) : null; + const currentEmbeddingModelName = + embeddingModelVersionInfo?.current_model_name; + const nextEmbeddingModelName = + embeddingModelVersionInfo?.secondary_model_name; + console.log(embeddingModelVersionInfo); // needs to be done in a non-client side component due to nextjs const storedSearchType = cookies().get("searchType")?.value as @@ -129,10 +135,12 @@ export default async function Home() { {connectors.length === 0 ? ( - + ) : ( - !checkModelNameIsValid(embeddingModelName) && ( - + embeddingModelVersionInfo && + !checkModelNameIsValid(currentEmbeddingModelName) && + !nextEmbeddingModelName && ( + ) )} diff --git a/web/src/components/SwitchModelModal.tsx b/web/src/components/SwitchModelModal.tsx index bdb1ec22a83c..d3174275d6f6 100644 --- a/web/src/components/SwitchModelModal.tsx +++ b/web/src/components/SwitchModelModal.tsx @@ -3,13 +3,11 @@ import { Button, Text } from "@tremor/react"; import { Modal } from "./Modal"; import Link from "next/link"; -import { FiCheckCircle } from "react-icons/fi"; -import { checkModelNameIsValid } from "@/app/admin/models/embedding/embeddingModels"; export function SwitchModelModal({ embeddingModelName, }: { - embeddingModelName: null | string; + embeddingModelName: undefined | null | string; }) { return ( diff --git a/web/src/components/WelcomeModal.tsx b/web/src/components/WelcomeModal.tsx index 06eb172e5b9d..8b0fd2b0b146 100644 --- a/web/src/components/WelcomeModal.tsx +++ b/web/src/components/WelcomeModal.tsx @@ -9,7 +9,7 @@ import { checkModelNameIsValid } from "@/app/admin/models/embedding/embeddingMod export function WelcomeModal({ embeddingModelName, }: { - embeddingModelName: null | string; + embeddingModelName: undefined | null | string; }) { const validModelSelected = checkModelNameIsValid(embeddingModelName);