Add ability to specify custom embedding models

This commit is contained in:
Weves 2024-03-26 23:12:04 -07:00 committed by Chris Weaver
parent fbff5b5784
commit 5a967322fd
11 changed files with 289 additions and 78 deletions

View File

@ -29,7 +29,9 @@ from danswer.db.embedding_model import update_embedding_model_status
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts
from danswer.db.index_attempt import (
count_unique_cc_pairs_with_successful_index_attempts,
)
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
@ -365,9 +367,9 @@ def kickoff_indexing_jobs(
def check_index_swap(db_session: Session) -> None:
"""Get count of cc-pairs and count of index_attempts for the new model grouped by
connector + credential, if it's the same, then assume new index is done building.
This does not take into consideration if the attempt failed or not"""
"""Get count of cc-pairs and count of successful index_attempts for the
new model grouped by connector + credential, if it's the same, then assume
new index is done building. If so, swap the indices and expire the old one."""
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = len(all_cc_pairs) - 1
@ -376,7 +378,7 @@ def check_index_swap(db_session: Session) -> None:
if not embedding_model:
return
unique_cc_indexings = count_unique_cc_pairs_with_index_attempts(
unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts(
embedding_model_id=embedding_model.id, db_session=db_session
)

View File

@ -291,7 +291,7 @@ def cancel_indexing_attempts_past_model(
db_session.commit()
def count_unique_cc_pairs_with_index_attempts(
def count_unique_cc_pairs_with_successful_index_attempts(
embedding_model_id: int | None,
db_session: Session,
) -> int:
@ -299,12 +299,7 @@ def count_unique_cc_pairs_with_index_attempts(
db_session.query(IndexAttempt.connector_id, IndexAttempt.credential_id)
.filter(
IndexAttempt.embedding_model_id == embedding_model_id,
# Should not be able to hang since indexing jobs expire after a limit
# It will then be marked failed, and the next cycle it will be in a completed state
or_(
IndexAttempt.status == IndexingStatus.SUCCESS,
IndexAttempt.status == IndexingStatus.FAILED,
),
IndexAttempt.status == IndexingStatus.SUCCESS,
)
.distinct()
.count()

View File

@ -1,6 +1,7 @@
from dataclasses import dataclass
from dataclasses import fields
from datetime import datetime
from typing import TYPE_CHECKING
from pydantic import BaseModel
@ -9,6 +10,9 @@ from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.utils.logger import setup_logger
if TYPE_CHECKING:
from danswer.db.models import EmbeddingModel
logger = setup_logger()
@ -130,3 +134,13 @@ class EmbeddingModelDetail(BaseModel):
normalize: bool
query_prefix: str | None
passage_prefix: str | None
@classmethod
def from_model(cls, embedding_model: "EmbeddingModel") -> "EmbeddingModelDetail":
return cls(
model_name=embedding_model.model_name,
model_dim=embedding_model.model_dim,
normalize=embedding_model.normalize,
query_prefix=embedding_model.query_prefix,
passage_prefix=embedding_model.passage_prefix,
)

View File

@ -11,6 +11,7 @@ from danswer.db.models import AllowedAnswerFilters
from danswer.db.models import ChannelConfig
from danswer.db.models import SlackBotConfig as SlackBotConfigModel
from danswer.db.models import SlackBotResponseType
from danswer.indexing.models import EmbeddingModelDetail
from danswer.server.features.persona.models import PersonaSnapshot
@ -125,10 +126,6 @@ class SlackBotConfig(BaseModel):
)
class ModelVersionResponse(BaseModel):
model_name: str | None # None only applicable to secondary index
class FullModelVersionResponse(BaseModel):
current_model_name: str
secondary_model_name: str | None
current_model: EmbeddingModelDetail
secondary_model: EmbeddingModelDetail | None

View File

@ -20,7 +20,6 @@ from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import EmbeddingModelDetail
from danswer.server.manage.models import FullModelVersionResponse
from danswer.server.manage.models import ModelVersionResponse
from danswer.server.models import IdReturn
from danswer.utils.logger import setup_logger
@ -115,21 +114,21 @@ def cancel_new_embedding(
def get_current_embedding_model(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ModelVersionResponse:
) -> EmbeddingModelDetail:
current_model = get_current_db_embedding_model(db_session)
return ModelVersionResponse(model_name=current_model.model_name)
return EmbeddingModelDetail.from_model(current_model)
@router.get("/get-secondary-embedding-model")
def get_secondary_embedding_model(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ModelVersionResponse:
) -> EmbeddingModelDetail | None:
next_model = get_secondary_db_embedding_model(db_session)
if not next_model:
return None
return ModelVersionResponse(
model_name=next_model.model_name if next_model else None
)
return EmbeddingModelDetail.from_model(next_model)
@router.get("/get-embedding-models")
@ -140,6 +139,8 @@ def get_embedding_models(
current_model = get_current_db_embedding_model(db_session)
next_model = get_secondary_db_embedding_model(db_session)
return FullModelVersionResponse(
current_model_name=current_model.model_name,
secondary_model_name=next_model.model_name if next_model else None,
current_model=EmbeddingModelDetail.from_model(current_model),
secondary_model=EmbeddingModelDetail.from_model(next_model)
if next_model
else None,
)

View File

@ -0,0 +1,116 @@
import {
BooleanFormField,
TextFormField,
} from "@/components/admin/connectors/Field";
import { Button, Divider, Text } from "@tremor/react";
import { Form, Formik } from "formik";
import * as Yup from "yup";
import { EmbeddingModelDescriptor } from "./embeddingModels";
export function CustomModelForm({
onSubmit,
}: {
onSubmit: (model: EmbeddingModelDescriptor) => void;
}) {
return (
<div>
<Formik
initialValues={{
model_name: "",
model_dim: "",
query_prefix: "",
passage_prefix: "",
normalize: true,
}}
validationSchema={Yup.object().shape({
model_name: Yup.string().required(
"Please enter the name of the Embedding Model"
),
model_dim: Yup.number().required(
"Please enter the dimensionality of the embeddings generated by the model"
),
query_prefix: Yup.string(),
passage_prefix: Yup.string(),
normalize: Yup.boolean().required(),
})}
onSubmit={async (values, formikHelpers) => {
onSubmit({ ...values, model_dim: parseInt(values.model_dim) });
}}
>
{({ isSubmitting, setFieldValue }) => (
<Form>
<TextFormField
name="model_name"
label="Name:"
subtext="The name of the model on Hugging Face"
placeholder="E.g. 'intfloat/e5-base-v2'"
autoCompleteDisabled={true}
/>
<TextFormField
name="model_dim"
label="Model Dimension:"
subtext="The dimensionality of the embeddings generated by the model"
placeholder="E.g. '768'"
autoCompleteDisabled={true}
onChange={(e) => {
const value = e.target.value;
// Allow only integer values
if (value === "" || /^[0-9]+$/.test(value)) {
setFieldValue("model_dim", value);
}
}}
/>
<TextFormField
name="query_prefix"
label="[Optional] Query Prefix:"
subtext={
<>
The prefix specified by the model creators which should be
prepended to <i>queries</i> before passing them to the model.
Many models do not have this, in which case this should be
left empty.
</>
}
placeholder="E.g. 'query: '"
autoCompleteDisabled={true}
/>
<TextFormField
name="passage_prefix"
label="[Optional] Passage Prefix:"
subtext={
<>
The prefix specified by the model creators which should be
prepended to <i>passages</i> before passing them to the model.
Many models do not have this, in which case this should be
left empty.
</>
}
placeholder="E.g. 'passage: '"
autoCompleteDisabled={true}
/>
<BooleanFormField
name="normalize"
label="Normalize Embeddings"
subtext="Whether or not to normalize the embeddings generated by the model. When in doubt, leave this checked."
/>
<div className="flex mt-6">
<Button
type="submit"
disabled={isSubmitting}
className="w-64 mx-auto"
>
Choose
</Button>
</div>
</Form>
)}
</Formik>
</div>
);
}

View File

@ -1,18 +1,21 @@
import { Modal } from "@/components/Modal";
import { Button, Text } from "@tremor/react";
import { Button, Text, Callout } from "@tremor/react";
import { EmbeddingModelDescriptor } from "./embeddingModels";
export function ModelSelectionConfirmaion({
selectedModel,
isCustom,
onConfirm,
}: {
selectedModel: string;
selectedModel: EmbeddingModelDescriptor;
isCustom: boolean;
onConfirm: () => void;
}) {
return (
<div className="mb-4">
<Text className="text-lg mb-4">
You have selected: <b>{selectedModel}</b>. Are you sure you want to
update to this new embedding model?
You have selected: <b>{selectedModel.model_name}</b>. Are you sure you
want to update to this new embedding model?
</Text>
<Text className="text-lg mb-2">
We will re-index all your documents in the background so you will be
@ -25,6 +28,18 @@ export function ModelSelectionConfirmaion({
normal. If you are self-hosting, we recommend that you allocate at least
16GB of RAM to Danswer during this process.
</Text>
{isCustom && (
<Callout title="IMPORTANT" color="yellow" className="mt-4">
We&apos;ve detected that this is a custom-specified embedding model.
Since we have to download the model files before verifying the
configuration&apos;s correctness, we won&apos;t be able to let you
know if the configuration is valid until <b>after</b> we start
re-indexing your documents. If there is an issue, it will show up on
this page as an indexing error on this page after clicking Confirm.
</Callout>
)}
<div className="flex mt-8">
<Button className="mx-auto" color="green" onClick={onConfirm}>
Confirm
@ -36,10 +51,12 @@ export function ModelSelectionConfirmaion({
export function ModelSelectionConfirmaionModal({
selectedModel,
isCustom,
onConfirm,
onCancel,
}: {
selectedModel: string;
selectedModel: EmbeddingModelDescriptor;
isCustom: boolean;
onConfirm: () => void;
onCancel: () => void;
}) {
@ -48,6 +65,7 @@ export function ModelSelectionConfirmaionModal({
<div>
<ModelSelectionConfirmaion
selectedModel={selectedModel}
isCustom={isCustom}
onConfirm={onConfirm}
/>
</div>

View File

@ -1,14 +1,18 @@
import { DefaultDropdown, StringOrNumberOption } from "@/components/Dropdown";
import { Title, Text } from "@tremor/react";
import { FullEmbeddingModelDescriptor } from "./embeddingModels";
import { Title, Text, Divider, Card } from "@tremor/react";
import {
EmbeddingModelDescriptor,
FullEmbeddingModelDescriptor,
} from "./embeddingModels";
import { FiStar } from "react-icons/fi";
import { CustomModelForm } from "./CustomModelForm";
export function ModelOption({
model,
onSelect,
}: {
model: FullEmbeddingModelDescriptor;
onSelect?: (modelName: string) => void;
onSelect?: (model: EmbeddingModelDescriptor) => void;
}) {
return (
<div
@ -20,7 +24,11 @@ export function ModelOption({
{model.isDefault && <FiStar className="my-auto mr-1 text-accent" />}
{model.model_name}
</div>
<div className="text-sm mt-1 mx-1">{model.description}</div>
<div className="text-sm mt-1 mx-1">
{model.description
? model.description
: "Custom model—no description is available."}
</div>
{model.link && (
<a
target="_blank"
@ -47,7 +55,7 @@ export function ModelOption({
hover:bg-hover
text-sm
mt-auto`}
onClick={() => onSelect(model.model_name)}
onClick={() => onSelect(model)}
>
Select Model
</div>
@ -61,17 +69,19 @@ export function ModelSelector({
setSelectedModel,
}: {
modelOptions: FullEmbeddingModelDescriptor[];
setSelectedModel: (modelName: string) => void;
setSelectedModel: (model: EmbeddingModelDescriptor) => void;
}) {
return (
<div className="flex flex-wrap gap-4">
{modelOptions.map((modelOption) => (
<ModelOption
key={modelOption.model_name}
model={modelOption}
onSelect={setSelectedModel}
/>
))}
<div>
<div className="flex flex-wrap gap-4">
{modelOptions.map((modelOption) => (
<ModelOption
key={modelOption.model_name}
model={modelOption}
onSelect={setSelectedModel}
/>
))}
</div>
</div>
);
}

View File

@ -1,14 +1,14 @@
import { PageSelector } from "@/components/PageSelector";
import { CCPairStatus, IndexAttemptStatus } from "@/components/Status";
import { ConnectorIndexingStatus, ValidStatuses } from "@/lib/types";
import { IndexAttemptStatus } from "@/components/Status";
import { ConnectorIndexingStatus } from "@/lib/types";
import {
Button,
Table,
TableBody,
TableCell,
TableHead,
TableHeaderCell,
TableRow,
Text,
} from "@tremor/react";
import Link from "next/link";
import { useState } from "react";
@ -30,6 +30,7 @@ export function ReindexingProgressTable({
<TableHeaderCell>Connector Name</TableHeaderCell>
<TableHeaderCell>Status</TableHeaderCell>
<TableHeaderCell>Docs Re-Indexed</TableHeaderCell>
<TableHeaderCell>Error Message</TableHeaderCell>
</TableRow>
</TableHead>
<TableBody>
@ -58,6 +59,13 @@ export function ReindexingProgressTable({
{reindexingProgress?.latest_index_attempt
?.total_docs_indexed || "-"}
</TableCell>
<TableCell>
<div>
<Text className="flex flex-wrap whitespace-normal">
{reindexingProgress.error_msg || "-"}
</Text>
</div>
</TableCell>
</TableRow>
);
})}

View File

@ -76,3 +76,12 @@ export function checkModelNameIsValid(modelName: string | undefined | null) {
}
return true;
}
export function fillOutEmeddingModelDescriptor(
embeddingModel: EmbeddingModelDescriptor | FullEmbeddingModelDescriptor
): FullEmbeddingModelDescriptor {
return {
...embeddingModel,
description: "",
};
}

View File

@ -6,7 +6,7 @@ import { KeyIcon, TrashIcon } from "@/components/icons/icons";
import { ApiKeyForm } from "@/components/openai/ApiKeyForm";
import { GEN_AI_API_KEY_URL } from "@/components/openai/constants";
import { errorHandlingFetcher, fetcher } from "@/lib/fetcher";
import { Button, Divider, Text, Title } from "@tremor/react";
import { Button, Card, Divider, Text, Title } from "@tremor/react";
import { FiCpu, FiPackage } from "react-icons/fi";
import useSWR, { mutate } from "swr";
import { ModelOption, ModelSelector } from "./ModelSelector";
@ -16,17 +16,18 @@ import { ReindexingProgressTable } from "./ReindexingProgressTable";
import { Modal } from "@/components/Modal";
import {
AVAILABLE_MODELS,
EmbeddingModelResponse,
EmbeddingModelDescriptor,
INVALID_OLD_MODEL,
fillOutEmeddingModelDescriptor,
} from "./embeddingModels";
import { ErrorCallout } from "@/components/ErrorCallout";
import { Connector, ConnectorIndexingStatus } from "@/lib/types";
import Link from "next/link";
import { CustomModelForm } from "./CustomModelForm";
function Main() {
const [tentativeNewEmbeddingModel, setTentativeNewEmbeddingModel] = useState<
string | null
>(null);
const [tentativeNewEmbeddingModel, setTentativeNewEmbeddingModel] =
useState<EmbeddingModelDescriptor | null>(null);
const [isCancelling, setIsCancelling] = useState<boolean>(false);
const [showAddConnectorPopup, setShowAddConnectorPopup] =
useState<boolean>(false);
@ -35,16 +36,16 @@ function Main() {
data: currentEmeddingModel,
isLoading: isLoadingCurrentModel,
error: currentEmeddingModelError,
} = useSWR<EmbeddingModelResponse>(
} = useSWR<EmbeddingModelDescriptor>(
"/api/secondary-index/get-current-embedding-model",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
const {
data: futureEmeddingModel,
data: futureEmbeddingModel,
isLoading: isLoadingFutureModel,
error: futureEmeddingModelError,
} = useSWR<EmbeddingModelResponse>(
} = useSWR<EmbeddingModelDescriptor | null>(
"/api/secondary-index/get-secondary-embedding-model",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
@ -63,24 +64,20 @@ function Main() {
{ refreshInterval: 5000 } // 5 seconds
);
const onSelect = async (modelName: string) => {
const onSelect = async (model: EmbeddingModelDescriptor) => {
if (currentEmeddingModel?.model_name === INVALID_OLD_MODEL) {
await onConfirm(modelName);
await onConfirm(model);
} else {
setTentativeNewEmbeddingModel(modelName);
setTentativeNewEmbeddingModel(model);
}
};
const onConfirm = async (modelName: string) => {
const modelDescriptor = AVAILABLE_MODELS.find(
(model) => model.model_name === modelName
);
const onConfirm = async (model: EmbeddingModelDescriptor) => {
const response = await fetch(
"/api/secondary-index/set-new-embedding-model",
{
method: "POST",
body: JSON.stringify(modelDescriptor),
body: JSON.stringify(model),
headers: {
"Content-Type": "application/json",
},
@ -120,26 +117,33 @@ function Main() {
if (
currentEmeddingModelError ||
!currentEmeddingModel ||
futureEmeddingModelError ||
!futureEmeddingModel
futureEmeddingModelError
) {
return <ErrorCallout errorTitle="Failed to fetch embedding model status" />;
}
const currentModelName = currentEmeddingModel.model_name;
const currentModel = AVAILABLE_MODELS.find(
(model) => model.model_name === currentModelName
);
const currentModel =
AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
fillOutEmeddingModelDescriptor(currentEmeddingModel);
const newModelSelection = AVAILABLE_MODELS.find(
(model) => model.model_name === futureEmeddingModel.model_name
);
const newModelSelection = futureEmbeddingModel
? AVAILABLE_MODELS.find(
(model) => model.model_name === futureEmbeddingModel.model_name
) || fillOutEmeddingModelDescriptor(futureEmbeddingModel)
: null;
return (
<div>
{tentativeNewEmbeddingModel && (
<ModelSelectionConfirmaionModal
selectedModel={tentativeNewEmbeddingModel}
isCustom={
AVAILABLE_MODELS.find(
(model) =>
model.model_name === tentativeNewEmbeddingModel.model_name
) === undefined
}
onConfirm={() => onConfirm(tentativeNewEmbeddingModel)}
onCancel={() => setTentativeNewEmbeddingModel(null)}
/>
@ -243,12 +247,49 @@ function Main() {
</>
)}
<Text className="mb-4">
Below are a curated selection of quality models that we recommend
you choose from.
</Text>
<ModelSelector
modelOptions={AVAILABLE_MODELS.filter(
(modelOption) => modelOption.model_name !== currentModelName
)}
setSelectedModel={onSelect}
/>
<Text className="mt-6">
Alternatively, (if you know what you&apos;re doing) you can
specify a{" "}
<a
target="_blank"
href="https://www.sbert.net/"
className="text-link"
>
SentenceTransformers
</a>
-compatible model of your choice below. The rough list of
supported models can be found{" "}
<a
target="_blank"
href="https://huggingface.co/models?library=sentence-transformers&sort=trending"
className="text-link"
>
here
</a>
.
<br />
<b>NOTE:</b> not all models listed will work with Danswer, since
some have unique interfaces or special requirements. If in doubt,
reach out to the Danswer team.
</Text>
<div className="w-full flex">
<Card className="mt-4 2xl:w-4/6 mx-auto">
<CustomModelForm onSubmit={onSelect} />
</Card>
</div>
</div>
) : (
connectors &&
@ -272,10 +313,10 @@ function Main() {
<Text className="my-4">
The table below shows the re-indexing progress of all existing
connectors. Once all connectors have been re-indexed, the new
model will be used for all search queries. Until then, we will
use the old model so that no downtime is necessary during this
transition.
connectors. Once all connectors have been re-indexed
successfully, the new model will be used for all search
queries. Until then, we will use the old model so that no
downtime is necessary during this transition.
</Text>
{isLoadingOngoingReIndexingStatus ? (