diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index b77ddee85..a7a20fca3 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -29,7 +29,9 @@ from danswer.db.embedding_model import update_embedding_model_status from danswer.db.engine import get_db_current_time from danswer.db.engine import get_sqlalchemy_engine from danswer.db.index_attempt import cancel_indexing_attempts_past_model -from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts +from danswer.db.index_attempt import ( + count_unique_cc_pairs_with_successful_index_attempts, +) from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import get_inprogress_index_attempts @@ -365,9 +367,9 @@ def kickoff_indexing_jobs( def check_index_swap(db_session: Session) -> None: - """Get count of cc-pairs and count of index_attempts for the new model grouped by - connector + credential, if it's the same, then assume new index is done building. - This does not take into consideration if the attempt failed or not""" + """Get count of cc-pairs and count of successful index_attempts for the + new model grouped by connector + credential, if it's the same, then assume + new index is done building. If so, swap the indices and expire the old one.""" # Default CC-pair created for Ingestion API unused here all_cc_pairs = get_connector_credential_pairs(db_session) cc_pair_count = len(all_cc_pairs) - 1 @@ -376,7 +378,7 @@ def check_index_swap(db_session: Session) -> None: if not embedding_model: return - unique_cc_indexings = count_unique_cc_pairs_with_index_attempts( + unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts( embedding_model_id=embedding_model.id, db_session=db_session ) diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index ce913098e..4580140a5 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -291,7 +291,7 @@ def cancel_indexing_attempts_past_model( db_session.commit() -def count_unique_cc_pairs_with_index_attempts( +def count_unique_cc_pairs_with_successful_index_attempts( embedding_model_id: int | None, db_session: Session, ) -> int: @@ -299,12 +299,7 @@ def count_unique_cc_pairs_with_index_attempts( db_session.query(IndexAttempt.connector_id, IndexAttempt.credential_id) .filter( IndexAttempt.embedding_model_id == embedding_model_id, - # Should not be able to hang since indexing jobs expire after a limit - # It will then be marked failed, and the next cycle it will be in a completed state - or_( - IndexAttempt.status == IndexingStatus.SUCCESS, - IndexAttempt.status == IndexingStatus.FAILED, - ), + IndexAttempt.status == IndexingStatus.SUCCESS, ) .distinct() .count() diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index c875c88bd..68f9e3886 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from dataclasses import fields from datetime import datetime +from typing import TYPE_CHECKING from pydantic import BaseModel @@ -9,6 +10,9 @@ from danswer.configs.constants import DocumentSource from danswer.connectors.models import Document from danswer.utils.logger import setup_logger +if TYPE_CHECKING: + from danswer.db.models import EmbeddingModel + logger = setup_logger() @@ -130,3 +134,13 @@ class EmbeddingModelDetail(BaseModel): normalize: bool query_prefix: str | None passage_prefix: str | None + + @classmethod + def from_model(cls, embedding_model: "EmbeddingModel") -> "EmbeddingModelDetail": + return cls( + model_name=embedding_model.model_name, + model_dim=embedding_model.model_dim, + normalize=embedding_model.normalize, + query_prefix=embedding_model.query_prefix, + passage_prefix=embedding_model.passage_prefix, + ) diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py index a2ea4c7ab..8857ffc55 100644 --- a/backend/danswer/server/manage/models.py +++ b/backend/danswer/server/manage/models.py @@ -11,6 +11,7 @@ from danswer.db.models import AllowedAnswerFilters from danswer.db.models import ChannelConfig from danswer.db.models import SlackBotConfig as SlackBotConfigModel from danswer.db.models import SlackBotResponseType +from danswer.indexing.models import EmbeddingModelDetail from danswer.server.features.persona.models import PersonaSnapshot @@ -125,10 +126,6 @@ class SlackBotConfig(BaseModel): ) -class ModelVersionResponse(BaseModel): - model_name: str | None # None only applicable to secondary index - - class FullModelVersionResponse(BaseModel): - current_model_name: str - secondary_model_name: str | None + current_model: EmbeddingModelDetail + secondary_model: EmbeddingModelDetail | None diff --git a/backend/danswer/server/manage/secondary_index.py b/backend/danswer/server/manage/secondary_index.py index c4c51c0e3..6f5adf752 100644 --- a/backend/danswer/server/manage/secondary_index.py +++ b/backend/danswer/server/manage/secondary_index.py @@ -20,7 +20,6 @@ from danswer.db.models import User from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import EmbeddingModelDetail from danswer.server.manage.models import FullModelVersionResponse -from danswer.server.manage.models import ModelVersionResponse from danswer.server.models import IdReturn from danswer.utils.logger import setup_logger @@ -115,21 +114,21 @@ def cancel_new_embedding( def get_current_embedding_model( _: User | None = Depends(current_user), db_session: Session = Depends(get_session), -) -> ModelVersionResponse: +) -> EmbeddingModelDetail: current_model = get_current_db_embedding_model(db_session) - return ModelVersionResponse(model_name=current_model.model_name) + return EmbeddingModelDetail.from_model(current_model) @router.get("/get-secondary-embedding-model") def get_secondary_embedding_model( _: User | None = Depends(current_user), db_session: Session = Depends(get_session), -) -> ModelVersionResponse: +) -> EmbeddingModelDetail | None: next_model = get_secondary_db_embedding_model(db_session) + if not next_model: + return None - return ModelVersionResponse( - model_name=next_model.model_name if next_model else None - ) + return EmbeddingModelDetail.from_model(next_model) @router.get("/get-embedding-models") @@ -140,6 +139,8 @@ def get_embedding_models( current_model = get_current_db_embedding_model(db_session) next_model = get_secondary_db_embedding_model(db_session) return FullModelVersionResponse( - current_model_name=current_model.model_name, - secondary_model_name=next_model.model_name if next_model else None, + current_model=EmbeddingModelDetail.from_model(current_model), + secondary_model=EmbeddingModelDetail.from_model(next_model) + if next_model + else None, ) diff --git a/web/src/app/admin/models/embedding/CustomModelForm.tsx b/web/src/app/admin/models/embedding/CustomModelForm.tsx new file mode 100644 index 000000000..23676bc61 --- /dev/null +++ b/web/src/app/admin/models/embedding/CustomModelForm.tsx @@ -0,0 +1,116 @@ +import { + BooleanFormField, + TextFormField, +} from "@/components/admin/connectors/Field"; +import { Button, Divider, Text } from "@tremor/react"; +import { Form, Formik } from "formik"; + +import * as Yup from "yup"; +import { EmbeddingModelDescriptor } from "./embeddingModels"; + +export function CustomModelForm({ + onSubmit, +}: { + onSubmit: (model: EmbeddingModelDescriptor) => void; +}) { + return ( +
+ { + onSubmit({ ...values, model_dim: parseInt(values.model_dim) }); + }} + > + {({ isSubmitting, setFieldValue }) => ( +
+ + + { + const value = e.target.value; + // Allow only integer values + if (value === "" || /^[0-9]+$/.test(value)) { + setFieldValue("model_dim", value); + } + }} + /> + + + The prefix specified by the model creators which should be + prepended to queries before passing them to the model. + Many models do not have this, in which case this should be + left empty. + + } + placeholder="E.g. 'query: '" + autoCompleteDisabled={true} + /> + + + The prefix specified by the model creators which should be + prepended to passages before passing them to the model. + Many models do not have this, in which case this should be + left empty. + + } + placeholder="E.g. 'passage: '" + autoCompleteDisabled={true} + /> + + + +
+ +
+ + )} +
+
+ ); +} diff --git a/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx b/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx index 949c5d46d..7572ac2ce 100644 --- a/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx +++ b/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx @@ -1,18 +1,21 @@ import { Modal } from "@/components/Modal"; -import { Button, Text } from "@tremor/react"; +import { Button, Text, Callout } from "@tremor/react"; +import { EmbeddingModelDescriptor } from "./embeddingModels"; export function ModelSelectionConfirmaion({ selectedModel, + isCustom, onConfirm, }: { - selectedModel: string; + selectedModel: EmbeddingModelDescriptor; + isCustom: boolean; onConfirm: () => void; }) { return (
- You have selected: {selectedModel}. Are you sure you want to - update to this new embedding model? + You have selected: {selectedModel.model_name}. Are you sure you + want to update to this new embedding model? We will re-index all your documents in the background so you will be @@ -25,6 +28,18 @@ export function ModelSelectionConfirmaion({ normal. If you are self-hosting, we recommend that you allocate at least 16GB of RAM to Danswer during this process. + + {isCustom && ( + + We've detected that this is a custom-specified embedding model. + Since we have to download the model files before verifying the + configuration's correctness, we won't be able to let you + know if the configuration is valid until after we start + re-indexing your documents. If there is an issue, it will show up on + this page as an indexing error on this page after clicking Confirm. + + )} +
@@ -61,17 +69,19 @@ export function ModelSelector({ setSelectedModel, }: { modelOptions: FullEmbeddingModelDescriptor[]; - setSelectedModel: (modelName: string) => void; + setSelectedModel: (model: EmbeddingModelDescriptor) => void; }) { return ( -
- {modelOptions.map((modelOption) => ( - - ))} +
+
+ {modelOptions.map((modelOption) => ( + + ))} +
); } diff --git a/web/src/app/admin/models/embedding/ReindexingProgressTable.tsx b/web/src/app/admin/models/embedding/ReindexingProgressTable.tsx index 3b366c192..b1f91d24b 100644 --- a/web/src/app/admin/models/embedding/ReindexingProgressTable.tsx +++ b/web/src/app/admin/models/embedding/ReindexingProgressTable.tsx @@ -1,14 +1,14 @@ import { PageSelector } from "@/components/PageSelector"; -import { CCPairStatus, IndexAttemptStatus } from "@/components/Status"; -import { ConnectorIndexingStatus, ValidStatuses } from "@/lib/types"; +import { IndexAttemptStatus } from "@/components/Status"; +import { ConnectorIndexingStatus } from "@/lib/types"; import { - Button, Table, TableBody, TableCell, TableHead, TableHeaderCell, TableRow, + Text, } from "@tremor/react"; import Link from "next/link"; import { useState } from "react"; @@ -30,6 +30,7 @@ export function ReindexingProgressTable({ Connector Name Status Docs Re-Indexed + Error Message @@ -58,6 +59,13 @@ export function ReindexingProgressTable({ {reindexingProgress?.latest_index_attempt ?.total_docs_indexed || "-"} + +
+ + {reindexingProgress.error_msg || "-"} + +
+
); })} diff --git a/web/src/app/admin/models/embedding/embeddingModels.ts b/web/src/app/admin/models/embedding/embeddingModels.ts index 64ccfff95..7c5d09180 100644 --- a/web/src/app/admin/models/embedding/embeddingModels.ts +++ b/web/src/app/admin/models/embedding/embeddingModels.ts @@ -76,3 +76,12 @@ export function checkModelNameIsValid(modelName: string | undefined | null) { } return true; } + +export function fillOutEmeddingModelDescriptor( + embeddingModel: EmbeddingModelDescriptor | FullEmbeddingModelDescriptor +): FullEmbeddingModelDescriptor { + return { + ...embeddingModel, + description: "", + }; +} diff --git a/web/src/app/admin/models/embedding/page.tsx b/web/src/app/admin/models/embedding/page.tsx index 5f4cd1c93..0612fe2c6 100644 --- a/web/src/app/admin/models/embedding/page.tsx +++ b/web/src/app/admin/models/embedding/page.tsx @@ -6,7 +6,7 @@ import { KeyIcon, TrashIcon } from "@/components/icons/icons"; import { ApiKeyForm } from "@/components/openai/ApiKeyForm"; import { GEN_AI_API_KEY_URL } from "@/components/openai/constants"; import { errorHandlingFetcher, fetcher } from "@/lib/fetcher"; -import { Button, Divider, Text, Title } from "@tremor/react"; +import { Button, Card, Divider, Text, Title } from "@tremor/react"; import { FiCpu, FiPackage } from "react-icons/fi"; import useSWR, { mutate } from "swr"; import { ModelOption, ModelSelector } from "./ModelSelector"; @@ -16,17 +16,18 @@ import { ReindexingProgressTable } from "./ReindexingProgressTable"; import { Modal } from "@/components/Modal"; import { AVAILABLE_MODELS, - EmbeddingModelResponse, + EmbeddingModelDescriptor, INVALID_OLD_MODEL, + fillOutEmeddingModelDescriptor, } from "./embeddingModels"; import { ErrorCallout } from "@/components/ErrorCallout"; import { Connector, ConnectorIndexingStatus } from "@/lib/types"; import Link from "next/link"; +import { CustomModelForm } from "./CustomModelForm"; function Main() { - const [tentativeNewEmbeddingModel, setTentativeNewEmbeddingModel] = useState< - string | null - >(null); + const [tentativeNewEmbeddingModel, setTentativeNewEmbeddingModel] = + useState(null); const [isCancelling, setIsCancelling] = useState(false); const [showAddConnectorPopup, setShowAddConnectorPopup] = useState(false); @@ -35,16 +36,16 @@ function Main() { data: currentEmeddingModel, isLoading: isLoadingCurrentModel, error: currentEmeddingModelError, - } = useSWR( + } = useSWR( "/api/secondary-index/get-current-embedding-model", errorHandlingFetcher, { refreshInterval: 5000 } // 5 seconds ); const { - data: futureEmeddingModel, + data: futureEmbeddingModel, isLoading: isLoadingFutureModel, error: futureEmeddingModelError, - } = useSWR( + } = useSWR( "/api/secondary-index/get-secondary-embedding-model", errorHandlingFetcher, { refreshInterval: 5000 } // 5 seconds @@ -63,24 +64,20 @@ function Main() { { refreshInterval: 5000 } // 5 seconds ); - const onSelect = async (modelName: string) => { + const onSelect = async (model: EmbeddingModelDescriptor) => { if (currentEmeddingModel?.model_name === INVALID_OLD_MODEL) { - await onConfirm(modelName); + await onConfirm(model); } else { - setTentativeNewEmbeddingModel(modelName); + setTentativeNewEmbeddingModel(model); } }; - const onConfirm = async (modelName: string) => { - const modelDescriptor = AVAILABLE_MODELS.find( - (model) => model.model_name === modelName - ); - + const onConfirm = async (model: EmbeddingModelDescriptor) => { const response = await fetch( "/api/secondary-index/set-new-embedding-model", { method: "POST", - body: JSON.stringify(modelDescriptor), + body: JSON.stringify(model), headers: { "Content-Type": "application/json", }, @@ -120,26 +117,33 @@ function Main() { if ( currentEmeddingModelError || !currentEmeddingModel || - futureEmeddingModelError || - !futureEmeddingModel + futureEmeddingModelError ) { return ; } const currentModelName = currentEmeddingModel.model_name; - const currentModel = AVAILABLE_MODELS.find( - (model) => model.model_name === currentModelName - ); + const currentModel = + AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) || + fillOutEmeddingModelDescriptor(currentEmeddingModel); - const newModelSelection = AVAILABLE_MODELS.find( - (model) => model.model_name === futureEmeddingModel.model_name - ); + const newModelSelection = futureEmbeddingModel + ? AVAILABLE_MODELS.find( + (model) => model.model_name === futureEmbeddingModel.model_name + ) || fillOutEmeddingModelDescriptor(futureEmbeddingModel) + : null; return (
{tentativeNewEmbeddingModel && ( + model.model_name === tentativeNewEmbeddingModel.model_name + ) === undefined + } onConfirm={() => onConfirm(tentativeNewEmbeddingModel)} onCancel={() => setTentativeNewEmbeddingModel(null)} /> @@ -243,12 +247,49 @@ function Main() { )} + + Below are a curated selection of quality models that we recommend + you choose from. + + modelOption.model_name !== currentModelName )} setSelectedModel={onSelect} /> + + + Alternatively, (if you know what you're doing) you can + specify a{" "} + + SentenceTransformers + + -compatible model of your choice below. The rough list of + supported models can be found{" "} + + here + + . +
+ NOTE: not all models listed will work with Danswer, since + some have unique interfaces or special requirements. If in doubt, + reach out to the Danswer team. + + +
+ + + +
) : ( connectors && @@ -272,10 +313,10 @@ function Main() { The table below shows the re-indexing progress of all existing - connectors. Once all connectors have been re-indexed, the new - model will be used for all search queries. Until then, we will - use the old model so that no downtime is necessary during this - transition. + connectors. Once all connectors have been re-indexed + successfully, the new model will be used for all search + queries. Until then, we will use the old model so that no + downtime is necessary during this transition. {isLoadingOngoingReIndexingStatus ? (