mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-01 00:18:18 +02:00
Simpler azure embedding (#2751)
* functional but janky * nit * adapt for azure * nit * minor updates * nits * nit * nit * ensure access to litellm * k
This commit is contained in:
parent
02cc211e91
commit
e022e77b6d
@ -0,0 +1,30 @@
|
||||
"""add api_version and deployment_name to search settings
|
||||
|
||||
Revision ID: 5d12a446f5c0
|
||||
Revises: e4334d5b33ba
|
||||
Create Date: 2024-10-08 15:56:07.975636
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5d12a446f5c0"
|
||||
down_revision = "e4334d5b33ba"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("api_version", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"embedding_provider", sa.Column("deployment_name", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("embedding_provider", "deployment_name")
|
||||
op.drop_column("embedding_provider", "api_version")
|
@ -53,7 +53,6 @@ MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
|
||||
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
@ -615,6 +615,7 @@ class SearchSettings(Base):
|
||||
normalize: Mapped[bool] = mapped_column(Boolean)
|
||||
query_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
passage_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
status: Mapped[IndexModelStatus] = mapped_column(
|
||||
Enum(IndexModelStatus, native_enum=False)
|
||||
)
|
||||
@ -670,6 +671,20 @@ class SearchSettings(Base):
|
||||
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
|
||||
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
|
||||
|
||||
@property
|
||||
def api_version(self) -> str | None:
|
||||
return (
|
||||
self.cloud_provider.api_version if self.cloud_provider is not None else None
|
||||
)
|
||||
|
||||
@property
|
||||
def deployment_name(self) -> str | None:
|
||||
return (
|
||||
self.cloud_provider.deployment_name
|
||||
if self.cloud_provider is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def api_url(self) -> str | None:
|
||||
return self.cloud_provider.api_url if self.cloud_provider is not None else None
|
||||
@ -1164,6 +1179,9 @@ class CloudEmbeddingProvider(Base):
|
||||
)
|
||||
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
search_settings: Mapped[list["SearchSettings"]] = relationship(
|
||||
"SearchSettings",
|
||||
back_populates="cloud_provider",
|
||||
|
@ -32,6 +32,8 @@ class IndexingEmbedder(ABC):
|
||||
provider_type: EmbeddingProvider | None,
|
||||
api_key: str | None,
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
deployment_name: str | None,
|
||||
heartbeat: Heartbeat | None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
@ -41,6 +43,8 @@ class IndexingEmbedder(ABC):
|
||||
self.provider_type = provider_type
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
self.api_version = api_version
|
||||
self.deployment_name = deployment_name
|
||||
|
||||
self.embedding_model = EmbeddingModel(
|
||||
model_name=model_name,
|
||||
@ -50,6 +54,8 @@ class IndexingEmbedder(ABC):
|
||||
api_key=api_key,
|
||||
provider_type=provider_type,
|
||||
api_url=api_url,
|
||||
api_version=api_version,
|
||||
deployment_name=deployment_name,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
@ -75,6 +81,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type: EmbeddingProvider | None = None,
|
||||
api_key: str | None = None,
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@ -85,6 +93,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type,
|
||||
api_key,
|
||||
api_url,
|
||||
api_version,
|
||||
deployment_name,
|
||||
heartbeat,
|
||||
)
|
||||
|
||||
@ -193,5 +203,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
api_url=search_settings.api_url,
|
||||
api_version=search_settings.api_version,
|
||||
deployment_name=search_settings.deployment_name,
|
||||
heartbeat=heartbeat,
|
||||
)
|
||||
|
@ -97,6 +97,8 @@ class EmbeddingModel:
|
||||
provider_type: EmbeddingProvider | None,
|
||||
retrim_content: bool = False,
|
||||
heartbeat: Heartbeat | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
@ -106,6 +108,8 @@ class EmbeddingModel:
|
||||
self.model_name = model_name
|
||||
self.retrim_content = retrim_content
|
||||
self.api_url = api_url
|
||||
self.api_version = api_version
|
||||
self.deployment_name = deployment_name
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_name=model_name, provider_type=provider_type
|
||||
)
|
||||
@ -157,6 +161,8 @@ class EmbeddingModel:
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
api_version=self.api_version,
|
||||
deployment_name=self.deployment_name,
|
||||
max_context_length=max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
@ -239,6 +245,8 @@ class EmbeddingModel:
|
||||
provider_type=search_settings.provider_type,
|
||||
api_url=search_settings.api_url,
|
||||
retrim_content=retrim_content,
|
||||
api_version=search_settings.api_version,
|
||||
deployment_name=search_settings.deployment_name,
|
||||
)
|
||||
|
||||
|
||||
|
@ -43,6 +43,8 @@ def test_embedding_configuration(
|
||||
api_url=test_llm_request.api_url,
|
||||
provider_type=test_llm_request.provider_type,
|
||||
model_name=test_llm_request.model_name,
|
||||
api_version=test_llm_request.api_version,
|
||||
deployment_name=test_llm_request.deployment_name,
|
||||
normalize=False,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
|
@ -17,6 +17,8 @@ class TestEmbeddingRequest(BaseModel):
|
||||
api_key: str | None = None
|
||||
api_url: str | None = None
|
||||
model_name: str | None = None
|
||||
api_version: str | None = None
|
||||
deployment_name: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
@ -26,6 +28,8 @@ class CloudEmbeddingProvider(BaseModel):
|
||||
provider_type: EmbeddingProvider
|
||||
api_key: str | None = None
|
||||
api_url: str | None = None
|
||||
api_version: str | None = None
|
||||
deployment_name: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
@ -35,6 +39,8 @@ class CloudEmbeddingProvider(BaseModel):
|
||||
provider_type=cloud_provider_model.provider_type,
|
||||
api_key=cloud_provider_model.api_key,
|
||||
api_url=cloud_provider_model.api_url,
|
||||
api_version=cloud_provider_model.api_version,
|
||||
deployment_name=cloud_provider_model.deployment_name,
|
||||
)
|
||||
|
||||
|
||||
@ -42,3 +48,5 @@ class CloudEmbeddingProviderCreationRequest(BaseModel):
|
||||
provider_type: EmbeddingProvider
|
||||
api_key: str | None = None
|
||||
api_url: str | None = None
|
||||
api_version: str | None = None
|
||||
deployment_name: str | None = None
|
||||
|
@ -10,6 +10,7 @@ from cohere import Client as CohereClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from litellm import embedding
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
@ -54,7 +55,11 @@ _COHERE_MAX_INPUT_LEN = 96
|
||||
|
||||
|
||||
def _initialize_client(
|
||||
api_key: str, provider: EmbeddingProvider, model: str | None = None
|
||||
api_key: str,
|
||||
provider: EmbeddingProvider,
|
||||
model: str | None = None,
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
) -> Any:
|
||||
if provider == EmbeddingProvider.OPENAI:
|
||||
return openai.OpenAI(api_key=api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
|
||||
@ -69,6 +74,8 @@ def _initialize_client(
|
||||
project_id = json.loads(api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
return TextEmbeddingModel.from_pretrained(model or DEFAULT_VERTEX_MODEL)
|
||||
elif provider == EmbeddingProvider.AZURE:
|
||||
return {"api_key": api_key, "api_url": api_url, "api_version": api_version}
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
@ -78,11 +85,15 @@ class CloudEmbedding:
|
||||
self,
|
||||
api_key: str,
|
||||
provider: EmbeddingProvider,
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
# Only for Google as is needed on client setup
|
||||
model: str | None = None,
|
||||
) -> None:
|
||||
self.provider = provider
|
||||
self.client = _initialize_client(api_key, self.provider, model)
|
||||
self.client = _initialize_client(
|
||||
api_key, self.provider, model, api_url, api_version
|
||||
)
|
||||
|
||||
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
|
||||
if not model:
|
||||
@ -144,6 +155,18 @@ class CloudEmbedding:
|
||||
)
|
||||
return response.embeddings
|
||||
|
||||
def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]:
|
||||
response = embedding(
|
||||
model=model,
|
||||
input=texts,
|
||||
api_key=self.client["api_key"],
|
||||
api_base=self.client["api_url"],
|
||||
api_version=self.client["api_version"],
|
||||
)
|
||||
embeddings = [embedding["embedding"] for embedding in response.data]
|
||||
|
||||
return embeddings
|
||||
|
||||
def _embed_vertex(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
@ -169,10 +192,13 @@ class CloudEmbedding:
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
) -> list[Embedding]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
@ -190,10 +216,14 @@ class CloudEmbedding:
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
api_key: str, provider: EmbeddingProvider, model: str | None = None
|
||||
api_key: str,
|
||||
provider: EmbeddingProvider,
|
||||
model: str | None = None,
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
) -> "CloudEmbedding":
|
||||
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||
return CloudEmbedding(api_key, provider, model)
|
||||
return CloudEmbedding(api_key, provider, model, api_url, api_version)
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
@ -260,12 +290,14 @@ def embed_text(
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None,
|
||||
deployment_name: str | None,
|
||||
max_context_length: int,
|
||||
normalize_embeddings: bool,
|
||||
api_key: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
prefix: str | None,
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
) -> list[Embedding]:
|
||||
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
|
||||
|
||||
@ -307,11 +339,16 @@ def embed_text(
|
||||
)
|
||||
|
||||
cloud_model = CloudEmbedding(
|
||||
api_key=api_key, provider=provider_type, model=model_name
|
||||
api_key=api_key,
|
||||
provider=provider_type,
|
||||
model=model_name,
|
||||
api_url=api_url,
|
||||
api_version=api_version,
|
||||
)
|
||||
embeddings = cloud_model.embed(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
@ -405,12 +442,14 @@ async def process_embed_request(
|
||||
embeddings = embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
deployment_name=embed_request.deployment_name,
|
||||
max_context_length=embed_request.max_context_length,
|
||||
normalize_embeddings=embed_request.normalize_embeddings,
|
||||
api_key=embed_request.api_key,
|
||||
provider_type=embed_request.provider_type,
|
||||
text_type=embed_request.text_type,
|
||||
api_url=embed_request.api_url,
|
||||
api_version=embed_request.api_version,
|
||||
prefix=prefix,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
|
@ -12,3 +12,4 @@ torch==2.2.0
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
litellm==1.48.7
|
||||
|
@ -7,6 +7,7 @@ class EmbeddingProvider(str, Enum):
|
||||
VOYAGE = "voyage"
|
||||
GOOGLE = "google"
|
||||
LITELLM = "litellm"
|
||||
AZURE = "azure"
|
||||
|
||||
|
||||
class RerankerProvider(str, Enum):
|
||||
|
@ -20,6 +20,7 @@ class EmbedRequest(BaseModel):
|
||||
texts: list[str]
|
||||
# Can be none for cloud embedding model requests, error handling logic exists for other cases
|
||||
model_name: str | None = None
|
||||
deployment_name: str | None = None
|
||||
max_context_length: int
|
||||
normalize_embeddings: bool
|
||||
api_key: str | None = None
|
||||
@ -28,7 +29,7 @@ class EmbedRequest(BaseModel):
|
||||
manual_query_prefix: str | None = None
|
||||
manual_passage_prefix: str | None = None
|
||||
api_url: str | None = None
|
||||
|
||||
api_version: str | None = None
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
|
@ -27,10 +27,13 @@ import {
|
||||
EMBEDDING_MODELS_ADMIN_URL,
|
||||
EMBEDDING_PROVIDERS_ADMIN_URL,
|
||||
} from "../configuration/llm/constants";
|
||||
import { AdvancedSearchConfiguration } from "./interfaces";
|
||||
|
||||
export interface EmbeddingDetails {
|
||||
api_key?: string;
|
||||
api_url?: string;
|
||||
api_version?: string;
|
||||
deployment_name?: string;
|
||||
custom_config: any;
|
||||
provider_type: EmbeddingProvider;
|
||||
}
|
||||
@ -41,6 +44,8 @@ export function EmbeddingModelSelection({
|
||||
updateSelectedProvider,
|
||||
modelTab,
|
||||
setModelTab,
|
||||
updateCurrentModel,
|
||||
advancedEmbeddingDetails,
|
||||
}: {
|
||||
modelTab: "open" | "cloud" | null;
|
||||
setModelTab: Dispatch<SetStateAction<"open" | "cloud" | null>>;
|
||||
@ -49,6 +54,11 @@ export function EmbeddingModelSelection({
|
||||
updateSelectedProvider: (
|
||||
model: CloudEmbeddingModel | HostedEmbeddingModel
|
||||
) => void;
|
||||
updateCurrentModel: (
|
||||
newModel: string,
|
||||
provider_type: EmbeddingProvider
|
||||
) => void;
|
||||
advancedEmbeddingDetails: AdvancedSearchConfiguration;
|
||||
}) {
|
||||
// Cloud Provider based modals
|
||||
const [showTentativeProvider, setShowTentativeProvider] =
|
||||
@ -72,12 +82,6 @@ export function EmbeddingModelSelection({
|
||||
const [showTentativeOpenProvider, setShowTentativeOpenProvider] =
|
||||
useState<HostedEmbeddingModel | null>(null);
|
||||
|
||||
// Enabled / unenabled providers
|
||||
const [newEnabledProviders, setNewEnabledProviders] = useState<string[]>([]);
|
||||
const [newUnenabledProviders, setNewUnenabledProviders] = useState<string[]>(
|
||||
[]
|
||||
);
|
||||
|
||||
const [showDeleteCredentialsModal, setShowDeleteCredentialsModal] =
|
||||
useState<boolean>(false);
|
||||
|
||||
@ -90,7 +94,10 @@ export function EmbeddingModelSelection({
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
);
|
||||
|
||||
const { data: embeddingProviderDetails } = useSWR<EmbeddingDetails[]>(
|
||||
const {
|
||||
data: embeddingProviderDetails,
|
||||
mutate: mutateEmbeddingProviderDetails,
|
||||
} = useSWR<EmbeddingDetails[]>(
|
||||
EMBEDDING_PROVIDERS_ADMIN_URL,
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
@ -132,32 +139,6 @@ export function EmbeddingModelSelection({
|
||||
}
|
||||
};
|
||||
|
||||
const clientsideAddProvider = (provider: CloudEmbeddingProvider) => {
|
||||
const providerType = provider.provider_type;
|
||||
setNewEnabledProviders((newEnabledProviders) => [
|
||||
...newEnabledProviders,
|
||||
providerType,
|
||||
]);
|
||||
setNewUnenabledProviders((newUnenabledProviders) =>
|
||||
newUnenabledProviders.filter(
|
||||
(givenProviderType) => givenProviderType != providerType
|
||||
)
|
||||
);
|
||||
};
|
||||
|
||||
const clientsideRemoveProvider = (provider: CloudEmbeddingProvider) => {
|
||||
const providerType = provider.provider_type;
|
||||
setNewEnabledProviders((newEnabledProviders) =>
|
||||
newEnabledProviders.filter(
|
||||
(givenProviderType) => givenProviderType != providerType
|
||||
)
|
||||
);
|
||||
setNewUnenabledProviders((newUnenabledProviders) => [
|
||||
...newUnenabledProviders,
|
||||
providerType,
|
||||
]);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="p-2">
|
||||
{alreadySelectedModel && (
|
||||
@ -186,14 +167,16 @@ export function EmbeddingModelSelection({
|
||||
|
||||
{showTentativeProvider && (
|
||||
<ProviderCreationModal
|
||||
updateCurrentModel={updateCurrentModel}
|
||||
isProxy={showTentativeProvider.provider_type == "LiteLLM"}
|
||||
isAzure={showTentativeProvider.provider_type == "Azure"}
|
||||
selectedProvider={showTentativeProvider}
|
||||
onConfirm={() => {
|
||||
setShowTentativeProvider(showUnconfiguredProvider);
|
||||
clientsideAddProvider(showTentativeProvider);
|
||||
if (showModelInQueue) {
|
||||
setShowTentativeModel(showModelInQueue);
|
||||
}
|
||||
mutateEmbeddingProviderDetails();
|
||||
}}
|
||||
onCancel={() => {
|
||||
setShowModelInQueue(null);
|
||||
@ -205,10 +188,11 @@ export function EmbeddingModelSelection({
|
||||
{changeCredentialsProvider && (
|
||||
<ChangeCredentialsModal
|
||||
isProxy={changeCredentialsProvider.provider_type == "LiteLLM"}
|
||||
isAzure={changeCredentialsProvider.provider_type == "Azure"}
|
||||
useFileUpload={changeCredentialsProvider.provider_type == "Google"}
|
||||
onDeleted={() => {
|
||||
clientsideRemoveProvider(changeCredentialsProvider);
|
||||
setChangeCredentialsProvider(null);
|
||||
mutateEmbeddingProviderDetails();
|
||||
}}
|
||||
provider={changeCredentialsProvider}
|
||||
onConfirm={() => setChangeCredentialsProvider(null)}
|
||||
@ -236,12 +220,13 @@ export function EmbeddingModelSelection({
|
||||
modelProvider={showTentativeProvider!}
|
||||
onConfirm={() => {
|
||||
setShowDeleteCredentialsModal(false);
|
||||
mutateEmbeddingProviderDetails();
|
||||
}}
|
||||
onCancel={() => setShowDeleteCredentialsModal(false)}
|
||||
/>
|
||||
)}
|
||||
|
||||
<p className="t mb-4">
|
||||
<p className="mb-4">
|
||||
Select from cloud, self-hosted models, or continue with your current
|
||||
embedding model.
|
||||
</p>
|
||||
@ -291,14 +276,13 @@ export function EmbeddingModelSelection({
|
||||
|
||||
{modelTab == "cloud" && (
|
||||
<CloudEmbeddingPage
|
||||
advancedEmbeddingDetails={advancedEmbeddingDetails}
|
||||
embeddingModelDetails={embeddingModelDetails}
|
||||
setShowModelInQueue={setShowModelInQueue}
|
||||
setShowTentativeModel={setShowTentativeModel}
|
||||
currentModel={selectedProvider || currentEmbeddingModel}
|
||||
setAlreadySelectedModel={setAlreadySelectedModel}
|
||||
embeddingProviderDetails={embeddingProviderDetails}
|
||||
newEnabledProviders={newEnabledProviders}
|
||||
newUnenabledProviders={newUnenabledProviders}
|
||||
setShowTentativeProvider={setShowTentativeProvider}
|
||||
setChangeCredentialsProvider={setChangeCredentialsProvider}
|
||||
/>
|
||||
|
@ -16,6 +16,7 @@ export function ChangeCredentialsModal({
|
||||
onDeleted,
|
||||
useFileUpload,
|
||||
isProxy = false,
|
||||
isAzure = false,
|
||||
}: {
|
||||
provider: CloudEmbeddingProvider;
|
||||
onConfirm: () => void;
|
||||
@ -23,6 +24,7 @@ export function ChangeCredentialsModal({
|
||||
onDeleted: () => void;
|
||||
useFileUpload: boolean;
|
||||
isProxy?: boolean;
|
||||
isAzure?: boolean;
|
||||
}) {
|
||||
const [apiKey, setApiKey] = useState("");
|
||||
const [apiUrl, setApiUrl] = useState("");
|
||||
@ -151,7 +153,6 @@ export function ChangeCredentialsModal({
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
width="max-w-3xl"
|
||||
@ -160,133 +161,131 @@ export function ChangeCredentialsModal({
|
||||
onOutsideClick={onCancel}
|
||||
>
|
||||
<>
|
||||
<p className="mb-4">
|
||||
You can modify your configuration by providing a new API key
|
||||
{isProxy ? " or API URL." : "."}
|
||||
</p>
|
||||
{!isAzure && (
|
||||
<>
|
||||
<p className="mb-4">
|
||||
You can modify your configuration by providing a new API key
|
||||
{isProxy ? " or API URL." : "."}
|
||||
</p>
|
||||
|
||||
<div className="mb-4 flex flex-col gap-y-2">
|
||||
<Label className="mt-2">API Key</Label>
|
||||
{useFileUpload ? (
|
||||
<>
|
||||
<Label className="mt-2">Upload JSON File</Label>
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept=".json"
|
||||
onChange={handleFileUpload}
|
||||
className="text-lg w-full p-1"
|
||||
/>
|
||||
{fileName && <p>Uploaded file: {fileName}</p>}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<input
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
bg-background-emphasis
|
||||
`}
|
||||
value={apiKey}
|
||||
onChange={(e: any) => setApiKey(e.target.value)}
|
||||
placeholder="Paste your API key here"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
<div className="mb-4 flex flex-col gap-y-2">
|
||||
<Label className="mt-2">API Key</Label>
|
||||
{useFileUpload ? (
|
||||
<>
|
||||
<Label className="mt-2">Upload JSON File</Label>
|
||||
<input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept=".json"
|
||||
onChange={handleFileUpload}
|
||||
className="text-lg w-full p-1"
|
||||
/>
|
||||
{fileName && <p>Uploaded file: {fileName}</p>}
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<input
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
bg-background-emphasis
|
||||
`}
|
||||
value={apiKey}
|
||||
onChange={(e: any) => setApiKey(e.target.value)}
|
||||
placeholder="Paste your API key here"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{isProxy && (
|
||||
<>
|
||||
<Label className="mt-2">API URL</Label>
|
||||
{isProxy && (
|
||||
<>
|
||||
<Label className="mt-2">API URL</Label>
|
||||
|
||||
<input
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
bg-background-emphasis
|
||||
`}
|
||||
value={apiUrl}
|
||||
onChange={(e: any) => setApiUrl(e.target.value)}
|
||||
placeholder="Paste your API URL here"
|
||||
/>
|
||||
<input
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
bg-background-emphasis
|
||||
`}
|
||||
value={apiUrl}
|
||||
onChange={(e: any) => setApiUrl(e.target.value)}
|
||||
placeholder="Paste your API URL here"
|
||||
/>
|
||||
|
||||
{deletionError && (
|
||||
<Callout title="Error" color="red" className="mt-4">
|
||||
{deletionError}
|
||||
{deletionError && (
|
||||
<Callout title="Error" color="red" className="mt-4">
|
||||
{deletionError}
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
<div>
|
||||
<Label className="mt-2">Test Model</Label>
|
||||
<p>
|
||||
Since you are using a liteLLM proxy, we'll need a
|
||||
model name to test the connection with.
|
||||
</p>
|
||||
</div>
|
||||
<input
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
bg-background-emphasis
|
||||
`}
|
||||
value={modelName}
|
||||
onChange={(e: any) => setModelName(e.target.value)}
|
||||
placeholder="Paste your model name here"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{testError && (
|
||||
<Callout title="Error" color="red" className="my-4">
|
||||
{testError}
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
<div>
|
||||
<Label className="mt-2">Test Model</Label>
|
||||
<p>
|
||||
Since you are using a liteLLM proxy, we'll need a model
|
||||
name to test the connection with.
|
||||
</p>
|
||||
</div>
|
||||
<input
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
bg-background-emphasis
|
||||
`}
|
||||
value={modelName}
|
||||
onChange={(e: any) => setModelName(e.target.value)}
|
||||
placeholder="Paste your API URL here"
|
||||
/>
|
||||
<Button
|
||||
className="mr-auto mt-4"
|
||||
color="blue"
|
||||
onClick={() => handleSubmit()}
|
||||
disabled={!apiKey}
|
||||
>
|
||||
Update Configuration
|
||||
</Button>
|
||||
|
||||
{deletionError && (
|
||||
<Callout title="Error" color="red" className="mt-4">
|
||||
{deletionError}
|
||||
</Callout>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<Divider />
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{testError && (
|
||||
<Callout title="Error" color="red" className="my-4">
|
||||
{testError}
|
||||
</Callout>
|
||||
)}
|
||||
<Subtitle className="mt-4 font-bold text-lg mb-2">
|
||||
You can delete your configuration.
|
||||
</Subtitle>
|
||||
<Text className="mb-2">
|
||||
This is only possible if you have already switched to a different
|
||||
embedding type!
|
||||
</Text>
|
||||
|
||||
<Button
|
||||
className="mr-auto mt-4"
|
||||
color="blue"
|
||||
onClick={() => handleSubmit()}
|
||||
disabled={!apiKey}
|
||||
>
|
||||
Update Configuration
|
||||
</Button>
|
||||
|
||||
<Divider />
|
||||
|
||||
<Subtitle className="mt-4 font-bold text-lg mb-2">
|
||||
You can also delete your configuration.
|
||||
</Subtitle>
|
||||
<Text className="mb-2">
|
||||
This is only possible if you have already switched to a different
|
||||
embedding type!
|
||||
</Text>
|
||||
|
||||
<Button className="mr-auto" onClick={handleDelete} color="red">
|
||||
Delete Configuration
|
||||
</Button>
|
||||
{deletionError && (
|
||||
<Callout title="Error" color="red" className="mt-4">
|
||||
{deletionError}
|
||||
</Callout>
|
||||
)}
|
||||
</div>
|
||||
<Button className="mr-auto" onClick={handleDelete} color="red">
|
||||
Delete Configuration
|
||||
</Button>
|
||||
{deletionError && (
|
||||
<Callout title="Error" color="red" className="mt-4">
|
||||
{deletionError}
|
||||
</Callout>
|
||||
)}
|
||||
</>
|
||||
</Modal>
|
||||
);
|
||||
|
@ -4,7 +4,10 @@ import { Formik, Form } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { Label, TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { CloudEmbeddingProvider } from "../../../../components/embedding/interfaces";
|
||||
import {
|
||||
CloudEmbeddingProvider,
|
||||
EmbeddingProvider,
|
||||
} from "../../../../components/embedding/interfaces";
|
||||
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "../../configuration/llm/constants";
|
||||
import { Modal } from "@/components/Modal";
|
||||
|
||||
@ -14,12 +17,19 @@ export function ProviderCreationModal({
|
||||
onCancel,
|
||||
existingProvider,
|
||||
isProxy,
|
||||
isAzure,
|
||||
updateCurrentModel,
|
||||
}: {
|
||||
updateCurrentModel: (
|
||||
newModel: string,
|
||||
provider_type: EmbeddingProvider
|
||||
) => void;
|
||||
selectedProvider: CloudEmbeddingProvider;
|
||||
onConfirm: () => void;
|
||||
onCancel: () => void;
|
||||
existingProvider?: CloudEmbeddingProvider;
|
||||
isProxy?: boolean;
|
||||
isAzure?: boolean;
|
||||
}) {
|
||||
const useFileUpload = selectedProvider.provider_type == "Google";
|
||||
|
||||
@ -41,16 +51,24 @@ export function ProviderCreationModal({
|
||||
|
||||
const validationSchema = Yup.object({
|
||||
provider_type: Yup.string().required("Provider type is required"),
|
||||
api_key: isProxy
|
||||
? Yup.string()
|
||||
: useFileUpload
|
||||
api_key:
|
||||
isProxy || isAzure
|
||||
? Yup.string()
|
||||
: Yup.string().required("API Key is required"),
|
||||
: useFileUpload
|
||||
? Yup.string()
|
||||
: Yup.string().required("API Key is required"),
|
||||
model_name: isProxy
|
||||
? Yup.string().required("Model name is required")
|
||||
: Yup.string().nullable(),
|
||||
api_url: isProxy
|
||||
? Yup.string().required("API URL is required")
|
||||
api_url:
|
||||
isProxy || isAzure
|
||||
? Yup.string().required("API URL is required")
|
||||
: Yup.string(),
|
||||
deployment_name: isAzure
|
||||
? Yup.string().required("Deployment name is required")
|
||||
: Yup.string(),
|
||||
api_version: isAzure
|
||||
? Yup.string().required("API Version is required")
|
||||
: Yup.string(),
|
||||
custom_config: Yup.array().of(Yup.array().of(Yup.string()).length(2)),
|
||||
});
|
||||
@ -101,6 +119,8 @@ export function ProviderCreationModal({
|
||||
api_key: values.api_key,
|
||||
api_url: values.api_url,
|
||||
model_name: values.model_name,
|
||||
api_version: values.api_version,
|
||||
deployment_name: values.deployment_name,
|
||||
}),
|
||||
}
|
||||
);
|
||||
@ -118,6 +138,8 @@ export function ProviderCreationModal({
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
...values,
|
||||
api_version: values.api_version,
|
||||
deployment_name: values.deployment_name,
|
||||
provider_type: values.provider_type.toLowerCase().split(" ")[0],
|
||||
custom_config: customConfig,
|
||||
is_default_provider: false,
|
||||
@ -125,6 +147,10 @@ export function ProviderCreationModal({
|
||||
}),
|
||||
});
|
||||
|
||||
if (isAzure) {
|
||||
updateCurrentModel(values.model_name, EmbeddingProvider.AZURE);
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json();
|
||||
throw new Error(
|
||||
@ -178,26 +204,45 @@ export function ProviderCreationModal({
|
||||
href={selectedProvider.apiLink}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{isProxy ? "API URL" : "API KEY"}
|
||||
{isProxy || isAzure ? "API URL" : "API KEY"}
|
||||
</a>
|
||||
</Text>
|
||||
|
||||
<div className="flex w-full flex-col gap-y-6">
|
||||
{(isProxy || isAzure) && (
|
||||
<TextFormField
|
||||
name="api_url"
|
||||
label="API URL"
|
||||
placeholder="API URL"
|
||||
type="text"
|
||||
/>
|
||||
)}
|
||||
|
||||
{isProxy && (
|
||||
<>
|
||||
<TextFormField
|
||||
name="api_url"
|
||||
label="API URL"
|
||||
placeholder="API URL"
|
||||
type="text"
|
||||
/>
|
||||
<TextFormField
|
||||
name="model_name"
|
||||
label="Model Name (for testing)"
|
||||
placeholder="Model Name"
|
||||
type="text"
|
||||
/>
|
||||
</>
|
||||
<TextFormField
|
||||
name="model_name"
|
||||
label={`Model Name ${isProxy ? "(for testing)" : ""}`}
|
||||
placeholder="Model Name"
|
||||
type="text"
|
||||
/>
|
||||
)}
|
||||
|
||||
{isAzure && (
|
||||
<TextFormField
|
||||
name="deployment_name"
|
||||
label="Deployment Name"
|
||||
placeholder="Deployment Name"
|
||||
type="text"
|
||||
/>
|
||||
)}
|
||||
|
||||
{isAzure && (
|
||||
<TextFormField
|
||||
name="api_version"
|
||||
label="API Version"
|
||||
placeholder="API Version"
|
||||
type="text"
|
||||
/>
|
||||
)}
|
||||
|
||||
{useFileUpload ? (
|
||||
|
@ -10,42 +10,42 @@ import {
|
||||
EmbeddingModelDescriptor,
|
||||
EmbeddingProvider,
|
||||
LITELLM_CLOUD_PROVIDER,
|
||||
AZURE_CLOUD_PROVIDER,
|
||||
} from "../../../../components/embedding/interfaces";
|
||||
import { EmbeddingDetails } from "../EmbeddingModelSelectionForm";
|
||||
import { FiExternalLink, FiInfo, FiTrash } from "react-icons/fi";
|
||||
import { HoverPopup } from "@/components/HoverPopup";
|
||||
import { Dispatch, SetStateAction, useEffect, useState } from "react";
|
||||
import { LiteLLMModelForm } from "@/components/embedding/LiteLLMModelForm";
|
||||
import { CustomEmbeddingModelForm } from "@/components/embedding/CustomEmbeddingModelForm";
|
||||
import { deleteSearchSettings } from "./utils";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
|
||||
import { AdvancedSearchConfiguration } from "../interfaces";
|
||||
|
||||
export default function CloudEmbeddingPage({
|
||||
currentModel,
|
||||
embeddingProviderDetails,
|
||||
embeddingModelDetails,
|
||||
newEnabledProviders,
|
||||
newUnenabledProviders,
|
||||
setShowTentativeProvider,
|
||||
setChangeCredentialsProvider,
|
||||
setAlreadySelectedModel,
|
||||
setShowTentativeModel,
|
||||
setShowModelInQueue,
|
||||
advancedEmbeddingDetails,
|
||||
}: {
|
||||
setShowModelInQueue: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||
setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||
currentModel: EmbeddingModelDescriptor | CloudEmbeddingModel;
|
||||
setAlreadySelectedModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||
newUnenabledProviders: string[];
|
||||
embeddingModelDetails?: CloudEmbeddingModel[];
|
||||
embeddingProviderDetails?: EmbeddingDetails[];
|
||||
newEnabledProviders: string[];
|
||||
setShowTentativeProvider: React.Dispatch<
|
||||
React.SetStateAction<CloudEmbeddingProvider | null>
|
||||
>;
|
||||
setChangeCredentialsProvider: React.Dispatch<
|
||||
React.SetStateAction<CloudEmbeddingProvider | null>
|
||||
>;
|
||||
advancedEmbeddingDetails: AdvancedSearchConfiguration;
|
||||
}) {
|
||||
function hasProviderTypeinArray(
|
||||
arr: Array<{ provider_type: string }>,
|
||||
@ -60,27 +60,38 @@ export default function CloudEmbeddingPage({
|
||||
(model) => ({
|
||||
...model,
|
||||
configured:
|
||||
!newUnenabledProviders.includes(model.provider_type) &&
|
||||
(newEnabledProviders.includes(model.provider_type) ||
|
||||
(embeddingProviderDetails &&
|
||||
hasProviderTypeinArray(
|
||||
embeddingProviderDetails,
|
||||
model.provider_type
|
||||
))!),
|
||||
embeddingProviderDetails &&
|
||||
hasProviderTypeinArray(embeddingProviderDetails, model.provider_type),
|
||||
})
|
||||
);
|
||||
const [liteLLMProvider, setLiteLLMProvider] = useState<
|
||||
EmbeddingDetails | undefined
|
||||
>(undefined);
|
||||
|
||||
const [azureProvider, setAzureProvider] = useState<
|
||||
EmbeddingDetails | undefined
|
||||
>(undefined);
|
||||
|
||||
useEffect(() => {
|
||||
const foundProvider = embeddingProviderDetails?.find(
|
||||
const liteLLMProvider = embeddingProviderDetails?.find(
|
||||
(provider) =>
|
||||
provider.provider_type === EmbeddingProvider.LITELLM.toLowerCase()
|
||||
);
|
||||
setLiteLLMProvider(foundProvider);
|
||||
setLiteLLMProvider(liteLLMProvider);
|
||||
const azureProvider = embeddingProviderDetails?.find(
|
||||
(provider) =>
|
||||
provider.provider_type === EmbeddingProvider.AZURE.toLowerCase()
|
||||
);
|
||||
setAzureProvider(azureProvider);
|
||||
}, [embeddingProviderDetails]);
|
||||
|
||||
const isAzureConfigured = azureProvider !== undefined;
|
||||
|
||||
// Get details of the configured Azure provider
|
||||
const azureProviderDetails = embeddingProviderDetails?.find(
|
||||
(provider) => provider.provider_type.toLowerCase() === "azure"
|
||||
);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Title className="mt-8">
|
||||
@ -248,7 +259,8 @@ export default function CloudEmbeddingPage({
|
||||
: ""
|
||||
}`}
|
||||
>
|
||||
<LiteLLMModelForm
|
||||
<CustomEmbeddingModelForm
|
||||
embeddingType={EmbeddingProvider.LITELLM}
|
||||
provider={liteLLMProvider}
|
||||
currentValues={
|
||||
currentModel.provider_type === EmbeddingProvider.LITELLM
|
||||
@ -262,6 +274,126 @@ export default function CloudEmbeddingPage({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Text className="mt-6">
|
||||
You can also use Azure OpenAI models for embeddings. Azure requires
|
||||
separate configuration for each model.
|
||||
</Text>
|
||||
|
||||
<div key={AZURE_CLOUD_PROVIDER.provider_type} className="mt-4 w-full">
|
||||
<div className="flex items-center mb-2">
|
||||
{AZURE_CLOUD_PROVIDER.icon({ size: 40 })}
|
||||
<h2 className="ml-2 mt-2 text-xl font-bold">
|
||||
{AZURE_CLOUD_PROVIDER.provider_type}{" "}
|
||||
</h2>
|
||||
<HoverPopup
|
||||
mainContent={
|
||||
<FiInfo className="ml-2 mt-2 cursor-pointer" size={18} />
|
||||
}
|
||||
popupContent={
|
||||
<div className="text-sm text-text-800 w-52">
|
||||
<div className="my-auto">
|
||||
{AZURE_CLOUD_PROVIDER.description}
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
style="dark"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="w-full flex flex-col items-start">
|
||||
{!isAzureConfigured ? (
|
||||
<>
|
||||
<button
|
||||
onClick={() => setShowTentativeProvider(AZURE_CLOUD_PROVIDER)}
|
||||
className="mb-2 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 text-sm cursor-pointer"
|
||||
>
|
||||
Configure Azure OpenAI
|
||||
</button>
|
||||
<div className="mt-2 w-full max-w-4xl">
|
||||
<Card className="p-4 border border-gray-200 rounded-lg shadow-sm">
|
||||
<Text className="text-base font-medium mb-2">
|
||||
Configure Azure OpenAI for Embeddings
|
||||
</Text>
|
||||
<Text className="text-sm text-gray-600 mb-3">
|
||||
Click "Configure Azure OpenAI" to set up Azure
|
||||
OpenAI for embeddings.
|
||||
</Text>
|
||||
<div className="flex items-center text-sm text-gray-700">
|
||||
<FiInfo className="text-gray-400 mr-2" size={16} />
|
||||
<Text>
|
||||
You'll need: API version, base URL, API key, model
|
||||
name, and deployment name.
|
||||
</Text>
|
||||
</div>
|
||||
</Card>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="mb-6 w-full">
|
||||
<Text className="text-lg font-semibold mb-3">
|
||||
Current Azure Configuration
|
||||
</Text>
|
||||
|
||||
{azureProviderDetails ? (
|
||||
<Card className="bg-white shadow-sm border border-gray-200 rounded-lg">
|
||||
<div className="p-4 space-y-3">
|
||||
<div className="flex justify-between">
|
||||
<span className="font-medium">API Version:</span>
|
||||
<span>{azureProviderDetails.api_version}</span>
|
||||
</div>
|
||||
<div className="flex justify-between">
|
||||
<span className="font-medium">Base URL:</span>
|
||||
<span>{azureProviderDetails.api_url}</span>
|
||||
</div>
|
||||
<div className="flex justify-between">
|
||||
<span className="font-medium">Deployment Name:</span>
|
||||
<span>{azureProviderDetails.deployment_name}</span>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
onClick={() =>
|
||||
setChangeCredentialsProvider(AZURE_CLOUD_PROVIDER)
|
||||
}
|
||||
className="mt-2 px-4 py-2 bg-red-500 text-white rounded hover:bg-red-600 text-sm"
|
||||
>
|
||||
Delete Current Azure Provider
|
||||
</button>
|
||||
</Card>
|
||||
) : (
|
||||
<Card className="bg-gray-50 border border-gray-200 rounded-lg">
|
||||
<div className="p-4 text-gray-500 text-center">
|
||||
No Azure provider has been configured yet.
|
||||
</div>
|
||||
</Card>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Card
|
||||
className={`mt-2 w-full max-w-4xl ${
|
||||
currentModel.provider_type === EmbeddingProvider.AZURE
|
||||
? "border-2 border-blue-500"
|
||||
: ""
|
||||
}`}
|
||||
>
|
||||
{azureProvider && (
|
||||
<CustomEmbeddingModelForm
|
||||
embeddingType={EmbeddingProvider.AZURE}
|
||||
provider={azureProvider}
|
||||
currentValues={
|
||||
currentModel.provider_type === EmbeddingProvider.AZURE
|
||||
? (currentModel as CloudEmbeddingModel)
|
||||
: null
|
||||
}
|
||||
setShowTentativeModel={setShowTentativeModel}
|
||||
/>
|
||||
)}
|
||||
</Card>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
@ -152,11 +152,6 @@ export default function EmbeddingForm() {
|
||||
}
|
||||
}, [currentEmbeddingModel]);
|
||||
|
||||
useEffect(() => {
|
||||
if (currentEmbeddingModel) {
|
||||
setSelectedProvider(currentEmbeddingModel);
|
||||
}
|
||||
}, [currentEmbeddingModel]);
|
||||
if (!selectedProvider) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
@ -164,10 +159,18 @@ export default function EmbeddingForm() {
|
||||
return <ErrorCallout errorTitle="Failed to fetch embedding model status" />;
|
||||
}
|
||||
|
||||
const updateCurrentModel = (newModel: string) => {
|
||||
setAdvancedEmbeddingDetails((values) => ({
|
||||
...values,
|
||||
model_name: newModel,
|
||||
}));
|
||||
};
|
||||
|
||||
const updateSearch = async () => {
|
||||
const values: SavedSearchSettings = {
|
||||
...rerankingDetails,
|
||||
...advancedEmbeddingDetails,
|
||||
...selectedProvider,
|
||||
provider_type:
|
||||
selectedProvider.provider_type?.toLowerCase() as EmbeddingProvider | null,
|
||||
};
|
||||
@ -311,11 +314,13 @@ export default function EmbeddingForm() {
|
||||
</Text>
|
||||
<Card>
|
||||
<EmbeddingModelSelection
|
||||
updateCurrentModel={updateCurrentModel}
|
||||
setModelTab={setModelTab}
|
||||
modelTab={modelTab}
|
||||
selectedProvider={selectedProvider}
|
||||
currentEmbeddingModel={currentEmbeddingModel}
|
||||
updateSelectedProvider={updateSelectedProvider}
|
||||
advancedEmbeddingDetails={advancedEmbeddingDetails}
|
||||
/>
|
||||
</Card>
|
||||
<div className="mt-4 flex w-full justify-end">
|
||||
|
@ -27,7 +27,7 @@ export async function Layout({ children }: { children: React.ReactNode }) {
|
||||
|
||||
const authTypeMetadata = results[0] as AuthTypeMetadata | null;
|
||||
const user = results[1] as User | null;
|
||||
|
||||
console.log("authTypeMetadata", authTypeMetadata);
|
||||
const authDisabled = authTypeMetadata?.authType === "disabled";
|
||||
const requiresVerification = authTypeMetadata?.requiresVerification;
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { CloudEmbeddingModel, CloudEmbeddingProvider } from "./interfaces";
|
||||
import { CloudEmbeddingModel, EmbeddingProvider } from "./interfaces";
|
||||
import { Formik, Form } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { TextFormField, BooleanFormField } from "../admin/connectors/Field";
|
||||
@ -6,14 +6,16 @@ import { Dispatch, SetStateAction } from "react";
|
||||
import { Button, Text } from "@tremor/react";
|
||||
import { EmbeddingDetails } from "@/app/admin/embeddings/EmbeddingModelSelectionForm";
|
||||
|
||||
export function LiteLLMModelForm({
|
||||
export function CustomEmbeddingModelForm({
|
||||
setShowTentativeModel,
|
||||
currentValues,
|
||||
provider,
|
||||
embeddingType,
|
||||
}: {
|
||||
setShowTentativeModel: Dispatch<SetStateAction<CloudEmbeddingModel | null>>;
|
||||
currentValues: CloudEmbeddingModel | null;
|
||||
provider: EmbeddingDetails;
|
||||
embeddingType: EmbeddingProvider;
|
||||
}) {
|
||||
return (
|
||||
<div>
|
||||
@ -25,7 +27,7 @@ export function LiteLLMModelForm({
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
provider_type: "LiteLLM",
|
||||
provider_type: embeddingType,
|
||||
api_key: "",
|
||||
enabled: true,
|
||||
api_url: provider.api_url,
|
||||
@ -55,18 +57,21 @@ export function LiteLLMModelForm({
|
||||
max_tokens: Yup.number(),
|
||||
})}
|
||||
onSubmit={async (values) => {
|
||||
console.log(values);
|
||||
setShowTentativeModel(values as CloudEmbeddingModel);
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) => (
|
||||
{({ isSubmitting, submitForm, errors }) => (
|
||||
<Form>
|
||||
<Text className="text-xl text-text-900 font-bold mb-4">
|
||||
Add a new model to LiteLLM proxy at {provider.api_url}
|
||||
Specify details for your{" "}
|
||||
{embeddingType === EmbeddingProvider.AZURE ? "Azure" : "LiteLLM"}{" "}
|
||||
Provider's model
|
||||
</Text>
|
||||
<TextFormField
|
||||
name="model_name"
|
||||
label="Model Name:"
|
||||
subtext="The name of the LiteLLM model"
|
||||
subtext={`The name of the ${embeddingType === EmbeddingProvider.AZURE ? "Azure" : "LiteLLM"} model`}
|
||||
placeholder="e.g. 'all-MiniLM-L6-v2'"
|
||||
autoCompleteDisabled={true}
|
||||
/>
|
||||
@ -103,10 +108,13 @@ export function LiteLLMModelForm({
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
onClick={() => console.log(errors)}
|
||||
disabled={isSubmitting}
|
||||
className="w-64 mx-auto"
|
||||
>
|
||||
Configure LiteLLM Model
|
||||
Configure{" "}
|
||||
{embeddingType === EmbeddingProvider.AZURE ? "Azure" : "LiteLLM"}{" "}
|
||||
Model
|
||||
</Button>
|
||||
</Form>
|
||||
)}
|
@ -31,7 +31,7 @@ export default function EmbeddingSidebar() {
|
||||
w-[250px]
|
||||
`}
|
||||
>
|
||||
<div className="fixed h-full left-0 top-0 w-[250px]">
|
||||
<div className="fixed h-full left-0 top-0 bg-background-100 w-[250px]">
|
||||
<div className="ml-4 mr-3 flex flex gap-x-1 items-center mt-2 my-auto text-text-700 text-xl">
|
||||
<div className="mr-1 my-auto h-6 w-6">
|
||||
<Logo height={24} width={24} />
|
||||
|
@ -1,4 +1,5 @@
|
||||
import {
|
||||
AzureIcon,
|
||||
CohereIcon,
|
||||
GoogleIcon,
|
||||
IconProps,
|
||||
@ -16,6 +17,7 @@ export enum EmbeddingProvider {
|
||||
VOYAGE = "Voyage",
|
||||
GOOGLE = "Google",
|
||||
LITELLM = "LiteLLM",
|
||||
AZURE = "Azure",
|
||||
}
|
||||
|
||||
export interface CloudEmbeddingProvider {
|
||||
@ -49,6 +51,8 @@ export interface EmbeddingModelDescriptor {
|
||||
description: string;
|
||||
api_key: string | null;
|
||||
api_url: string | null;
|
||||
api_version?: string | null;
|
||||
deployment_name?: string | null;
|
||||
index_name: string | null;
|
||||
}
|
||||
|
||||
@ -161,6 +165,20 @@ export const LITELLM_CLOUD_PROVIDER: CloudEmbeddingProvider = {
|
||||
embedding_models: [], // No default embedding models
|
||||
};
|
||||
|
||||
export const AZURE_CLOUD_PROVIDER: CloudEmbeddingProvider = {
|
||||
provider_type: EmbeddingProvider.AZURE,
|
||||
website:
|
||||
"https://azure.microsoft.com/en-us/products/cognitive-services/openai/",
|
||||
icon: AzureIcon,
|
||||
description:
|
||||
"Azure OpenAI is a cloud-based AI service that provides access to OpenAI models.",
|
||||
apiLink:
|
||||
"https://docs.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource",
|
||||
costslink:
|
||||
"https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai/",
|
||||
embedding_models: [], // No default embedding models
|
||||
};
|
||||
|
||||
export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
{
|
||||
provider_type: EmbeddingProvider.COHERE,
|
||||
|
Loading…
x
Reference in New Issue
Block a user