Update embedding interface (#2205)

* squash

* simplify interface

* some updates to typing

* cloud provider type

* update typing to be even clearer

* push local commit (squash)

* cleaner interfaces

* another quick pass

* squash

* cleaner alembic

* cleaner

* remove trailing whitespace

* add sequence

* quick circle back to double check

* update

* update naming

* update naming
This commit is contained in:
pablodanswer 2024-08-22 20:52:02 -07:00 committed by GitHub
parent 7da6d33451
commit e89dc67e5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 295 additions and 147 deletions

View File

@ -0,0 +1,163 @@
"""embedding provider by provider type
Revision ID: f17bf3b0d9f1
Revises: ee3f4b47fad5
Create Date: 2024-08-21 13:13:31.120460
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f17bf3b0d9f1"
down_revision = "ee3f4b47fad5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add provider_type column to embedding_provider
op.add_column(
"embedding_provider",
sa.Column("provider_type", sa.String(50), nullable=True),
)
# Update provider_type with existing name values
op.execute("UPDATE embedding_provider SET provider_type = UPPER(name)")
# Make provider_type not nullable
op.alter_column("embedding_provider", "provider_type", nullable=False)
# Drop the foreign key constraint in embedding_model table
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
# Drop the existing primary key constraint
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
# Create a new primary key constraint on provider_type
op.create_primary_key(
"embedding_provider_pkey", "embedding_provider", ["provider_type"]
)
# Add provider_type column to embedding_model
op.add_column(
"embedding_model",
sa.Column("provider_type", sa.String(50), nullable=True),
)
# Update provider_type for existing embedding models
op.execute(
"""
UPDATE embedding_model
SET provider_type = (
SELECT provider_type
FROM embedding_provider
WHERE embedding_provider.id = embedding_model.cloud_provider_id
)
"""
)
# Drop the old id column from embedding_provider
op.drop_column("embedding_provider", "id")
# Drop the name column from embedding_provider
op.drop_column("embedding_provider", "name")
# Drop the default_model_id column from embedding_provider
op.drop_column("embedding_provider", "default_model_id")
# Drop the old cloud_provider_id column from embedding_model
op.drop_column("embedding_model", "cloud_provider_id")
# Create the new foreign key constraint
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["provider_type"],
["provider_type"],
)
def downgrade() -> None:
# Drop the foreign key constraint in embedding_model table
op.drop_constraint(
"fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey"
)
# Add back the cloud_provider_id column to embedding_model
op.add_column(
"embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True)
)
op.add_column("embedding_provider", sa.Column("id", sa.Integer(), nullable=True))
# Assign incrementing IDs to embedding providers
op.execute(
"""
CREATE SEQUENCE IF NOT EXISTS embedding_provider_id_seq;"""
)
op.execute(
"""
UPDATE embedding_provider SET id = nextval('embedding_provider_id_seq');
"""
)
# Update cloud_provider_id based on provider_type
op.execute(
"""
UPDATE embedding_model
SET cloud_provider_id = CASE
WHEN provider_type IS NULL THEN NULL
ELSE (
SELECT id
FROM embedding_provider
WHERE embedding_provider.provider_type = embedding_model.provider_type
)
END
"""
)
# Drop the provider_type column from embedding_model
op.drop_column("embedding_model", "provider_type")
# Add back the columns to embedding_provider
op.add_column("embedding_provider", sa.Column("name", sa.String(50), nullable=True))
op.add_column(
"embedding_provider", sa.Column("default_model_id", sa.Integer(), nullable=True)
)
# Drop the existing primary key constraint on provider_type
op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary")
# Create the original primary key constraint on id
op.create_primary_key("embedding_provider_pkey", "embedding_provider", ["id"])
# Update name with existing provider_type values
op.execute(
"""
UPDATE embedding_provider
SET name = CASE
WHEN provider_type = 'OPENAI' THEN 'OpenAI'
WHEN provider_type = 'COHERE' THEN 'Cohere'
WHEN provider_type = 'GOOGLE' THEN 'Google'
WHEN provider_type = 'VOYAGE' THEN 'Voyage'
ELSE provider_type
END
"""
)
# Drop the provider_type column from embedding_provider
op.drop_column("embedding_provider", "provider_type")
# Recreate the foreign key constraint in embedding_model table
op.create_foreign_key(
"fk_embedding_model_cloud_provider",
"embedding_model",
"embedding_provider",
["cloud_provider_id"],
["id"],
)

View File

@ -378,7 +378,7 @@ def update_loop(
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if db_embedding_model.cloud_provider_id is None:
if db_embedding_model.provider_type is None:
logger.notice("Running a first inference to warm up embedding model")
warm_up_bi_encoder(
embedding_model=db_embedding_model,

View File

@ -469,7 +469,7 @@ if __name__ == "__main__":
# or the tokens have updated (set up for the first time)
with Session(get_sqlalchemy_engine()) as db_session:
embedding_model = get_current_db_embedding_model(db_session)
if embedding_model.cloud_provider_id is None:
if embedding_model.provider_type is None:
warm_up_bi_encoder(
embedding_model=embedding_model,
model_server_host=MODEL_SERVER_HOST,

View File

@ -14,32 +14,34 @@ from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.models import EmbeddingModelCreateRequest
from danswer.indexing.models import EmbeddingModelDetail
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)
from danswer.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
def create_embedding_model(
model_details: EmbeddingModelDetail,
create_embed_model_details: EmbeddingModelCreateRequest,
db_session: Session,
status: IndexModelStatus = IndexModelStatus.FUTURE,
) -> EmbeddingModel:
embedding_model = EmbeddingModel(
model_name=model_details.model_name,
model_dim=model_details.model_dim,
normalize=model_details.normalize,
query_prefix=model_details.query_prefix,
passage_prefix=model_details.passage_prefix,
model_name=create_embed_model_details.model_name,
model_dim=create_embed_model_details.model_dim,
normalize=create_embed_model_details.normalize,
query_prefix=create_embed_model_details.query_prefix,
passage_prefix=create_embed_model_details.passage_prefix,
status=status,
cloud_provider_id=model_details.cloud_provider_id,
provider_type=create_embed_model_details.provider_type,
# Every single embedding model except the initial one from migrations has this name
# The initial one from migration is called "danswer_chunk"
index_name=model_details.index_name,
index_name=create_embed_model_details.index_name,
)
db_session.add(embedding_model)
@ -48,14 +50,14 @@ def create_embedding_model(
return embedding_model
def get_model_id_from_name(
db_session: Session, embedding_provider_name: str
) -> int | None:
def get_embedding_provider_from_provider_type(
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProvider | None:
query = select(CloudEmbeddingProvider).where(
CloudEmbeddingProvider.name == embedding_provider_name
CloudEmbeddingProvider.provider_type == provider_type
)
provider = db_session.execute(query).scalars().first()
return provider.id if provider else None
return provider if provider else None
def get_current_db_embedding_provider(
@ -65,14 +67,12 @@ def get_current_db_embedding_provider(
get_current_db_embedding_model(db_session=db_session)
)
if (
current_embedding_model is None
or current_embedding_model.cloud_provider_id is None
):
if current_embedding_model is None or current_embedding_model.provider_type is None:
return None
embedding_provider = fetch_embedding_provider(
db_session=db_session, provider_id=current_embedding_model.cloud_provider_id
db_session=db_session,
provider_type=current_embedding_model.provider_type,
)
if embedding_provider is None:
raise RuntimeError("No embedding provider exists for this model.")

View File

@ -12,6 +12,7 @@ from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from shared_configs.enums import EmbeddingProvider
def update_group_llm_provider_relationships__no_commit(
@ -41,7 +42,7 @@ def upsert_cloud_embedding_provider(
) -> CloudEmbeddingProvider:
existing_provider = (
db_session.query(CloudEmbeddingProviderModel)
.filter_by(name=provider.name)
.filter_by(provider_type=provider.provider_type)
.first()
)
if existing_provider:
@ -124,11 +125,11 @@ def fetch_existing_llm_providers(
def fetch_embedding_provider(
db_session: Session, provider_id: int
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProviderModel | None:
return db_session.scalar(
select(CloudEmbeddingProviderModel).where(
CloudEmbeddingProviderModel.id == provider_id
CloudEmbeddingProviderModel.provider_type == provider_type
)
)
@ -154,11 +155,11 @@ def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider |
def remove_embedding_provider(
db_session: Session, embedding_provider_name: str
db_session: Session, provider_type: EmbeddingProvider
) -> None:
db_session.execute(
delete(CloudEmbeddingProviderModel).where(
CloudEmbeddingProviderModel.name == embedding_provider_name
CloudEmbeddingProviderModel.provider_type == provider_type
)
)

View File

@ -558,13 +558,14 @@ class EmbeddingModel(Base):
index_name: Mapped[str] = mapped_column(String)
# New field for cloud provider relationship
cloud_provider_id: Mapped[int | None] = mapped_column(
ForeignKey("embedding_provider.id")
provider_type: Mapped[EmbeddingProvider | None] = mapped_column(
ForeignKey("embedding_provider.provider_type"), nullable=True
)
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
"CloudEmbeddingProvider",
back_populates="embedding_models",
foreign_keys=[cloud_provider_id],
foreign_keys=[provider_type],
)
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
@ -588,15 +589,7 @@ class EmbeddingModel(Base):
def __repr__(self) -> str:
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
@property
def provider_type(self) -> EmbeddingProvider | None:
return (
EmbeddingProvider(self.cloud_provider.name.lower())
if self.cloud_provider is not None
else None
)
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
@property
def api_key(self) -> str | None:
@ -1073,24 +1066,18 @@ class LLMProvider(Base):
class CloudEmbeddingProvider(Base):
__tablename__ = "embedding_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
default_model_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("embedding_model.id"), nullable=True
provider_type: Mapped[EmbeddingProvider] = mapped_column(
Enum(EmbeddingProvider), primary_key=True
)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
embedding_models: Mapped[list["EmbeddingModel"]] = relationship(
"EmbeddingModel",
back_populates="cloud_provider",
foreign_keys="EmbeddingModel.cloud_provider_id",
)
default_model: Mapped["EmbeddingModel"] = relationship(
"EmbeddingModel", foreign_keys=[default_model_id]
foreign_keys="EmbeddingModel.provider_type",
)
def __repr__(self) -> str:
return f"<EmbeddingProvider(name='{self.name}')>"
return f"<EmbeddingProvider(type='{self.provider_type}')>"
class DocumentSet(Base):

View File

@ -5,6 +5,7 @@ from pydantic import BaseModel
from danswer.access.models import DocumentAccess
from danswer.connectors.models import Document
from danswer.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
from shared_configs.model_server_models import Embedding
if TYPE_CHECKING:
@ -99,9 +100,7 @@ class EmbeddingModelDetail(BaseModel):
normalize: bool
query_prefix: str | None
passage_prefix: str | None
cloud_provider_id: int | None = None
cloud_provider_name: str | None = None
index_name: str | None = None
provider_type: EmbeddingProvider | None = None
@classmethod
def from_model(
@ -114,6 +113,9 @@ class EmbeddingModelDetail(BaseModel):
normalize=embedding_model.normalize,
query_prefix=embedding_model.query_prefix,
passage_prefix=embedding_model.passage_prefix,
cloud_provider_id=embedding_model.cloud_provider_id,
index_name=embedding_model.index_name,
provider_type=embedding_model.provider_type,
)
class EmbeddingModelCreateRequest(EmbeddingModelDetail):
index_name: str

View File

@ -343,7 +343,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
setup_vespa(document_index, db_embedding_model, secondary_db_embedding_model)
logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
if db_embedding_model.cloud_provider_id is None:
if db_embedding_model.provider_type is None:
warm_up_bi_encoder(
embedding_model=db_embedding_model,
model_server_host=MODEL_SERVER_HOST,

View File

@ -17,6 +17,7 @@ from danswer.server.manage.embedding.models import TestEmbeddingRequest
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
logger = setup_logger()
@ -36,7 +37,7 @@ def test_embedding_configuration(
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
api_key=test_llm_request.api_key,
provider_type=test_llm_request.provider,
provider_type=test_llm_request.provider_type,
normalize=False,
query_prefix=None,
passage_prefix=None,
@ -66,22 +67,22 @@ def list_embedding_providers(
]
@admin_router.delete("/embedding-provider/{embedding_provider_name}")
@admin_router.delete("/embedding-provider/{provider_type}")
def delete_embedding_provider(
embedding_provider_name: str,
provider_type: EmbeddingProvider,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
embedding_provider = get_current_db_embedding_provider(db_session=db_session)
if (
embedding_provider is not None
and embedding_provider_name == embedding_provider.name
and provider_type == embedding_provider.provider_type
):
raise HTTPException(
status_code=400, detail="You can't delete a currently active model"
)
remove_embedding_provider(db_session, embedding_provider_name)
remove_embedding_provider(db_session, provider_type=provider_type)
@admin_router.put("/embedding-provider")

View File

@ -9,29 +9,24 @@ if TYPE_CHECKING:
class TestEmbeddingRequest(BaseModel):
provider: EmbeddingProvider
provider_type: EmbeddingProvider
api_key: str | None = None
class CloudEmbeddingProvider(BaseModel):
name: str
provider_type: EmbeddingProvider
api_key: str | None = None
default_model_id: int | None = None
id: int
@classmethod
def from_request(
cls, cloud_provider_model: "CloudEmbeddingProviderModel"
) -> "CloudEmbeddingProvider":
return cls(
id=cloud_provider_model.id,
name=cloud_provider_model.name,
provider_type=cloud_provider_model.provider_type,
api_key=cloud_provider_model.api_key,
default_model_id=cloud_provider_model.default_model_id,
)
class CloudEmbeddingProviderCreationRequest(BaseModel):
name: str
provider_type: EmbeddingProvider
api_key: str | None = None
default_model_id: int | None = None

View File

@ -11,7 +11,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.embedding_model import create_embedding_model
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_model_id_from_name
from danswer.db.embedding_model import get_embedding_provider_from_provider_type
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.embedding_model import update_embedding_model_status
from danswer.db.engine import get_session
@ -19,6 +19,7 @@ from danswer.db.index_attempt import expire_index_attempts
from danswer.db.models import IndexModelStatus
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import EmbeddingModelCreateRequest
from danswer.indexing.models import EmbeddingModelDetail
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.search.models import SavedSearchSettings
@ -42,30 +43,32 @@ def set_new_embedding_model(
"""Creates a new EmbeddingModel row and cancels the previous secondary indexing if any
Gives an error if the same model name is used as the current or secondary index
"""
current_model = get_current_db_embedding_model(db_session)
if embed_model_details.cloud_provider_name is not None:
cloud_id = get_model_id_from_name(
db_session, embed_model_details.cloud_provider_name
# Validate cloud provider exists
if embed_model_details.provider_type is not None:
cloud_provider = get_embedding_provider_from_provider_type(
db_session, provider_type=embed_model_details.provider_type
)
if cloud_id is None:
if cloud_provider is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No ID exists for given provider name",
detail=f"No embedding provider exists for cloud embedding type {embed_model_details.provider_type}",
)
embed_model_details.cloud_provider_id = cloud_id
current_model = get_current_db_embedding_model(db_session)
embed_model_details.index_name = (
f"danswer_chunk_{clean_model_name(embed_model_details.model_name)}"
)
# account for same model name being indexed with two different configurations
# We define index name here
index_name = f"danswer_chunk_{clean_model_name(embed_model_details.model_name)}"
if (
embed_model_details.model_name == current_model.model_name
and not current_model.index_name.endswith(ALT_INDEX_SUFFIX)
):
embed_model_details.index_name += ALT_INDEX_SUFFIX
index_name += ALT_INDEX_SUFFIX
create_embed_model_details = EmbeddingModelCreateRequest(
**embed_model_details.dict(), index_name=index_name
)
secondary_model = get_secondary_db_embedding_model(db_session)
@ -89,8 +92,7 @@ def set_new_embedding_model(
)
new_model = create_embedding_model(
model_details=embed_model_details,
db_session=db_session,
create_embed_model_details=create_embed_model_details, db_session=db_session
)
# Ensure Vespa has the new index immediately

View File

@ -3,7 +3,6 @@ import { Modal } from "@/components/Modal";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { ConnectorIndexingStatus } from "@/lib/types";
import { Button, Text, Title } from "@tremor/react";
import Link from "next/link";
import { useState } from "react";
import useSWR, { mutate } from "swr";
import { ReindexingProgressTable } from "../../../../components/embedding/ReindexingProgressTable";

View File

@ -79,7 +79,7 @@ function Main() {
(provider) =>
provider.embedding_models.map((model) => ({
...model,
cloud_provider_id: provider.id,
provider_type: provider.provider_type,
model_name: model.model_name, // Ensure model_name is set for consistency
}))
);

View File

@ -11,6 +11,7 @@ import {
INVALID_OLD_MODEL,
HostedEmbeddingModel,
EmbeddingModelDescriptor,
EmbeddingProvider,
} from "../../../components/embedding/interfaces";
import { Connector } from "@/lib/connectors/connectors";
import OpenEmbeddingPage from "./pages/OpenEmbeddingPage";
@ -28,8 +29,7 @@ import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../configuration/llm/constants";
export interface EmbeddingDetails {
api_key: string;
custom_config: any;
default_model_id?: number;
name: string;
provider_type: EmbeddingProvider;
}
export function EmbeddingModelSelection({
@ -122,28 +122,28 @@ export function EmbeddingModelSelection({
};
const clientsideAddProvider = (provider: CloudEmbeddingProvider) => {
const providerName = provider.name;
const providerType = provider.provider_type;
setNewEnabledProviders((newEnabledProviders) => [
...newEnabledProviders,
providerName,
providerType,
]);
setNewUnenabledProviders((newUnenabledProviders) =>
newUnenabledProviders.filter(
(givenProvidername) => givenProvidername != providerName
(givenProviderType) => givenProviderType != providerType
)
);
};
const clientsideRemoveProvider = (provider: CloudEmbeddingProvider) => {
const providerName = provider.name;
const providerType = provider.provider_type;
setNewEnabledProviders((newEnabledProviders) =>
newEnabledProviders.filter(
(givenProvidername) => givenProvidername != providerName
(givenProviderType) => givenProviderType != providerType
)
);
setNewUnenabledProviders((newUnenabledProviders) => [
...newUnenabledProviders,
providerName,
providerType,
]);
};
@ -191,7 +191,7 @@ export function EmbeddingModelSelection({
)}
{changeCredentialsProvider && (
<ChangeCredentialsModal
useFileUpload={changeCredentialsProvider.name == "Google"}
useFileUpload={changeCredentialsProvider.provider_type == "Google"}
onDeleted={() => {
clientsideRemoveProvider(changeCredentialsProvider);
setChangeCredentialsProvider(null);

View File

@ -74,7 +74,7 @@ export function ChangeCredentialsModal({
try {
const response = await fetch(
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.name}`,
`${EMBEDDING_PROVIDERS_ADMIN_URL}/${provider.provider_type}`,
{
method: "DELETE",
}
@ -99,19 +99,12 @@ export function ChangeCredentialsModal({
const handleSubmit = async () => {
setTestError("");
try {
const body = JSON.stringify({
api_key: apiKey,
provider: provider.name.toLowerCase().split(" ")[0],
default_model_id: provider.name,
});
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider: provider.name.toLowerCase().split(" ")[0],
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
api_key: apiKey,
}),
});
@ -125,7 +118,7 @@ export function ChangeCredentialsModal({
method: "PUT",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
name: provider.name,
provider_type: provider.provider_type.toLowerCase().split(" ")[0],
api_key: apiKey,
is_default_provider: false,
is_configured: true,
@ -151,7 +144,7 @@ export function ChangeCredentialsModal({
<Modal
width="max-w-3xl"
icon={provider.icon}
title={`Modify your ${provider.name} key`}
title={`Modify your ${provider.provider_type} key`}
onOutsideClick={onCancel}
>
<div className="mb-4">

View File

@ -15,13 +15,13 @@ export function DeleteCredentialsModal({
return (
<Modal
width="max-w-3xl"
title={`Nuke ${modelProvider.name} Credentials?`}
title={`Delete ${modelProvider.provider_type} Credentials?`}
onOutsideClick={onCancel}
>
<div className="mb-4">
<Text className="text-lg mb-2">
You&apos;re about to delete your {modelProvider.name} credentials. Are
you sure?
You&apos;re about to delete your {modelProvider.provider_type}{" "}
credentials. Are you sure?
</Text>
<Callout
title="Point of No Return"

View File

@ -19,24 +19,24 @@ export function ProviderCreationModal({
onCancel: () => void;
existingProvider?: CloudEmbeddingProvider;
}) {
const useFileUpload = selectedProvider.name == "Google";
const useFileUpload = selectedProvider.provider_type == "Google";
const [isProcessing, setIsProcessing] = useState(false);
const [errorMsg, setErrorMsg] = useState<string>("");
const [fileName, setFileName] = useState<string>("");
const initialValues = {
name: existingProvider?.name || selectedProvider.name,
provider_type:
existingProvider?.provider_type || selectedProvider.provider_type,
api_key: existingProvider?.api_key || "",
custom_config: existingProvider?.custom_config
? Object.entries(existingProvider.custom_config)
: [],
default_model_name: "",
model_id: 0,
};
const validationSchema = Yup.object({
name: Yup.string().required("Name is required"),
provider_type: Yup.string().required("Provider type is required"),
api_key: useFileUpload
? Yup.string()
: Yup.string().required("API Key is required"),
@ -76,7 +76,6 @@ export function ProviderCreationModal({
) => {
setIsProcessing(true);
setErrorMsg("");
try {
const customConfig = Object.fromEntries(values.custom_config);
@ -86,7 +85,7 @@ export function ProviderCreationModal({
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider: values.name.toLowerCase().split(" ")[0],
provider_type: values.provider_type.toLowerCase().split(" ")[0],
api_key: values.api_key,
}),
}
@ -105,6 +104,7 @@ export function ProviderCreationModal({
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
...values,
provider_type: values.provider_type.toLowerCase().split(" ")[0],
custom_config: customConfig,
is_default_provider: false,
is_configured: true,
@ -134,7 +134,7 @@ export function ProviderCreationModal({
return (
<Modal
width="max-w-3xl"
title={`Configure ${selectedProvider.name}`}
title={`Configure ${selectedProvider.provider_type}`}
onOutsideClick={onCancel}
icon={selectedProvider.icon}
>

View File

@ -39,12 +39,12 @@ export default function CloudEmbeddingPage({
React.SetStateAction<CloudEmbeddingProvider | null>
>;
}) {
function hasNameInArray(
arr: Array<{ name: string }>,
function hasProviderTypeinArray(
arr: Array<{ provider_type: string }>,
searchName: string
): boolean {
return arr.some(
(item) => item.name.toLowerCase() === searchName.toLowerCase()
(item) => item.provider_type.toLowerCase() === searchName.toLowerCase()
);
}
@ -52,10 +52,13 @@ export default function CloudEmbeddingPage({
(model) => ({
...model,
configured:
!newUnenabledProviders.includes(model.name) &&
(newEnabledProviders.includes(model.name) ||
!newUnenabledProviders.includes(model.provider_type) &&
(newEnabledProviders.includes(model.provider_type) ||
(embeddingProviderDetails &&
hasNameInArray(embeddingProviderDetails, model.name))!),
hasProviderTypeinArray(
embeddingProviderDetails,
model.provider_type
))!),
})
);
@ -71,11 +74,12 @@ export default function CloudEmbeddingPage({
<div className="gap-4 mt-2 pb-10 flex content-start flex-wrap">
{providers.map((provider) => (
<div key={provider.name} className="mt-4 w-full">
<div key={provider.provider_type} className="mt-4 w-full">
<div className="flex items-center mb-2">
{provider.icon({ size: 40 })}
<h2 className="ml-2 mt-2 text-xl font-bold">
{provider.name} {provider.name == "Cohere" && "(recommended)"}
{provider.provider_type}{" "}
{provider.provider_type == "Cohere" && "(recommended)"}
</h2>
<HoverPopup
mainContent={

View File

@ -167,12 +167,14 @@ export default function EmbeddingForm() {
const onConfirm = async () => {
let newModel: EmbeddingModelDescriptor;
if ("cloud_provider_name" in selectedProvider) {
if ("provider_type" in selectedProvider) {
// This is a CloudEmbeddingModel
newModel = {
...selectedProvider,
model_name: selectedProvider.model_name,
cloud_provider_name: selectedProvider.cloud_provider_name,
provider_type: selectedProvider.provider_type
?.toLowerCase()
.split(" ")[0],
};
} else {
// This is an EmbeddingModelDescriptor
@ -180,7 +182,7 @@ export default function EmbeddingForm() {
...selectedProvider,
model_name: selectedProvider.model_name!,
description: "",
cloud_provider_name: null,
provider_type: null,
};
}

View File

@ -9,11 +9,15 @@ import {
VoyageIcon,
} from "@/components/icons/icons";
// Cloud Provider (not needed for hosted ones)
export enum EmbeddingProvider {
OPENAI = "OpenAI",
COHERE = "Cohere",
VOYAGE = "Voyage",
GOOGLE = "Google",
}
export interface CloudEmbeddingProvider {
id: number;
name: string;
provider_type: EmbeddingProvider;
api_key?: string;
custom_config?: Record<string, string>;
docsLink?: string;
@ -37,12 +41,11 @@ export interface EmbeddingModelDescriptor {
normalize: boolean;
query_prefix: string;
passage_prefix: string;
cloud_provider_name?: string | null;
provider_type?: string | null;
description: string;
}
export interface CloudEmbeddingModel extends EmbeddingModelDescriptor {
cloud_provider_name: string | null;
pricePerMillion: number;
enabled?: boolean;
mtebScore: number;
@ -124,8 +127,7 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
{
id: 1,
name: "Cohere",
provider_type: EmbeddingProvider.COHERE,
website: "https://cohere.ai",
icon: CohereIcon,
docsLink:
@ -136,8 +138,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
costslink: "https://cohere.com/pricing",
embedding_models: [
{
provider_type: EmbeddingProvider.COHERE,
model_name: "embed-english-v3.0",
cloud_provider_name: "Cohere",
description:
"Cohere's English embedding model. Good performance for English-language tasks.",
pricePerMillion: 0.1,
@ -151,7 +153,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
},
{
model_name: "embed-english-light-v3.0",
cloud_provider_name: "Cohere",
provider_type: EmbeddingProvider.COHERE,
description:
"Cohere's lightweight English embedding model. Faster and more efficient for simpler tasks.",
pricePerMillion: 0.1,
@ -166,8 +168,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
],
},
{
id: 0,
name: "OpenAI",
provider_type: EmbeddingProvider.OPENAI,
website: "https://openai.com",
icon: OpenAIIcon,
description: "AI industry leader known for ChatGPT and DALL-E",
@ -177,8 +178,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
costslink: "https://openai.com/pricing",
embedding_models: [
{
provider_type: EmbeddingProvider.OPENAI,
model_name: "text-embedding-3-large",
cloud_provider_name: "OpenAI",
description:
"OpenAI's large embedding model. Best performance, but more expensive.",
pricePerMillion: 0.13,
@ -191,8 +192,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
enabled: false,
},
{
provider_type: EmbeddingProvider.OPENAI,
model_name: "text-embedding-3-small",
cloud_provider_name: "OpenAI",
model_dim: 1536,
normalize: false,
query_prefix: "",
@ -208,8 +209,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
},
{
id: 2,
name: "Google",
provider_type: EmbeddingProvider.GOOGLE,
website: "https://ai.google",
icon: GoogleIcon,
docsLink:
@ -220,7 +220,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
costslink: "https://cloud.google.com/vertex-ai/pricing",
embedding_models: [
{
cloud_provider_name: "Google",
provider_type: EmbeddingProvider.GOOGLE,
model_name: "text-embedding-004",
description: "Google's most recent text embedding model.",
pricePerMillion: 0.025,
@ -233,7 +233,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "",
},
{
cloud_provider_name: "Google",
provider_type: EmbeddingProvider.GOOGLE,
model_name: "textembedding-gecko@003",
description: "Google's Gecko embedding model. Powerful and efficient.",
pricePerMillion: 0.025,
@ -248,8 +248,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
],
},
{
id: 3,
name: "Voyage",
provider_type: EmbeddingProvider.VOYAGE,
website: "https://www.voyageai.com",
icon: VoyageIcon,
description: "Advanced NLP research startup born from Stanford AI Labs",
@ -259,7 +258,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
costslink: "https://www.voyageai.com/pricing",
embedding_models: [
{
cloud_provider_name: "Voyage",
provider_type: EmbeddingProvider.VOYAGE,
model_name: "voyage-large-2-instruct",
description:
"Voyage's large embedding model. High performance with instruction fine-tuning.",
@ -273,7 +272,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
passage_prefix: "",
},
{
cloud_provider_name: "Voyage",
provider_type: EmbeddingProvider.VOYAGE,
model_name: "voyage-light-2-instruct",
description:
"Voyage's lightweight embedding model. Good balance of performance and efficiency.",