mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-10 13:15:18 +02:00
Fix change model popup
This commit is contained in:
@@ -107,3 +107,8 @@ class SlackBotConfig(BaseModel):
|
|||||||
|
|
||||||
class ModelVersionResponse(BaseModel):
|
class ModelVersionResponse(BaseModel):
|
||||||
model_name: str | None # None only applicable to secondary index
|
model_name: str | None # None only applicable to secondary index
|
||||||
|
|
||||||
|
|
||||||
|
class FullModelVersionResponse(BaseModel):
|
||||||
|
current_model_name: str
|
||||||
|
secondary_model_name: str | None
|
||||||
|
@@ -5,6 +5,7 @@ from fastapi import status
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.auth.users import current_admin_user
|
from danswer.auth.users import current_admin_user
|
||||||
|
from danswer.auth.users import current_user
|
||||||
from danswer.db.embedding_model import create_embedding_model
|
from danswer.db.embedding_model import create_embedding_model
|
||||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||||
@@ -15,6 +16,7 @@ from danswer.db.models import IndexModelStatus
|
|||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
from danswer.indexing.models import EmbeddingModelDetail
|
from danswer.indexing.models import EmbeddingModelDetail
|
||||||
|
from danswer.server.manage.models import FullModelVersionResponse
|
||||||
from danswer.server.manage.models import ModelVersionResponse
|
from danswer.server.manage.models import ModelVersionResponse
|
||||||
from danswer.server.models import IdReturn
|
from danswer.server.models import IdReturn
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@@ -100,7 +102,7 @@ def cancel_new_embedding(
|
|||||||
|
|
||||||
@router.get("/get-current-embedding-model")
|
@router.get("/get-current-embedding-model")
|
||||||
def get_current_embedding_model(
|
def get_current_embedding_model(
|
||||||
_: User | None = Depends(current_admin_user),
|
_: User | None = Depends(current_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> ModelVersionResponse:
|
) -> ModelVersionResponse:
|
||||||
current_model = get_current_db_embedding_model(db_session)
|
current_model = get_current_db_embedding_model(db_session)
|
||||||
@@ -109,7 +111,7 @@ def get_current_embedding_model(
|
|||||||
|
|
||||||
@router.get("/get-secondary-embedding-model")
|
@router.get("/get-secondary-embedding-model")
|
||||||
def get_secondary_embedding_model(
|
def get_secondary_embedding_model(
|
||||||
_: User | None = Depends(current_admin_user),
|
_: User | None = Depends(current_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> ModelVersionResponse:
|
) -> ModelVersionResponse:
|
||||||
next_model = get_secondary_db_embedding_model(db_session)
|
next_model = get_secondary_db_embedding_model(db_session)
|
||||||
@@ -117,3 +119,16 @@ def get_secondary_embedding_model(
|
|||||||
return ModelVersionResponse(
|
return ModelVersionResponse(
|
||||||
model_name=next_model.model_name if next_model else None
|
model_name=next_model.model_name if next_model else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/get-embedding-models")
|
||||||
|
def get_embedding_models(
|
||||||
|
_: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> FullModelVersionResponse:
|
||||||
|
current_model = get_current_db_embedding_model(db_session)
|
||||||
|
next_model = get_secondary_db_embedding_model(db_session)
|
||||||
|
return FullModelVersionResponse(
|
||||||
|
current_model_name=current_model.model_name,
|
||||||
|
secondary_model_name=next_model.model_name if next_model else None,
|
||||||
|
)
|
||||||
|
@@ -2,6 +2,11 @@ export interface EmbeddingModelResponse {
|
|||||||
model_name: string | null;
|
model_name: string | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface FullEmbeddingModelResponse {
|
||||||
|
current_model_name: string;
|
||||||
|
secondary_model_name: string | null;
|
||||||
|
}
|
||||||
|
|
||||||
export interface EmbeddingModelDescriptor {
|
export interface EmbeddingModelDescriptor {
|
||||||
model_name: string;
|
model_name: string;
|
||||||
model_dim: number;
|
model_dim: number;
|
||||||
|
@@ -17,7 +17,7 @@ import { DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME } from "@/components/resizable/conta
|
|||||||
import { personaComparator } from "../admin/personas/lib";
|
import { personaComparator } from "../admin/personas/lib";
|
||||||
import { ChatLayout } from "./ChatPage";
|
import { ChatLayout } from "./ChatPage";
|
||||||
import {
|
import {
|
||||||
EmbeddingModelResponse,
|
FullEmbeddingModelResponse,
|
||||||
checkModelNameIsValid,
|
checkModelNameIsValid,
|
||||||
} from "../admin/models/embedding/embeddingModels";
|
} from "../admin/models/embedding/embeddingModels";
|
||||||
import { SwitchModelModal } from "@/components/SwitchModelModal";
|
import { SwitchModelModal } from "@/components/SwitchModelModal";
|
||||||
@@ -37,7 +37,7 @@ export default async function Page({
|
|||||||
fetchSS("/persona?include_default=true"),
|
fetchSS("/persona?include_default=true"),
|
||||||
fetchSS("/chat/get-user-chat-sessions"),
|
fetchSS("/chat/get-user-chat-sessions"),
|
||||||
fetchSS("/query/valid-tags"),
|
fetchSS("/query/valid-tags"),
|
||||||
fetchSS("/secondary-index/get-current-embedding-model"),
|
fetchSS("/secondary-index/get-embedding-models"),
|
||||||
];
|
];
|
||||||
|
|
||||||
// catch cases where the backend is completely unreachable here
|
// catch cases where the backend is completely unreachable here
|
||||||
@@ -47,7 +47,7 @@ export default async function Page({
|
|||||||
| User
|
| User
|
||||||
| Response
|
| Response
|
||||||
| AuthTypeMetadata
|
| AuthTypeMetadata
|
||||||
| EmbeddingModelResponse
|
| FullEmbeddingModelResponse
|
||||||
| null
|
| null
|
||||||
)[] = [null, null, null, null, null, null, null, null];
|
)[] = [null, null, null, null, null, null, null, null];
|
||||||
try {
|
try {
|
||||||
@@ -124,10 +124,14 @@ export default async function Page({
|
|||||||
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
|
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const embeddingModelName =
|
const embeddingModelVersionInfo =
|
||||||
embeddingModelResponse && embeddingModelResponse.ok
|
embeddingModelResponse && embeddingModelResponse.ok
|
||||||
? ((await embeddingModelResponse.json()).model_name as string)
|
? ((await embeddingModelResponse.json()) as FullEmbeddingModelResponse)
|
||||||
: null;
|
: null;
|
||||||
|
const currentEmbeddingModelName =
|
||||||
|
embeddingModelVersionInfo?.current_model_name;
|
||||||
|
const nextEmbeddingModelName =
|
||||||
|
embeddingModelVersionInfo?.secondary_model_name;
|
||||||
|
|
||||||
const defaultPersonaIdRaw = searchParams["personaId"];
|
const defaultPersonaIdRaw = searchParams["personaId"];
|
||||||
const defaultPersonaId = defaultPersonaIdRaw
|
const defaultPersonaId = defaultPersonaIdRaw
|
||||||
@@ -147,10 +151,12 @@ export default async function Page({
|
|||||||
<ApiKeyModal />
|
<ApiKeyModal />
|
||||||
|
|
||||||
{connectors.length === 0 ? (
|
{connectors.length === 0 ? (
|
||||||
<WelcomeModal embeddingModelName={embeddingModelName} />
|
<WelcomeModal embeddingModelName={currentEmbeddingModelName} />
|
||||||
) : (
|
) : (
|
||||||
!checkModelNameIsValid(embeddingModelName) && (
|
embeddingModelVersionInfo &&
|
||||||
<SwitchModelModal embeddingModelName={embeddingModelName} />
|
!checkModelNameIsValid(currentEmbeddingModelName) &&
|
||||||
|
!nextEmbeddingModelName && (
|
||||||
|
<SwitchModelModal embeddingModelName={currentEmbeddingModelName} />
|
||||||
)
|
)
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
@@ -17,7 +17,10 @@ import { WelcomeModal } from "@/components/WelcomeModal";
|
|||||||
import { unstable_noStore as noStore } from "next/cache";
|
import { unstable_noStore as noStore } from "next/cache";
|
||||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||||
import { personaComparator } from "../admin/personas/lib";
|
import { personaComparator } from "../admin/personas/lib";
|
||||||
import { checkModelNameIsValid } from "../admin/models/embedding/embeddingModels";
|
import {
|
||||||
|
FullEmbeddingModelResponse,
|
||||||
|
checkModelNameIsValid,
|
||||||
|
} from "../admin/models/embedding/embeddingModels";
|
||||||
import { SwitchModelModal } from "@/components/SwitchModelModal";
|
import { SwitchModelModal } from "@/components/SwitchModelModal";
|
||||||
|
|
||||||
export default async function Home() {
|
export default async function Home() {
|
||||||
@@ -33,21 +36,19 @@ export default async function Home() {
|
|||||||
fetchSS("/manage/document-set"),
|
fetchSS("/manage/document-set"),
|
||||||
fetchSS("/persona"),
|
fetchSS("/persona"),
|
||||||
fetchSS("/query/valid-tags"),
|
fetchSS("/query/valid-tags"),
|
||||||
fetchSS("/secondary-index/get-current-embedding-model"),
|
fetchSS("/secondary-index/get-embedding-models"),
|
||||||
];
|
];
|
||||||
|
|
||||||
// catch cases where the backend is completely unreachable here
|
// catch cases where the backend is completely unreachable here
|
||||||
// without try / catch, will just raise an exception and the page
|
// without try / catch, will just raise an exception and the page
|
||||||
// will not render
|
// will not render
|
||||||
let results: (User | Response | AuthTypeMetadata | null)[] = [
|
let results: (
|
||||||
null,
|
| User
|
||||||
null,
|
| Response
|
||||||
null,
|
| AuthTypeMetadata
|
||||||
null,
|
| FullEmbeddingModelResponse
|
||||||
null,
|
| null
|
||||||
null,
|
)[] = [null, null, null, null, null, null, null];
|
||||||
null,
|
|
||||||
];
|
|
||||||
try {
|
try {
|
||||||
results = await Promise.all(tasks);
|
results = await Promise.all(tasks);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
@@ -104,10 +105,15 @@ export default async function Home() {
|
|||||||
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
|
console.log(`Failed to fetch tags - ${tagsResponse?.status}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const embeddingModelName =
|
const embeddingModelVersionInfo =
|
||||||
embeddingModelResponse && embeddingModelResponse.ok
|
embeddingModelResponse && embeddingModelResponse.ok
|
||||||
? ((await embeddingModelResponse.json()).model_name as string)
|
? ((await embeddingModelResponse.json()) as FullEmbeddingModelResponse)
|
||||||
: null;
|
: null;
|
||||||
|
const currentEmbeddingModelName =
|
||||||
|
embeddingModelVersionInfo?.current_model_name;
|
||||||
|
const nextEmbeddingModelName =
|
||||||
|
embeddingModelVersionInfo?.secondary_model_name;
|
||||||
|
console.log(embeddingModelVersionInfo);
|
||||||
|
|
||||||
// needs to be done in a non-client side component due to nextjs
|
// needs to be done in a non-client side component due to nextjs
|
||||||
const storedSearchType = cookies().get("searchType")?.value as
|
const storedSearchType = cookies().get("searchType")?.value as
|
||||||
@@ -129,10 +135,12 @@ export default async function Home() {
|
|||||||
<InstantSSRAutoRefresh />
|
<InstantSSRAutoRefresh />
|
||||||
|
|
||||||
{connectors.length === 0 ? (
|
{connectors.length === 0 ? (
|
||||||
<WelcomeModal embeddingModelName={embeddingModelName} />
|
<WelcomeModal embeddingModelName={currentEmbeddingModelName} />
|
||||||
) : (
|
) : (
|
||||||
!checkModelNameIsValid(embeddingModelName) && (
|
embeddingModelVersionInfo &&
|
||||||
<SwitchModelModal embeddingModelName={embeddingModelName} />
|
!checkModelNameIsValid(currentEmbeddingModelName) &&
|
||||||
|
!nextEmbeddingModelName && (
|
||||||
|
<SwitchModelModal embeddingModelName={currentEmbeddingModelName} />
|
||||||
)
|
)
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
@@ -3,13 +3,11 @@
|
|||||||
import { Button, Text } from "@tremor/react";
|
import { Button, Text } from "@tremor/react";
|
||||||
import { Modal } from "./Modal";
|
import { Modal } from "./Modal";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { FiCheckCircle } from "react-icons/fi";
|
|
||||||
import { checkModelNameIsValid } from "@/app/admin/models/embedding/embeddingModels";
|
|
||||||
|
|
||||||
export function SwitchModelModal({
|
export function SwitchModelModal({
|
||||||
embeddingModelName,
|
embeddingModelName,
|
||||||
}: {
|
}: {
|
||||||
embeddingModelName: null | string;
|
embeddingModelName: undefined | null | string;
|
||||||
}) {
|
}) {
|
||||||
return (
|
return (
|
||||||
<Modal className="max-w-4xl">
|
<Modal className="max-w-4xl">
|
||||||
|
@@ -9,7 +9,7 @@ import { checkModelNameIsValid } from "@/app/admin/models/embedding/embeddingMod
|
|||||||
export function WelcomeModal({
|
export function WelcomeModal({
|
||||||
embeddingModelName,
|
embeddingModelName,
|
||||||
}: {
|
}: {
|
||||||
embeddingModelName: null | string;
|
embeddingModelName: undefined | null | string;
|
||||||
}) {
|
}) {
|
||||||
const validModelSelected = checkModelNameIsValid(embeddingModelName);
|
const validModelSelected = checkModelNameIsValid(embeddingModelName);
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user