diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py
index b77ddee85..a7a20fca3 100755
--- a/backend/danswer/background/update.py
+++ b/backend/danswer/background/update.py
@@ -29,7 +29,9 @@ from danswer.db.embedding_model import update_embedding_model_status
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
-from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts
+from danswer.db.index_attempt import (
+ count_unique_cc_pairs_with_successful_index_attempts,
+)
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
@@ -365,9 +367,9 @@ def kickoff_indexing_jobs(
def check_index_swap(db_session: Session) -> None:
- """Get count of cc-pairs and count of index_attempts for the new model grouped by
- connector + credential, if it's the same, then assume new index is done building.
- This does not take into consideration if the attempt failed or not"""
+ """Get count of cc-pairs and count of successful index_attempts for the
+ new model grouped by connector + credential, if it's the same, then assume
+ new index is done building. If so, swap the indices and expire the old one."""
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = len(all_cc_pairs) - 1
@@ -376,7 +378,7 @@ def check_index_swap(db_session: Session) -> None:
if not embedding_model:
return
- unique_cc_indexings = count_unique_cc_pairs_with_index_attempts(
+ unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts(
embedding_model_id=embedding_model.id, db_session=db_session
)
diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py
index ce913098e..4580140a5 100644
--- a/backend/danswer/db/index_attempt.py
+++ b/backend/danswer/db/index_attempt.py
@@ -291,7 +291,7 @@ def cancel_indexing_attempts_past_model(
db_session.commit()
-def count_unique_cc_pairs_with_index_attempts(
+def count_unique_cc_pairs_with_successful_index_attempts(
embedding_model_id: int | None,
db_session: Session,
) -> int:
@@ -299,12 +299,7 @@ def count_unique_cc_pairs_with_index_attempts(
db_session.query(IndexAttempt.connector_id, IndexAttempt.credential_id)
.filter(
IndexAttempt.embedding_model_id == embedding_model_id,
- # Should not be able to hang since indexing jobs expire after a limit
- # It will then be marked failed, and the next cycle it will be in a completed state
- or_(
- IndexAttempt.status == IndexingStatus.SUCCESS,
- IndexAttempt.status == IndexingStatus.FAILED,
- ),
+ IndexAttempt.status == IndexingStatus.SUCCESS,
)
.distinct()
.count()
diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py
index c875c88bd..68f9e3886 100644
--- a/backend/danswer/indexing/models.py
+++ b/backend/danswer/indexing/models.py
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from dataclasses import fields
from datetime import datetime
+from typing import TYPE_CHECKING
from pydantic import BaseModel
@@ -9,6 +10,9 @@ from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.utils.logger import setup_logger
+if TYPE_CHECKING:
+ from danswer.db.models import EmbeddingModel
+
logger = setup_logger()
@@ -130,3 +134,13 @@ class EmbeddingModelDetail(BaseModel):
normalize: bool
query_prefix: str | None
passage_prefix: str | None
+
+ @classmethod
+ def from_model(cls, embedding_model: "EmbeddingModel") -> "EmbeddingModelDetail":
+ return cls(
+ model_name=embedding_model.model_name,
+ model_dim=embedding_model.model_dim,
+ normalize=embedding_model.normalize,
+ query_prefix=embedding_model.query_prefix,
+ passage_prefix=embedding_model.passage_prefix,
+ )
diff --git a/backend/danswer/server/manage/models.py b/backend/danswer/server/manage/models.py
index a2ea4c7ab..8857ffc55 100644
--- a/backend/danswer/server/manage/models.py
+++ b/backend/danswer/server/manage/models.py
@@ -11,6 +11,7 @@ from danswer.db.models import AllowedAnswerFilters
from danswer.db.models import ChannelConfig
from danswer.db.models import SlackBotConfig as SlackBotConfigModel
from danswer.db.models import SlackBotResponseType
+from danswer.indexing.models import EmbeddingModelDetail
from danswer.server.features.persona.models import PersonaSnapshot
@@ -125,10 +126,6 @@ class SlackBotConfig(BaseModel):
)
-class ModelVersionResponse(BaseModel):
- model_name: str | None # None only applicable to secondary index
-
-
class FullModelVersionResponse(BaseModel):
- current_model_name: str
- secondary_model_name: str | None
+ current_model: EmbeddingModelDetail
+ secondary_model: EmbeddingModelDetail | None
diff --git a/backend/danswer/server/manage/secondary_index.py b/backend/danswer/server/manage/secondary_index.py
index c4c51c0e3..6f5adf752 100644
--- a/backend/danswer/server/manage/secondary_index.py
+++ b/backend/danswer/server/manage/secondary_index.py
@@ -20,7 +20,6 @@ from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import EmbeddingModelDetail
from danswer.server.manage.models import FullModelVersionResponse
-from danswer.server.manage.models import ModelVersionResponse
from danswer.server.models import IdReturn
from danswer.utils.logger import setup_logger
@@ -115,21 +114,21 @@ def cancel_new_embedding(
def get_current_embedding_model(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
-) -> ModelVersionResponse:
+) -> EmbeddingModelDetail:
current_model = get_current_db_embedding_model(db_session)
- return ModelVersionResponse(model_name=current_model.model_name)
+ return EmbeddingModelDetail.from_model(current_model)
@router.get("/get-secondary-embedding-model")
def get_secondary_embedding_model(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
-) -> ModelVersionResponse:
+) -> EmbeddingModelDetail | None:
next_model = get_secondary_db_embedding_model(db_session)
+ if not next_model:
+ return None
- return ModelVersionResponse(
- model_name=next_model.model_name if next_model else None
- )
+ return EmbeddingModelDetail.from_model(next_model)
@router.get("/get-embedding-models")
@@ -140,6 +139,8 @@ def get_embedding_models(
current_model = get_current_db_embedding_model(db_session)
next_model = get_secondary_db_embedding_model(db_session)
return FullModelVersionResponse(
- current_model_name=current_model.model_name,
- secondary_model_name=next_model.model_name if next_model else None,
+ current_model=EmbeddingModelDetail.from_model(current_model),
+ secondary_model=EmbeddingModelDetail.from_model(next_model)
+ if next_model
+ else None,
)
diff --git a/web/src/app/admin/models/embedding/CustomModelForm.tsx b/web/src/app/admin/models/embedding/CustomModelForm.tsx
new file mode 100644
index 000000000..23676bc61
--- /dev/null
+++ b/web/src/app/admin/models/embedding/CustomModelForm.tsx
@@ -0,0 +1,116 @@
+import {
+ BooleanFormField,
+ TextFormField,
+} from "@/components/admin/connectors/Field";
+import { Button, Divider, Text } from "@tremor/react";
+import { Form, Formik } from "formik";
+
+import * as Yup from "yup";
+import { EmbeddingModelDescriptor } from "./embeddingModels";
+
+export function CustomModelForm({
+ onSubmit,
+}: {
+ onSubmit: (model: EmbeddingModelDescriptor) => void;
+}) {
+ return (
+
+
{
+ onSubmit({ ...values, model_dim: parseInt(values.model_dim) });
+ }}
+ >
+ {({ isSubmitting, setFieldValue }) => (
+
+ )}
+
+
+ );
+}
diff --git a/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx b/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx
index 949c5d46d..7572ac2ce 100644
--- a/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx
+++ b/web/src/app/admin/models/embedding/ModelSelectionConfirmation.tsx
@@ -1,18 +1,21 @@
import { Modal } from "@/components/Modal";
-import { Button, Text } from "@tremor/react";
+import { Button, Text, Callout } from "@tremor/react";
+import { EmbeddingModelDescriptor } from "./embeddingModels";
export function ModelSelectionConfirmaion({
selectedModel,
+ isCustom,
onConfirm,
}: {
- selectedModel: string;
+ selectedModel: EmbeddingModelDescriptor;
+ isCustom: boolean;
onConfirm: () => void;
}) {
return (
- You have selected: {selectedModel} . Are you sure you want to
- update to this new embedding model?
+ You have selected: {selectedModel.model_name} . Are you sure you
+ want to update to this new embedding model?
We will re-index all your documents in the background so you will be
@@ -25,6 +28,18 @@ export function ModelSelectionConfirmaion({
normal. If you are self-hosting, we recommend that you allocate at least
16GB of RAM to Danswer during this process.
+
+ {isCustom && (
+
+ We've detected that this is a custom-specified embedding model.
+ Since we have to download the model files before verifying the
+ configuration's correctness, we won't be able to let you
+ know if the configuration is valid until after we start
+ re-indexing your documents. If there is an issue, it will show up on
+ this page as an indexing error on this page after clicking Confirm.
+
+ )}
+
Confirm
@@ -36,10 +51,12 @@ export function ModelSelectionConfirmaion({
export function ModelSelectionConfirmaionModal({
selectedModel,
+ isCustom,
onConfirm,
onCancel,
}: {
- selectedModel: string;
+ selectedModel: EmbeddingModelDescriptor;
+ isCustom: boolean;
onConfirm: () => void;
onCancel: () => void;
}) {
@@ -48,6 +65,7 @@ export function ModelSelectionConfirmaionModal({
diff --git a/web/src/app/admin/models/embedding/ModelSelector.tsx b/web/src/app/admin/models/embedding/ModelSelector.tsx
index 4ac80785e..bc1c8b165 100644
--- a/web/src/app/admin/models/embedding/ModelSelector.tsx
+++ b/web/src/app/admin/models/embedding/ModelSelector.tsx
@@ -1,14 +1,18 @@
import { DefaultDropdown, StringOrNumberOption } from "@/components/Dropdown";
-import { Title, Text } from "@tremor/react";
-import { FullEmbeddingModelDescriptor } from "./embeddingModels";
+import { Title, Text, Divider, Card } from "@tremor/react";
+import {
+ EmbeddingModelDescriptor,
+ FullEmbeddingModelDescriptor,
+} from "./embeddingModels";
import { FiStar } from "react-icons/fi";
+import { CustomModelForm } from "./CustomModelForm";
export function ModelOption({
model,
onSelect,
}: {
model: FullEmbeddingModelDescriptor;
- onSelect?: (modelName: string) => void;
+ onSelect?: (model: EmbeddingModelDescriptor) => void;
}) {
return (
}
{model.model_name}
-
{model.description}
+
+ {model.description
+ ? model.description
+ : "Custom model—no description is available."}
+
{model.link && (
onSelect(model.model_name)}
+ onClick={() => onSelect(model)}
>
Select Model
@@ -61,17 +69,19 @@ export function ModelSelector({
setSelectedModel,
}: {
modelOptions: FullEmbeddingModelDescriptor[];
- setSelectedModel: (modelName: string) => void;
+ setSelectedModel: (model: EmbeddingModelDescriptor) => void;
}) {
return (
-
- {modelOptions.map((modelOption) => (
-
- ))}
+
+
+ {modelOptions.map((modelOption) => (
+
+ ))}
+
);
}
diff --git a/web/src/app/admin/models/embedding/ReindexingProgressTable.tsx b/web/src/app/admin/models/embedding/ReindexingProgressTable.tsx
index 3b366c192..b1f91d24b 100644
--- a/web/src/app/admin/models/embedding/ReindexingProgressTable.tsx
+++ b/web/src/app/admin/models/embedding/ReindexingProgressTable.tsx
@@ -1,14 +1,14 @@
import { PageSelector } from "@/components/PageSelector";
-import { CCPairStatus, IndexAttemptStatus } from "@/components/Status";
-import { ConnectorIndexingStatus, ValidStatuses } from "@/lib/types";
+import { IndexAttemptStatus } from "@/components/Status";
+import { ConnectorIndexingStatus } from "@/lib/types";
import {
- Button,
Table,
TableBody,
TableCell,
TableHead,
TableHeaderCell,
TableRow,
+ Text,
} from "@tremor/react";
import Link from "next/link";
import { useState } from "react";
@@ -30,6 +30,7 @@ export function ReindexingProgressTable({
Connector Name
Status
Docs Re-Indexed
+
Error Message
@@ -58,6 +59,13 @@ export function ReindexingProgressTable({
{reindexingProgress?.latest_index_attempt
?.total_docs_indexed || "-"}
+
+
+
+ {reindexingProgress.error_msg || "-"}
+
+
+
);
})}
diff --git a/web/src/app/admin/models/embedding/embeddingModels.ts b/web/src/app/admin/models/embedding/embeddingModels.ts
index 64ccfff95..7c5d09180 100644
--- a/web/src/app/admin/models/embedding/embeddingModels.ts
+++ b/web/src/app/admin/models/embedding/embeddingModels.ts
@@ -76,3 +76,12 @@ export function checkModelNameIsValid(modelName: string | undefined | null) {
}
return true;
}
+
+export function fillOutEmeddingModelDescriptor(
+ embeddingModel: EmbeddingModelDescriptor | FullEmbeddingModelDescriptor
+): FullEmbeddingModelDescriptor {
+ return {
+ ...embeddingModel,
+ description: "",
+ };
+}
diff --git a/web/src/app/admin/models/embedding/page.tsx b/web/src/app/admin/models/embedding/page.tsx
index 5f4cd1c93..0612fe2c6 100644
--- a/web/src/app/admin/models/embedding/page.tsx
+++ b/web/src/app/admin/models/embedding/page.tsx
@@ -6,7 +6,7 @@ import { KeyIcon, TrashIcon } from "@/components/icons/icons";
import { ApiKeyForm } from "@/components/openai/ApiKeyForm";
import { GEN_AI_API_KEY_URL } from "@/components/openai/constants";
import { errorHandlingFetcher, fetcher } from "@/lib/fetcher";
-import { Button, Divider, Text, Title } from "@tremor/react";
+import { Button, Card, Divider, Text, Title } from "@tremor/react";
import { FiCpu, FiPackage } from "react-icons/fi";
import useSWR, { mutate } from "swr";
import { ModelOption, ModelSelector } from "./ModelSelector";
@@ -16,17 +16,18 @@ import { ReindexingProgressTable } from "./ReindexingProgressTable";
import { Modal } from "@/components/Modal";
import {
AVAILABLE_MODELS,
- EmbeddingModelResponse,
+ EmbeddingModelDescriptor,
INVALID_OLD_MODEL,
+ fillOutEmeddingModelDescriptor,
} from "./embeddingModels";
import { ErrorCallout } from "@/components/ErrorCallout";
import { Connector, ConnectorIndexingStatus } from "@/lib/types";
import Link from "next/link";
+import { CustomModelForm } from "./CustomModelForm";
function Main() {
- const [tentativeNewEmbeddingModel, setTentativeNewEmbeddingModel] = useState<
- string | null
- >(null);
+ const [tentativeNewEmbeddingModel, setTentativeNewEmbeddingModel] =
+ useState(null);
const [isCancelling, setIsCancelling] = useState(false);
const [showAddConnectorPopup, setShowAddConnectorPopup] =
useState(false);
@@ -35,16 +36,16 @@ function Main() {
data: currentEmeddingModel,
isLoading: isLoadingCurrentModel,
error: currentEmeddingModelError,
- } = useSWR(
+ } = useSWR(
"/api/secondary-index/get-current-embedding-model",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
const {
- data: futureEmeddingModel,
+ data: futureEmbeddingModel,
isLoading: isLoadingFutureModel,
error: futureEmeddingModelError,
- } = useSWR(
+ } = useSWR(
"/api/secondary-index/get-secondary-embedding-model",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
@@ -63,24 +64,20 @@ function Main() {
{ refreshInterval: 5000 } // 5 seconds
);
- const onSelect = async (modelName: string) => {
+ const onSelect = async (model: EmbeddingModelDescriptor) => {
if (currentEmeddingModel?.model_name === INVALID_OLD_MODEL) {
- await onConfirm(modelName);
+ await onConfirm(model);
} else {
- setTentativeNewEmbeddingModel(modelName);
+ setTentativeNewEmbeddingModel(model);
}
};
- const onConfirm = async (modelName: string) => {
- const modelDescriptor = AVAILABLE_MODELS.find(
- (model) => model.model_name === modelName
- );
-
+ const onConfirm = async (model: EmbeddingModelDescriptor) => {
const response = await fetch(
"/api/secondary-index/set-new-embedding-model",
{
method: "POST",
- body: JSON.stringify(modelDescriptor),
+ body: JSON.stringify(model),
headers: {
"Content-Type": "application/json",
},
@@ -120,26 +117,33 @@ function Main() {
if (
currentEmeddingModelError ||
!currentEmeddingModel ||
- futureEmeddingModelError ||
- !futureEmeddingModel
+ futureEmeddingModelError
) {
return ;
}
const currentModelName = currentEmeddingModel.model_name;
- const currentModel = AVAILABLE_MODELS.find(
- (model) => model.model_name === currentModelName
- );
+ const currentModel =
+ AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
+ fillOutEmeddingModelDescriptor(currentEmeddingModel);
- const newModelSelection = AVAILABLE_MODELS.find(
- (model) => model.model_name === futureEmeddingModel.model_name
- );
+ const newModelSelection = futureEmbeddingModel
+ ? AVAILABLE_MODELS.find(
+ (model) => model.model_name === futureEmbeddingModel.model_name
+ ) || fillOutEmeddingModelDescriptor(futureEmbeddingModel)
+ : null;
return (
{tentativeNewEmbeddingModel && (
+ model.model_name === tentativeNewEmbeddingModel.model_name
+ ) === undefined
+ }
onConfirm={() => onConfirm(tentativeNewEmbeddingModel)}
onCancel={() => setTentativeNewEmbeddingModel(null)}
/>
@@ -243,12 +247,49 @@ function Main() {
>
)}
+
+ Below are a curated selection of quality models that we recommend
+ you choose from.
+
+
modelOption.model_name !== currentModelName
)}
setSelectedModel={onSelect}
/>
+
+
+ Alternatively, (if you know what you're doing) you can
+ specify a{" "}
+
+ SentenceTransformers
+
+ -compatible model of your choice below. The rough list of
+ supported models can be found{" "}
+
+ here
+
+ .
+
+ NOTE: not all models listed will work with Danswer, since
+ some have unique interfaces or special requirements. If in doubt,
+ reach out to the Danswer team.
+
+
+
+
+
+
+
) : (
connectors &&
@@ -272,10 +313,10 @@ function Main() {
The table below shows the re-indexing progress of all existing
- connectors. Once all connectors have been re-indexed, the new
- model will be used for all search queries. Until then, we will
- use the old model so that no downtime is necessary during this
- transition.
+ connectors. Once all connectors have been re-indexed
+ successfully, the new model will be used for all search
+ queries. Until then, we will use the old model so that no
+ downtime is necessary during this transition.
{isLoadingOngoingReIndexingStatus ? (