diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py
index 01cff7c97ab9..ee4470eab14c 100644
--- a/backend/danswer/server/manage/models.py
+++ b/backend/danswer/server/manage/models.py
@@ -107,3 +107,8 @@ class SlackBotConfig(BaseModel):
class ModelVersionResponse(BaseModel):
model_name: str | None # None only applicable to secondary index
+
+
+class FullModelVersionResponse(BaseModel):
+ current_model_name: str
+ secondary_model_name: str | None
diff --git a/backend/danswer/server/manage/secondary_index.py b/backend/danswer/server/manage/secondary_index.py
index d1c5ffdb2780..2013cfc5bb8d 100644
--- a/backend/danswer/server/manage/secondary_index.py
+++ b/backend/danswer/server/manage/secondary_index.py
@@ -5,6 +5,7 @@ from fastapi import status
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
+from danswer.auth.users import current_user
from danswer.db.embedding_model import create_embedding_model
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
@@ -15,6 +16,7 @@ from danswer.db.models import IndexModelStatus
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import EmbeddingModelDetail
+from danswer.server.manage.models import FullModelVersionResponse
from danswer.server.manage.models import ModelVersionResponse
from danswer.server.models import IdReturn
from danswer.utils.logger import setup_logger
@@ -100,7 +102,7 @@ def cancel_new_embedding(
@router.get("/get-current-embedding-model")
def get_current_embedding_model(
- _: User | None = Depends(current_admin_user),
+ _: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ModelVersionResponse:
current_model = get_current_db_embedding_model(db_session)
@@ -109,7 +111,7 @@ def get_current_embedding_model(
@router.get("/get-secondary-embedding-model")
def get_secondary_embedding_model(
- _: User | None = Depends(current_admin_user),
+ _: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ModelVersionResponse:
next_model = get_secondary_db_embedding_model(db_session)
@@ -117,3 +119,16 @@ def get_secondary_embedding_model(
return ModelVersionResponse(
model_name=next_model.model_name if next_model else None
)
+
+
+@router.get("/get-embedding-models")
+def get_embedding_models(
+ _: User | None = Depends(current_user),
+ db_session: Session = Depends(get_session),
+) -> FullModelVersionResponse:
+ current_model = get_current_db_embedding_model(db_session)
+ next_model = get_secondary_db_embedding_model(db_session)
+ return FullModelVersionResponse(
+ current_model_name=current_model.model_name,
+ secondary_model_name=next_model.model_name if next_model else None,
+ )
diff --git a/web/src/app/admin/models/embedding/embeddingModels.ts b/web/src/app/admin/models/embedding/embeddingModels.ts
index 4bacdab99181..64ccfff9581a 100644
--- a/web/src/app/admin/models/embedding/embeddingModels.ts
+++ b/web/src/app/admin/models/embedding/embeddingModels.ts
@@ -2,6 +2,11 @@ export interface EmbeddingModelResponse {
model_name: string | null;
}
+export interface FullEmbeddingModelResponse {
+ current_model_name: string;
+ secondary_model_name: string | null;
+}
+
export interface EmbeddingModelDescriptor {
model_name: string;
model_dim: number;
diff --git a/web/src/app/chat/page.tsx b/web/src/app/chat/page.tsx
index 169b55232b0e..eb8ff8918c53 100644
--- a/web/src/app/chat/page.tsx
+++ b/web/src/app/chat/page.tsx
@@ -17,7 +17,7 @@ import { DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME } from "@/components/resizable/conta
import { personaComparator } from "../admin/personas/lib";
import { ChatLayout } from "./ChatPage";
import {
- EmbeddingModelResponse,
+ FullEmbeddingModelResponse,
checkModelNameIsValid,
} from "../admin/models/embedding/embeddingModels";
import { SwitchModelModal } from "@/components/SwitchModelModal";
@@ -37,7 +37,7 @@ export default async function Page({
fetchSS("/persona?include_default=true"),
fetchSS("/chat/get-user-chat-sessions"),
fetchSS("/query/valid-tags"),
- fetchSS("/secondary-index/get-current-embedding-model"),
+ fetchSS("/secondary-index/get-embedding-models"),
];
// catch cases where the backend is completely unreachable here
@@ -47,7 +47,7 @@ export default async function Page({
| User
| Response
| AuthTypeMetadata
- | EmbeddingModelResponse
+ | FullEmbeddingModelResponse
| null
)[] = [null, null, null, null, null, null, null, null];
try {
@@ -124,10 +124,14 @@ export default async function Page({
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
}
- const embeddingModelName =
+ const embeddingModelVersionInfo =
embeddingModelResponse && embeddingModelResponse.ok
- ? ((await embeddingModelResponse.json()).model_name as string)
+ ? ((await embeddingModelResponse.json()) as FullEmbeddingModelResponse)
: null;
+ const currentEmbeddingModelName =
+ embeddingModelVersionInfo?.current_model_name;
+ const nextEmbeddingModelName =
+ embeddingModelVersionInfo?.secondary_model_name;
const defaultPersonaIdRaw = searchParams["personaId"];
const defaultPersonaId = defaultPersonaIdRaw
@@ -147,10 +151,12 @@ export default async function Page({
{connectors.length === 0 ? (
-
+
) : (
- !checkModelNameIsValid(embeddingModelName) && (
-
+ embeddingModelVersionInfo &&
+ !checkModelNameIsValid(currentEmbeddingModelName) &&
+ !nextEmbeddingModelName && (
+
)
)}
diff --git a/web/src/app/search/page.tsx b/web/src/app/search/page.tsx
index e39a43dc137d..0903b9ff51fc 100644
--- a/web/src/app/search/page.tsx
+++ b/web/src/app/search/page.tsx
@@ -17,7 +17,10 @@ import { WelcomeModal } from "@/components/WelcomeModal";
import { unstable_noStore as noStore } from "next/cache";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import { personaComparator } from "../admin/personas/lib";
-import { checkModelNameIsValid } from "../admin/models/embedding/embeddingModels";
+import {
+ FullEmbeddingModelResponse,
+ checkModelNameIsValid,
+} from "../admin/models/embedding/embeddingModels";
import { SwitchModelModal } from "@/components/SwitchModelModal";
export default async function Home() {
@@ -33,21 +36,19 @@ export default async function Home() {
fetchSS("/manage/document-set"),
fetchSS("/persona"),
fetchSS("/query/valid-tags"),
- fetchSS("/secondary-index/get-current-embedding-model"),
+ fetchSS("/secondary-index/get-embedding-models"),
];
// catch cases where the backend is completely unreachable here
// without try / catch, will just raise an exception and the page
// will not render
- let results: (User | Response | AuthTypeMetadata | null)[] = [
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- ];
+ let results: (
+ | User
+ | Response
+ | AuthTypeMetadata
+ | FullEmbeddingModelResponse
+ | null
+ )[] = [null, null, null, null, null, null, null];
try {
results = await Promise.all(tasks);
} catch (e) {
@@ -104,10 +105,15 @@ export default async function Home() {
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
}
- const embeddingModelName =
+ const embeddingModelVersionInfo =
embeddingModelResponse && embeddingModelResponse.ok
- ? ((await embeddingModelResponse.json()).model_name as string)
+ ? ((await embeddingModelResponse.json()) as FullEmbeddingModelResponse)
: null;
+ const currentEmbeddingModelName =
+ embeddingModelVersionInfo?.current_model_name;
+ const nextEmbeddingModelName =
+ embeddingModelVersionInfo?.secondary_model_name;
+ console.log(embeddingModelVersionInfo);
// needs to be done in a non-client side component due to nextjs
const storedSearchType = cookies().get("searchType")?.value as
@@ -129,10 +135,12 @@ export default async function Home() {
{connectors.length === 0 ? (
-
+
) : (
- !checkModelNameIsValid(embeddingModelName) && (
-
+ embeddingModelVersionInfo &&
+ !checkModelNameIsValid(currentEmbeddingModelName) &&
+ !nextEmbeddingModelName && (
+
)
)}
diff --git a/web/src/components/SwitchModelModal.tsx b/web/src/components/SwitchModelModal.tsx
index bdb1ec22a83c..d3174275d6f6 100644
--- a/web/src/components/SwitchModelModal.tsx
+++ b/web/src/components/SwitchModelModal.tsx
@@ -3,13 +3,11 @@
import { Button, Text } from "@tremor/react";
import { Modal } from "./Modal";
import Link from "next/link";
-import { FiCheckCircle } from "react-icons/fi";
-import { checkModelNameIsValid } from "@/app/admin/models/embedding/embeddingModels";
export function SwitchModelModal({
embeddingModelName,
}: {
- embeddingModelName: null | string;
+ embeddingModelName: undefined | null | string;
}) {
return (
diff --git a/web/src/components/WelcomeModal.tsx b/web/src/components/WelcomeModal.tsx
index 06eb172e5b9d..8b0fd2b0b146 100644
--- a/web/src/components/WelcomeModal.tsx
+++ b/web/src/components/WelcomeModal.tsx
@@ -9,7 +9,7 @@ import { checkModelNameIsValid } from "@/app/admin/models/embedding/embeddingMod
export function WelcomeModal({
embeddingModelName,
}: {
- embeddingModelName: null | string;
+ embeddingModelName: undefined | null | string;
}) {
const validModelSelected = checkModelNameIsValid(embeddingModelName);