Fix change model popup

This commit is contained in:
Weves
2024-01-30 10:08:46 -08:00
committed by Chris Weaver
parent b076c3d1ea
commit 8b9e6a91a4
7 changed files with 67 additions and 30 deletions

View File

@@ -107,3 +107,8 @@ class SlackBotConfig(BaseModel):
class ModelVersionResponse(BaseModel):
model_name: str | None # None only applicable to secondary index
class FullModelVersionResponse(BaseModel):
current_model_name: str
secondary_model_name: str | None

View File

@@ -5,6 +5,7 @@ from fastapi import status
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.embedding_model import create_embedding_model
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
@@ -15,6 +16,7 @@ from danswer.db.models import IndexModelStatus
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import EmbeddingModelDetail
from danswer.server.manage.models import FullModelVersionResponse
from danswer.server.manage.models import ModelVersionResponse
from danswer.server.models import IdReturn
from danswer.utils.logger import setup_logger
@@ -100,7 +102,7 @@ def cancel_new_embedding(
@router.get("/get-current-embedding-model")
def get_current_embedding_model(
_: User | None = Depends(current_admin_user),
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ModelVersionResponse:
current_model = get_current_db_embedding_model(db_session)
@@ -109,7 +111,7 @@ def get_current_embedding_model(
@router.get("/get-secondary-embedding-model")
def get_secondary_embedding_model(
_: User | None = Depends(current_admin_user),
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ModelVersionResponse:
next_model = get_secondary_db_embedding_model(db_session)
@@ -117,3 +119,16 @@ def get_secondary_embedding_model(
return ModelVersionResponse(
model_name=next_model.model_name if next_model else None
)
@router.get("/get-embedding-models")
def get_embedding_models(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> FullModelVersionResponse:
current_model = get_current_db_embedding_model(db_session)
next_model = get_secondary_db_embedding_model(db_session)
return FullModelVersionResponse(
current_model_name=current_model.model_name,
secondary_model_name=next_model.model_name if next_model else None,
)

View File

@@ -2,6 +2,11 @@ export interface EmbeddingModelResponse {
model_name: string | null;
}
export interface FullEmbeddingModelResponse {
current_model_name: string;
secondary_model_name: string | null;
}
export interface EmbeddingModelDescriptor {
model_name: string;
model_dim: number;

View File

@@ -17,7 +17,7 @@ import { DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME } from "@/components/resizable/conta
import { personaComparator } from "../admin/personas/lib";
import { ChatLayout } from "./ChatPage";
import {
EmbeddingModelResponse,
FullEmbeddingModelResponse,
checkModelNameIsValid,
} from "../admin/models/embedding/embeddingModels";
import { SwitchModelModal } from "@/components/SwitchModelModal";
@@ -37,7 +37,7 @@ export default async function Page({
fetchSS("/persona?include_default=true"),
fetchSS("/chat/get-user-chat-sessions"),
fetchSS("/query/valid-tags"),
fetchSS("/secondary-index/get-current-embedding-model"),
fetchSS("/secondary-index/get-embedding-models"),
];
// catch cases where the backend is completely unreachable here
@@ -47,7 +47,7 @@ export default async function Page({
| User
| Response
| AuthTypeMetadata
| EmbeddingModelResponse
| FullEmbeddingModelResponse
| null
)[] = [null, null, null, null, null, null, null, null];
try {
@@ -124,10 +124,14 @@ export default async function Page({
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
}
const embeddingModelName =
const embeddingModelVersionInfo =
embeddingModelResponse && embeddingModelResponse.ok
? ((await embeddingModelResponse.json()).model_name as string)
? ((await embeddingModelResponse.json()) as FullEmbeddingModelResponse)
: null;
const currentEmbeddingModelName =
embeddingModelVersionInfo?.current_model_name;
const nextEmbeddingModelName =
embeddingModelVersionInfo?.secondary_model_name;
const defaultPersonaIdRaw = searchParams["personaId"];
const defaultPersonaId = defaultPersonaIdRaw
@@ -147,10 +151,12 @@ export default async function Page({
<ApiKeyModal />
{connectors.length === 0 ? (
<WelcomeModal embeddingModelName={embeddingModelName} />
<WelcomeModal embeddingModelName={currentEmbeddingModelName} />
) : (
!checkModelNameIsValid(embeddingModelName) && (
<SwitchModelModal embeddingModelName={embeddingModelName} />
embeddingModelVersionInfo &&
!checkModelNameIsValid(currentEmbeddingModelName) &&
!nextEmbeddingModelName && (
<SwitchModelModal embeddingModelName={currentEmbeddingModelName} />
)
)}

View File

@@ -17,7 +17,10 @@ import { WelcomeModal } from "@/components/WelcomeModal";
import { unstable_noStore as noStore } from "next/cache";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import { personaComparator } from "../admin/personas/lib";
import { checkModelNameIsValid } from "../admin/models/embedding/embeddingModels";
import {
FullEmbeddingModelResponse,
checkModelNameIsValid,
} from "../admin/models/embedding/embeddingModels";
import { SwitchModelModal } from "@/components/SwitchModelModal";
export default async function Home() {
@@ -33,21 +36,19 @@ export default async function Home() {
fetchSS("/manage/document-set"),
fetchSS("/persona"),
fetchSS("/query/valid-tags"),
fetchSS("/secondary-index/get-current-embedding-model"),
fetchSS("/secondary-index/get-embedding-models"),
];
// catch cases where the backend is completely unreachable here
// without try / catch, will just raise an exception and the page
// will not render
let results: (User | Response | AuthTypeMetadata | null)[] = [
null,
null,
null,
null,
null,
null,
null,
];
let results: (
| User
| Response
| AuthTypeMetadata
| FullEmbeddingModelResponse
| null
)[] = [null, null, null, null, null, null, null];
try {
results = await Promise.all(tasks);
} catch (e) {
@@ -104,10 +105,15 @@ export default async function Home() {
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
}
const embeddingModelName =
const embeddingModelVersionInfo =
embeddingModelResponse && embeddingModelResponse.ok
? ((await embeddingModelResponse.json()).model_name as string)
? ((await embeddingModelResponse.json()) as FullEmbeddingModelResponse)
: null;
const currentEmbeddingModelName =
embeddingModelVersionInfo?.current_model_name;
const nextEmbeddingModelName =
embeddingModelVersionInfo?.secondary_model_name;
console.log(embeddingModelVersionInfo);
// needs to be done in a non-client side component due to nextjs
const storedSearchType = cookies().get("searchType")?.value as
@@ -129,10 +135,12 @@ export default async function Home() {
<InstantSSRAutoRefresh />
{connectors.length === 0 ? (
<WelcomeModal embeddingModelName={embeddingModelName} />
<WelcomeModal embeddingModelName={currentEmbeddingModelName} />
) : (
!checkModelNameIsValid(embeddingModelName) && (
<SwitchModelModal embeddingModelName={embeddingModelName} />
embeddingModelVersionInfo &&
!checkModelNameIsValid(currentEmbeddingModelName) &&
!nextEmbeddingModelName && (
<SwitchModelModal embeddingModelName={currentEmbeddingModelName} />
)
)}

View File

@@ -3,13 +3,11 @@
import { Button, Text } from "@tremor/react";
import { Modal } from "./Modal";
import Link from "next/link";
import { FiCheckCircle } from "react-icons/fi";
import { checkModelNameIsValid } from "@/app/admin/models/embedding/embeddingModels";
export function SwitchModelModal({
embeddingModelName,
}: {
embeddingModelName: null | string;
embeddingModelName: undefined | null | string;
}) {
return (
<Modal className="max-w-4xl">

View File

@@ -9,7 +9,7 @@ import { checkModelNameIsValid } from "@/app/admin/models/embedding/embeddingMod
export function WelcomeModal({
embeddingModelName,
}: {
embeddingModelName: null | string;
embeddingModelName: undefined | null | string;
}) {
const validModelSelected = checkModelNameIsValid(embeddingModelName);