mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 20:38:32 +02:00
Fix change model popup
This commit is contained in:
@@ -107,3 +107,8 @@ class SlackBotConfig(BaseModel):
|
||||
|
||||
class ModelVersionResponse(BaseModel):
|
||||
model_name: str | None # None only applicable to secondary index
|
||||
|
||||
|
||||
class FullModelVersionResponse(BaseModel):
|
||||
current_model_name: str
|
||||
secondary_model_name: str | None
|
||||
|
@@ -5,6 +5,7 @@ from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.embedding_model import create_embedding_model
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
@@ -15,6 +16,7 @@ from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.server.manage.models import FullModelVersionResponse
|
||||
from danswer.server.manage.models import ModelVersionResponse
|
||||
from danswer.server.models import IdReturn
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -100,7 +102,7 @@ def cancel_new_embedding(
|
||||
|
||||
@router.get("/get-current-embedding-model")
|
||||
def get_current_embedding_model(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ModelVersionResponse:
|
||||
current_model = get_current_db_embedding_model(db_session)
|
||||
@@ -109,7 +111,7 @@ def get_current_embedding_model(
|
||||
|
||||
@router.get("/get-secondary-embedding-model")
|
||||
def get_secondary_embedding_model(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ModelVersionResponse:
|
||||
next_model = get_secondary_db_embedding_model(db_session)
|
||||
@@ -117,3 +119,16 @@ def get_secondary_embedding_model(
|
||||
return ModelVersionResponse(
|
||||
model_name=next_model.model_name if next_model else None
|
||||
)
|
||||
|
||||
|
||||
@router.get("/get-embedding-models")
|
||||
def get_embedding_models(
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FullModelVersionResponse:
|
||||
current_model = get_current_db_embedding_model(db_session)
|
||||
next_model = get_secondary_db_embedding_model(db_session)
|
||||
return FullModelVersionResponse(
|
||||
current_model_name=current_model.model_name,
|
||||
secondary_model_name=next_model.model_name if next_model else None,
|
||||
)
|
||||
|
@@ -2,6 +2,11 @@ export interface EmbeddingModelResponse {
|
||||
model_name: string | null;
|
||||
}
|
||||
|
||||
export interface FullEmbeddingModelResponse {
|
||||
current_model_name: string;
|
||||
secondary_model_name: string | null;
|
||||
}
|
||||
|
||||
export interface EmbeddingModelDescriptor {
|
||||
model_name: string;
|
||||
model_dim: number;
|
||||
|
@@ -17,7 +17,7 @@ import { DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME } from "@/components/resizable/conta
|
||||
import { personaComparator } from "../admin/personas/lib";
|
||||
import { ChatLayout } from "./ChatPage";
|
||||
import {
|
||||
EmbeddingModelResponse,
|
||||
FullEmbeddingModelResponse,
|
||||
checkModelNameIsValid,
|
||||
} from "../admin/models/embedding/embeddingModels";
|
||||
import { SwitchModelModal } from "@/components/SwitchModelModal";
|
||||
@@ -37,7 +37,7 @@ export default async function Page({
|
||||
fetchSS("/persona?include_default=true"),
|
||||
fetchSS("/chat/get-user-chat-sessions"),
|
||||
fetchSS("/query/valid-tags"),
|
||||
fetchSS("/secondary-index/get-current-embedding-model"),
|
||||
fetchSS("/secondary-index/get-embedding-models"),
|
||||
];
|
||||
|
||||
// catch cases where the backend is completely unreachable here
|
||||
@@ -47,7 +47,7 @@ export default async function Page({
|
||||
| User
|
||||
| Response
|
||||
| AuthTypeMetadata
|
||||
| EmbeddingModelResponse
|
||||
| FullEmbeddingModelResponse
|
||||
| null
|
||||
)[] = [null, null, null, null, null, null, null, null];
|
||||
try {
|
||||
@@ -124,10 +124,14 @@ export default async function Page({
|
||||
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
|
||||
}
|
||||
|
||||
const embeddingModelName =
|
||||
const embeddingModelVersionInfo =
|
||||
embeddingModelResponse && embeddingModelResponse.ok
|
||||
? ((await embeddingModelResponse.json()).model_name as string)
|
||||
? ((await embeddingModelResponse.json()) as FullEmbeddingModelResponse)
|
||||
: null;
|
||||
const currentEmbeddingModelName =
|
||||
embeddingModelVersionInfo?.current_model_name;
|
||||
const nextEmbeddingModelName =
|
||||
embeddingModelVersionInfo?.secondary_model_name;
|
||||
|
||||
const defaultPersonaIdRaw = searchParams["personaId"];
|
||||
const defaultPersonaId = defaultPersonaIdRaw
|
||||
@@ -147,10 +151,12 @@ export default async function Page({
|
||||
<ApiKeyModal />
|
||||
|
||||
{connectors.length === 0 ? (
|
||||
<WelcomeModal embeddingModelName={embeddingModelName} />
|
||||
<WelcomeModal embeddingModelName={currentEmbeddingModelName} />
|
||||
) : (
|
||||
!checkModelNameIsValid(embeddingModelName) && (
|
||||
<SwitchModelModal embeddingModelName={embeddingModelName} />
|
||||
embeddingModelVersionInfo &&
|
||||
!checkModelNameIsValid(currentEmbeddingModelName) &&
|
||||
!nextEmbeddingModelName && (
|
||||
<SwitchModelModal embeddingModelName={currentEmbeddingModelName} />
|
||||
)
|
||||
)}
|
||||
|
||||
|
@@ -17,7 +17,10 @@ import { WelcomeModal } from "@/components/WelcomeModal";
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import { personaComparator } from "../admin/personas/lib";
|
||||
import { checkModelNameIsValid } from "../admin/models/embedding/embeddingModels";
|
||||
import {
|
||||
FullEmbeddingModelResponse,
|
||||
checkModelNameIsValid,
|
||||
} from "../admin/models/embedding/embeddingModels";
|
||||
import { SwitchModelModal } from "@/components/SwitchModelModal";
|
||||
|
||||
export default async function Home() {
|
||||
@@ -33,21 +36,19 @@ export default async function Home() {
|
||||
fetchSS("/manage/document-set"),
|
||||
fetchSS("/persona"),
|
||||
fetchSS("/query/valid-tags"),
|
||||
fetchSS("/secondary-index/get-current-embedding-model"),
|
||||
fetchSS("/secondary-index/get-embedding-models"),
|
||||
];
|
||||
|
||||
// catch cases where the backend is completely unreachable here
|
||||
// without try / catch, will just raise an exception and the page
|
||||
// will not render
|
||||
let results: (User | Response | AuthTypeMetadata | null)[] = [
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
];
|
||||
let results: (
|
||||
| User
|
||||
| Response
|
||||
| AuthTypeMetadata
|
||||
| FullEmbeddingModelResponse
|
||||
| null
|
||||
)[] = [null, null, null, null, null, null, null];
|
||||
try {
|
||||
results = await Promise.all(tasks);
|
||||
} catch (e) {
|
||||
@@ -104,10 +105,15 @@ export default async function Home() {
|
||||
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
|
||||
}
|
||||
|
||||
const embeddingModelName =
|
||||
const embeddingModelVersionInfo =
|
||||
embeddingModelResponse && embeddingModelResponse.ok
|
||||
? ((await embeddingModelResponse.json()).model_name as string)
|
||||
? ((await embeddingModelResponse.json()) as FullEmbeddingModelResponse)
|
||||
: null;
|
||||
const currentEmbeddingModelName =
|
||||
embeddingModelVersionInfo?.current_model_name;
|
||||
const nextEmbeddingModelName =
|
||||
embeddingModelVersionInfo?.secondary_model_name;
|
||||
console.log(embeddingModelVersionInfo);
|
||||
|
||||
// needs to be done in a non-client side component due to nextjs
|
||||
const storedSearchType = cookies().get("searchType")?.value as
|
||||
@@ -129,10 +135,12 @@ export default async function Home() {
|
||||
<InstantSSRAutoRefresh />
|
||||
|
||||
{connectors.length === 0 ? (
|
||||
<WelcomeModal embeddingModelName={embeddingModelName} />
|
||||
<WelcomeModal embeddingModelName={currentEmbeddingModelName} />
|
||||
) : (
|
||||
!checkModelNameIsValid(embeddingModelName) && (
|
||||
<SwitchModelModal embeddingModelName={embeddingModelName} />
|
||||
embeddingModelVersionInfo &&
|
||||
!checkModelNameIsValid(currentEmbeddingModelName) &&
|
||||
!nextEmbeddingModelName && (
|
||||
<SwitchModelModal embeddingModelName={currentEmbeddingModelName} />
|
||||
)
|
||||
)}
|
||||
|
||||
|
@@ -3,13 +3,11 @@
|
||||
import { Button, Text } from "@tremor/react";
|
||||
import { Modal } from "./Modal";
|
||||
import Link from "next/link";
|
||||
import { FiCheckCircle } from "react-icons/fi";
|
||||
import { checkModelNameIsValid } from "@/app/admin/models/embedding/embeddingModels";
|
||||
|
||||
export function SwitchModelModal({
|
||||
embeddingModelName,
|
||||
}: {
|
||||
embeddingModelName: null | string;
|
||||
embeddingModelName: undefined | null | string;
|
||||
}) {
|
||||
return (
|
||||
<Modal className="max-w-4xl">
|
||||
|
@@ -9,7 +9,7 @@ import { checkModelNameIsValid } from "@/app/admin/models/embedding/embeddingMod
|
||||
export function WelcomeModal({
|
||||
embeddingModelName,
|
||||
}: {
|
||||
embeddingModelName: null | string;
|
||||
embeddingModelName: undefined | null | string;
|
||||
}) {
|
||||
const validModelSelected = checkModelNameIsValid(embeddingModelName);
|
||||
|
||||
|
Reference in New Issue
Block a user