Fix change model popup

This commit is contained in:
Weves
2024-01-30 10:08:46 -08:00
committed by Chris Weaver
parent b076c3d1ea
commit 8b9e6a91a4
7 changed files with 67 additions and 30 deletions

View File

@@ -107,3 +107,8 @@ class SlackBotConfig(BaseModel):
class ModelVersionResponse(BaseModel): class ModelVersionResponse(BaseModel):
model_name: str | None # None only applicable to secondary index model_name: str | None # None only applicable to secondary index
class FullModelVersionResponse(BaseModel):
current_model_name: str
secondary_model_name: str | None

View File

@@ -5,6 +5,7 @@ from fastapi import status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.embedding_model import create_embedding_model from danswer.db.embedding_model import create_embedding_model
from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model from danswer.db.embedding_model import get_secondary_db_embedding_model
@@ -15,6 +16,7 @@ from danswer.db.models import IndexModelStatus
from danswer.db.models import User from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import EmbeddingModelDetail from danswer.indexing.models import EmbeddingModelDetail
from danswer.server.manage.models import FullModelVersionResponse
from danswer.server.manage.models import ModelVersionResponse from danswer.server.manage.models import ModelVersionResponse
from danswer.server.models import IdReturn from danswer.server.models import IdReturn
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@@ -100,7 +102,7 @@ def cancel_new_embedding(
@router.get("/get-current-embedding-model") @router.get("/get-current-embedding-model")
def get_current_embedding_model( def get_current_embedding_model(
_: User | None = Depends(current_admin_user), _: User | None = Depends(current_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> ModelVersionResponse: ) -> ModelVersionResponse:
current_model = get_current_db_embedding_model(db_session) current_model = get_current_db_embedding_model(db_session)
@@ -109,7 +111,7 @@ def get_current_embedding_model(
@router.get("/get-secondary-embedding-model") @router.get("/get-secondary-embedding-model")
def get_secondary_embedding_model( def get_secondary_embedding_model(
_: User | None = Depends(current_admin_user), _: User | None = Depends(current_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> ModelVersionResponse: ) -> ModelVersionResponse:
next_model = get_secondary_db_embedding_model(db_session) next_model = get_secondary_db_embedding_model(db_session)
@@ -117,3 +119,16 @@ def get_secondary_embedding_model(
return ModelVersionResponse( return ModelVersionResponse(
model_name=next_model.model_name if next_model else None model_name=next_model.model_name if next_model else None
) )
@router.get("/get-embedding-models")
def get_embedding_models(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> FullModelVersionResponse:
current_model = get_current_db_embedding_model(db_session)
next_model = get_secondary_db_embedding_model(db_session)
return FullModelVersionResponse(
current_model_name=current_model.model_name,
secondary_model_name=next_model.model_name if next_model else None,
)

View File

@@ -2,6 +2,11 @@ export interface EmbeddingModelResponse {
model_name: string | null; model_name: string | null;
} }
export interface FullEmbeddingModelResponse {
current_model_name: string;
secondary_model_name: string | null;
}
export interface EmbeddingModelDescriptor { export interface EmbeddingModelDescriptor {
model_name: string; model_name: string;
model_dim: number; model_dim: number;

View File

@@ -17,7 +17,7 @@ import { DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME } from "@/components/resizable/conta
import { personaComparator } from "../admin/personas/lib"; import { personaComparator } from "../admin/personas/lib";
import { ChatLayout } from "./ChatPage"; import { ChatLayout } from "./ChatPage";
import { import {
EmbeddingModelResponse, FullEmbeddingModelResponse,
checkModelNameIsValid, checkModelNameIsValid,
} from "../admin/models/embedding/embeddingModels"; } from "../admin/models/embedding/embeddingModels";
import { SwitchModelModal } from "@/components/SwitchModelModal"; import { SwitchModelModal } from "@/components/SwitchModelModal";
@@ -37,7 +37,7 @@ export default async function Page({
fetchSS("/persona?include_default=true"), fetchSS("/persona?include_default=true"),
fetchSS("/chat/get-user-chat-sessions"), fetchSS("/chat/get-user-chat-sessions"),
fetchSS("/query/valid-tags"), fetchSS("/query/valid-tags"),
fetchSS("/secondary-index/get-current-embedding-model"), fetchSS("/secondary-index/get-embedding-models"),
]; ];
// catch cases where the backend is completely unreachable here // catch cases where the backend is completely unreachable here
@@ -47,7 +47,7 @@ export default async function Page({
| User | User
| Response | Response
| AuthTypeMetadata | AuthTypeMetadata
| EmbeddingModelResponse | FullEmbeddingModelResponse
| null | null
)[] = [null, null, null, null, null, null, null, null]; )[] = [null, null, null, null, null, null, null, null];
try { try {
@@ -124,10 +124,14 @@ export default async function Page({
console.log(`Failed to fetch tags - ${tagsResponse?.status}`); console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
} }
const embeddingModelName = const embeddingModelVersionInfo =
embeddingModelResponse && embeddingModelResponse.ok embeddingModelResponse && embeddingModelResponse.ok
? ((await embeddingModelResponse.json()).model_name as string) ? ((await embeddingModelResponse.json()) as FullEmbeddingModelResponse)
: null; : null;
const currentEmbeddingModelName =
embeddingModelVersionInfo?.current_model_name;
const nextEmbeddingModelName =
embeddingModelVersionInfo?.secondary_model_name;
const defaultPersonaIdRaw = searchParams["personaId"]; const defaultPersonaIdRaw = searchParams["personaId"];
const defaultPersonaId = defaultPersonaIdRaw const defaultPersonaId = defaultPersonaIdRaw
@@ -147,10 +151,12 @@ export default async function Page({
<ApiKeyModal /> <ApiKeyModal />
{connectors.length === 0 ? ( {connectors.length === 0 ? (
<WelcomeModal embeddingModelName={embeddingModelName} /> <WelcomeModal embeddingModelName={currentEmbeddingModelName} />
) : ( ) : (
!checkModelNameIsValid(embeddingModelName) && ( embeddingModelVersionInfo &&
<SwitchModelModal embeddingModelName={embeddingModelName} /> !checkModelNameIsValid(currentEmbeddingModelName) &&
!nextEmbeddingModelName && (
<SwitchModelModal embeddingModelName={currentEmbeddingModelName} />
) )
)} )}

View File

@@ -17,7 +17,10 @@ import { WelcomeModal } from "@/components/WelcomeModal";
import { unstable_noStore as noStore } from "next/cache"; import { unstable_noStore as noStore } from "next/cache";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh"; import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import { personaComparator } from "../admin/personas/lib"; import { personaComparator } from "../admin/personas/lib";
import { checkModelNameIsValid } from "../admin/models/embedding/embeddingModels"; import {
FullEmbeddingModelResponse,
checkModelNameIsValid,
} from "../admin/models/embedding/embeddingModels";
import { SwitchModelModal } from "@/components/SwitchModelModal"; import { SwitchModelModal } from "@/components/SwitchModelModal";
export default async function Home() { export default async function Home() {
@@ -33,21 +36,19 @@ export default async function Home() {
fetchSS("/manage/document-set"), fetchSS("/manage/document-set"),
fetchSS("/persona"), fetchSS("/persona"),
fetchSS("/query/valid-tags"), fetchSS("/query/valid-tags"),
fetchSS("/secondary-index/get-current-embedding-model"), fetchSS("/secondary-index/get-embedding-models"),
]; ];
// catch cases where the backend is completely unreachable here // catch cases where the backend is completely unreachable here
// without try / catch, will just raise an exception and the page // without try / catch, will just raise an exception and the page
// will not render // will not render
let results: (User | Response | AuthTypeMetadata | null)[] = [ let results: (
null, | User
null, | Response
null, | AuthTypeMetadata
null, | FullEmbeddingModelResponse
null, | null
null, )[] = [null, null, null, null, null, null, null];
null,
];
try { try {
results = await Promise.all(tasks); results = await Promise.all(tasks);
} catch (e) { } catch (e) {
@@ -104,10 +105,15 @@ export default async function Home() {
console.log(`Failed to fetch tags - ${tagsResponse?.status}`); console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
} }
const embeddingModelName = const embeddingModelVersionInfo =
embeddingModelResponse && embeddingModelResponse.ok embeddingModelResponse && embeddingModelResponse.ok
? ((await embeddingModelResponse.json()).model_name as string) ? ((await embeddingModelResponse.json()) as FullEmbeddingModelResponse)
: null; : null;
const currentEmbeddingModelName =
embeddingModelVersionInfo?.current_model_name;
const nextEmbeddingModelName =
embeddingModelVersionInfo?.secondary_model_name;
console.log(embeddingModelVersionInfo);
// needs to be done in a non-client side component due to nextjs // needs to be done in a non-client side component due to nextjs
const storedSearchType = cookies().get("searchType")?.value as const storedSearchType = cookies().get("searchType")?.value as
@@ -129,10 +135,12 @@ export default async function Home() {
<InstantSSRAutoRefresh /> <InstantSSRAutoRefresh />
{connectors.length === 0 ? ( {connectors.length === 0 ? (
<WelcomeModal embeddingModelName={embeddingModelName} /> <WelcomeModal embeddingModelName={currentEmbeddingModelName} />
) : ( ) : (
!checkModelNameIsValid(embeddingModelName) && ( embeddingModelVersionInfo &&
<SwitchModelModal embeddingModelName={embeddingModelName} /> !checkModelNameIsValid(currentEmbeddingModelName) &&
!nextEmbeddingModelName && (
<SwitchModelModal embeddingModelName={currentEmbeddingModelName} />
) )
)} )}

View File

@@ -3,13 +3,11 @@
import { Button, Text } from "@tremor/react"; import { Button, Text } from "@tremor/react";
import { Modal } from "./Modal"; import { Modal } from "./Modal";
import Link from "next/link"; import Link from "next/link";
import { FiCheckCircle } from "react-icons/fi";
import { checkModelNameIsValid } from "@/app/admin/models/embedding/embeddingModels";
export function SwitchModelModal({ export function SwitchModelModal({
embeddingModelName, embeddingModelName,
}: { }: {
embeddingModelName: null | string; embeddingModelName: undefined | null | string;
}) { }) {
return ( return (
<Modal className="max-w-4xl"> <Modal className="max-w-4xl">

View File

@@ -9,7 +9,7 @@ import { checkModelNameIsValid } from "@/app/admin/models/embedding/embeddingMod
export function WelcomeModal({ export function WelcomeModal({
embeddingModelName, embeddingModelName,
}: { }: {
embeddingModelName: null | string; embeddingModelName: undefined | null | string;
}) { }) {
const validModelSelected = checkModelNameIsValid(embeddingModelName); const validModelSelected = checkModelNameIsValid(embeddingModelName);