mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 21:32:36 +01:00
Add UI-based LLM selection
This commit is contained in:
parent
4c740060aa
commit
f5b3333df3
@ -0,0 +1,49 @@
|
||||
"""Add tables for UI-based LLM configuration
|
||||
|
||||
Revision ID: 401c1ac29467
|
||||
Revises: 703313b75876
|
||||
Create Date: 2024-04-13 18:07:29.153817
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "401c1ac29467"
|
||||
down_revision = "703313b75876"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_provider",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("api_key", sa.String(), nullable=True),
|
||||
sa.Column("api_base", sa.String(), nullable=True),
|
||||
sa.Column("api_version", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"custom_config",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("default_model_name", sa.String(), nullable=False),
|
||||
sa.Column("fast_default_model_name", sa.String(), nullable=True),
|
||||
sa.Column("is_default_provider", sa.Boolean(), unique=True, nullable=True),
|
||||
sa.Column("model_names", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("llm_model_provider_override", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "llm_model_provider_override")
|
||||
|
||||
op.drop_table("llm_provider")
|
@ -88,6 +88,7 @@ def load_personas_from_yaml(
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompts=cast(list[PromptDBModel] | None, prompts),
|
||||
|
@ -34,11 +34,10 @@ from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import LLMConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_llm_for_persona
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import SearchRequest
|
||||
@ -135,8 +134,8 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
try:
|
||||
llm = get_default_llm(
|
||||
gen_ai_model_version_override=persona.llm_model_version_override
|
||||
llm = get_llm_for_persona(
|
||||
persona, new_msg_req.llm_override or chat_session.llm_override
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
llm = None
|
||||
@ -373,9 +372,11 @@ def stream_chat_message_objects(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
),
|
||||
llm_config=LLMConfig.from_persona(
|
||||
persona,
|
||||
llm_override=(new_msg_req.llm_override or chat_session.llm_override),
|
||||
llm=(
|
||||
llm
|
||||
or get_llm_for_persona(
|
||||
persona, new_msg_req.llm_override or chat_session.llm_override
|
||||
)
|
||||
),
|
||||
doc_relevance_list=llm_relevance_list,
|
||||
message_history=[
|
||||
|
@ -24,6 +24,7 @@ MATCH_HIGHLIGHTS = "match_highlights"
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
# are still useful as a search result but not for QA.
|
||||
IGNORE_FOR_QA = "ignore_for_qa"
|
||||
# NOTE: deprecated, only used for porting key from old system
|
||||
GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
|
||||
PUBLIC_DOC_PAT = "PUBLIC"
|
||||
PUBLIC_DOCUMENT_SET = "__PUBLIC"
|
||||
@ -52,10 +53,6 @@ SECTION_SEPARATOR = "\n\n"
|
||||
INDEX_SEPARATOR = "==="
|
||||
|
||||
|
||||
# Key-Value store constants
|
||||
GEN_AI_DETECTED_MODEL = "gen_ai_detected_model"
|
||||
|
||||
|
||||
# Messages
|
||||
DISABLED_GEN_AI_MSG = (
|
||||
"Your System Admin has disabled the Generative AI functionalities of Danswer.\n"
|
||||
|
@ -36,13 +36,15 @@ from danswer.danswerbot.slack.utils import slack_usage_report
|
||||
from danswer.danswerbot.slack.utils import SlackRateLimiter
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_llm_for_persona
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
@ -237,32 +239,38 @@ def handle_message(
|
||||
|
||||
max_document_tokens: int | None = None
|
||||
max_history_tokens: int | None = None
|
||||
if len(new_message_request.messages) > 1:
|
||||
llm_name = get_default_llm_version()[0]
|
||||
if persona and persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(model_name=llm_name)
|
||||
max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
||||
query_text = new_message_request.messages[0].message
|
||||
if persona:
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
)
|
||||
else:
|
||||
max_document_tokens = (
|
||||
remaining_tokens
|
||||
- 512 # Needs to be more than any of the QA prompts
|
||||
- check_number_of_tokens(query_text)
|
||||
)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if len(new_message_request.messages) > 1:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(db_session, new_message_request.persona_id),
|
||||
)
|
||||
llm = get_llm_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
||||
query_text = new_message_request.messages[0].message
|
||||
if persona:
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
)
|
||||
else:
|
||||
max_document_tokens = (
|
||||
remaining_tokens
|
||||
- 512 # Needs to be more than any of the QA prompts
|
||||
- check_number_of_tokens(query_text)
|
||||
)
|
||||
|
||||
# This also handles creating the query event in postgres
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
|
@ -498,6 +498,7 @@ def upsert_persona(
|
||||
recency_bias: RecencyBiasSetting,
|
||||
prompts: list[Prompt] | None,
|
||||
document_sets: list[DBDocumentSet] | None,
|
||||
llm_model_provider_override: str | None,
|
||||
llm_model_version_override: str | None,
|
||||
starter_messages: list[StarterMessage] | None,
|
||||
is_public: bool,
|
||||
@ -524,6 +525,7 @@ def upsert_persona(
|
||||
persona.llm_filter_extraction = llm_filter_extraction
|
||||
persona.recency_bias = recency_bias
|
||||
persona.default_persona = default_persona
|
||||
persona.llm_model_provider_override = llm_model_provider_override
|
||||
persona.llm_model_version_override = llm_model_version_override
|
||||
persona.starter_messages = starter_messages
|
||||
persona.deleted = False # Un-delete if previously deleted
|
||||
@ -553,6 +555,7 @@ def upsert_persona(
|
||||
default_persona=default_persona,
|
||||
prompts=prompts or [],
|
||||
document_sets=document_sets or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
starter_messages=starter_messages,
|
||||
)
|
||||
|
@ -71,7 +71,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
def get_session_context_manager() -> ContextManager:
|
||||
def get_session_context_manager() -> ContextManager[Session]:
|
||||
return contextlib.contextmanager(get_session)()
|
||||
|
||||
|
||||
|
96
backend/danswer/db/llm.py
Normal file
96
backend/danswer/db/llm.py
Normal file
@ -0,0 +1,96 @@
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
|
||||
|
||||
def upsert_llm_provider(
|
||||
db_session: Session, llm_provider: LLMProviderUpsertRequest
|
||||
) -> FullLLMProvider:
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
if existing_llm_provider:
|
||||
existing_llm_provider.api_key = llm_provider.api_key
|
||||
existing_llm_provider.api_base = llm_provider.api_base
|
||||
existing_llm_provider.api_version = llm_provider.api_version
|
||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||
existing_llm_provider.fast_default_model_name = (
|
||||
llm_provider.fast_default_model_name
|
||||
)
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
db_session.commit()
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
# if it does not exist, create a new entry
|
||||
llm_provider_model = LLMProviderModel(
|
||||
name=llm_provider.name,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
default_model_name=llm_provider.default_model_name,
|
||||
fast_default_model_name=llm_provider.fast_default_model_name,
|
||||
model_names=llm_provider.model_names,
|
||||
is_default_provider=None,
|
||||
)
|
||||
db_session.add(llm_provider_model)
|
||||
db_session.commit()
|
||||
|
||||
return FullLLMProvider.from_model(llm_provider_model)
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_provider == True # noqa: E712
|
||||
)
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_default_provider(db_session: Session, provider_id: int) -> None:
|
||||
new_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
if not new_default:
|
||||
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
|
||||
|
||||
existing_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_provider == True # noqa: E712
|
||||
)
|
||||
)
|
||||
if existing_default:
|
||||
existing_default.is_default_provider = None
|
||||
# required to ensure that the below does not cause a unique constraint violation
|
||||
db_session.flush()
|
||||
|
||||
new_default.is_default_provider = True
|
||||
db_session.commit()
|
@ -8,8 +8,8 @@ from typing import Optional
|
||||
from typing import TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
|
||||
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
|
||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import DateTime
|
||||
@ -695,6 +695,33 @@ Structures, Organizational, Configurations Tables
|
||||
"""
|
||||
|
||||
|
||||
class LLMProvider(Base):
|
||||
__tablename__ = "llm_provider"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
api_key: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# custom configs that should be passed to the LLM provider at inference time
|
||||
# (e.g. `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, etc. for bedrock)
|
||||
custom_config: Mapped[dict[str, str] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# The LLMs that are available for this provider. Only required if not a default provider.
|
||||
# If a default provider, then the LLM options are pulled from the `options.py` file.
|
||||
# If needed, can be pulled out as a separate table in the future.
|
||||
model_names: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
|
||||
|
||||
class DocumentSet(Base):
|
||||
__tablename__ = "document_set"
|
||||
|
||||
@ -792,6 +819,9 @@ class Persona(Base):
|
||||
# globablly via env variables. For flexibility, validity is not currently enforced
|
||||
# NOTE: only is applied on the actual response generation - is not used for things like
|
||||
# auto-detected time filters, relevance filters, etc.
|
||||
llm_model_provider_override: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
)
|
||||
llm_model_version_override: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
)
|
||||
|
@ -1,11 +1,13 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.chat import get_prompts_by_ids
|
||||
from danswer.db.chat import upsert_persona
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Persona__User
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
@ -69,6 +71,7 @@ def create_update_persona(
|
||||
recency_bias=create_persona_request.recency_bias,
|
||||
prompts=prompts,
|
||||
document_sets=document_sets,
|
||||
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
is_public=create_persona_request.is_public,
|
||||
@ -91,3 +94,7 @@ def create_update_persona(
|
||||
logger.exception("Failed to create persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return PersonaSnapshot.from_model(persona)
|
||||
|
||||
|
||||
def fetch_persona_by_id(db_session: Session, persona_id: int) -> Persona | None:
|
||||
return db_session.scalar(select(Persona).where(Persona.id == persona_id))
|
||||
|
@ -59,6 +59,7 @@ def create_slack_bot_persona(
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
prompts=None,
|
||||
document_sets=document_sets,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
is_public=True,
|
||||
|
@ -1,9 +1,27 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.app_configs import DYNAMIC_CONFIG_DIR_PATH
|
||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_API_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.factory import PostgresBackedDynamicConfigStore
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def read_file_system_store(directory_path: str) -> dict:
|
||||
@ -38,3 +56,60 @@ def insert_into_postgres(store_data: dict) -> None:
|
||||
def port_filesystem_to_postgres(directory_path: str = DYNAMIC_CONFIG_DIR_PATH) -> None:
|
||||
store_data = read_file_system_store(directory_path)
|
||||
insert_into_postgres(store_data)
|
||||
|
||||
|
||||
def port_api_key_to_postgres() -> None:
|
||||
# can't port over custom, no longer supported
|
||||
if GEN_AI_MODEL_PROVIDER == "custom":
|
||||
return
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
# if we already have ported things over / setup providers in the db, don't do anything
|
||||
if len(fetch_existing_llm_providers(db_session)) > 0:
|
||||
return
|
||||
|
||||
api_key = GEN_AI_API_KEY
|
||||
try:
|
||||
api_key = cast(
|
||||
str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)
|
||||
)
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
||||
# if no API key set, don't port anything over
|
||||
if not api_key:
|
||||
return
|
||||
|
||||
default_model_name = GEN_AI_MODEL_VERSION
|
||||
if GEN_AI_MODEL_PROVIDER == "openai" and not default_model_name:
|
||||
default_model_name = "gpt-4"
|
||||
|
||||
# if no default model name found, don't port anything over
|
||||
if not default_model_name:
|
||||
return
|
||||
|
||||
default_fast_model_name = FAST_GEN_AI_MODEL_VERSION
|
||||
if GEN_AI_MODEL_PROVIDER == "openai" and not default_fast_model_name:
|
||||
default_fast_model_name = "gpt-3.5-turbo"
|
||||
|
||||
llm_provider_upsert = LLMProviderUpsertRequest(
|
||||
name=GEN_AI_MODEL_PROVIDER,
|
||||
api_key=api_key,
|
||||
api_base=GEN_AI_API_ENDPOINT,
|
||||
api_version=GEN_AI_API_VERSION,
|
||||
# can't port over any custom configs, since we don't know
|
||||
# all the possible keys and values that could be in there
|
||||
custom_config=None,
|
||||
default_model_name=default_model_name,
|
||||
fast_default_model_name=default_fast_model_name,
|
||||
model_names=None,
|
||||
)
|
||||
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
|
||||
update_default_provider(db_session, llm_provider.id)
|
||||
logger.info(f"Ported over LLM provider:\n\n{llm_provider}")
|
||||
|
||||
# delete the old API key
|
||||
try:
|
||||
get_dynamic_config_store().delete(GEN_AI_API_KEY_STORAGE_KEY)
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
@ -12,7 +12,6 @@ from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.llm.answering.doc_pruning import prune_documents
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import LLMConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import StreamProcessor
|
||||
@ -26,7 +25,7 @@ from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
build_quotes_processor,
|
||||
)
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
|
||||
|
||||
@ -53,7 +52,7 @@ class Answer:
|
||||
question: str,
|
||||
docs: list[LlmDoc],
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
llm_config: LLMConfig,
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
doc_relevance_list: list[bool] | None = None,
|
||||
@ -74,15 +73,9 @@ class Answer:
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.answer_style_config = answer_style_config
|
||||
self.llm_config = llm_config
|
||||
self.prompt_config = prompt_config
|
||||
|
||||
self.llm = get_default_llm(
|
||||
gen_ai_model_provider=self.llm_config.model_provider,
|
||||
gen_ai_model_version_override=self.llm_config.model_version,
|
||||
timeout=timeout,
|
||||
temperature=self.llm_config.temperature,
|
||||
)
|
||||
self.llm = llm
|
||||
self.llm_tokenizer = get_default_llm_tokenizer()
|
||||
|
||||
self._final_prompt: list[BaseMessage] | None = None
|
||||
@ -101,7 +94,7 @@ class Answer:
|
||||
docs=self.docs,
|
||||
doc_relevance_list=self.doc_relevance_list,
|
||||
prompt_config=self.prompt_config,
|
||||
llm_config=self.llm_config,
|
||||
llm_config=self.llm.config,
|
||||
question=self.question,
|
||||
document_pruning_config=self.answer_style_config.document_pruning_config,
|
||||
)
|
||||
@ -116,7 +109,7 @@ class Answer:
|
||||
self._final_prompt = build_citations_prompt(
|
||||
question=self.question,
|
||||
message_history=self.message_history,
|
||||
llm_config=self.llm_config,
|
||||
llm_config=self.llm.config,
|
||||
prompt_config=self.prompt_config,
|
||||
context_docs=self.pruned_docs,
|
||||
all_doc_useful=self.answer_style_config.citation_config.all_docs_useful,
|
||||
|
@ -7,9 +7,9 @@ from danswer.chat.models import (
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import LLMConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
|
@ -9,15 +9,11 @@ from pydantic import root_validator
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import Persona
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
@ -86,36 +82,6 @@ class AnswerStyleConfig(BaseModel):
|
||||
return values
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""Final representation of the LLM configuration passed into
|
||||
the `Answer` object."""
|
||||
|
||||
model_provider: str
|
||||
model_version: str
|
||||
temperature: float
|
||||
|
||||
@classmethod
|
||||
def from_persona(
|
||||
cls, persona: "Persona", llm_override: LLMOverride | None = None
|
||||
) -> "LLMConfig":
|
||||
model_provider_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
|
||||
return cls(
|
||||
model_provider=model_provider_override or GEN_AI_MODEL_PROVIDER,
|
||||
model_version=(
|
||||
model_version_override
|
||||
or persona.llm_model_version_override
|
||||
or get_default_llm_version()[0]
|
||||
),
|
||||
temperature=temperature_override or 0.0,
|
||||
)
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
|
@ -11,9 +11,10 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from danswer.db.chat import get_default_prompt
|
||||
from danswer.db.models import Persona
|
||||
from danswer.llm.answering.models import LLMConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.factory import get_llm_for_persona
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
@ -131,7 +132,9 @@ def compute_max_document_tokens(
|
||||
max_input_tokens = (
|
||||
max_llm_token_override
|
||||
if max_llm_token_override
|
||||
else get_max_input_tokens(model_name=llm_config.model_version)
|
||||
else get_max_input_tokens(
|
||||
model_name=llm_config.model_name, model_provider=llm_config.model_provider
|
||||
)
|
||||
)
|
||||
prompt_tokens = get_prompt_tokens(prompt_config)
|
||||
|
||||
@ -152,7 +155,7 @@ def compute_max_document_tokens_for_persona(
|
||||
prompt = persona.prompts[0] if persona.prompts else get_default_prompt()
|
||||
return compute_max_document_tokens(
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm_config=LLMConfig.from_persona(persona),
|
||||
llm_config=get_llm_for_persona(persona).config,
|
||||
actual_user_input=actual_user_input,
|
||||
max_llm_token_override=max_llm_token_override,
|
||||
)
|
||||
@ -162,7 +165,7 @@ def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int:
|
||||
"""Maximum tokens allows in the input to the LLM (of any type)."""
|
||||
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm_config.model_version, model_provider=llm_config.model_provider
|
||||
model_name=llm_config.model_name, model_provider=llm_config.model_provider
|
||||
)
|
||||
return input_tokens - _MISC_BUFFER
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import abc
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
|
||||
import litellm # type:ignore
|
||||
@ -12,11 +13,9 @@ from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_API_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import message_generator_to_string_generator
|
||||
from danswer.llm.utils import should_be_verbose
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -69,6 +68,7 @@ class LangChainChatLLM(LLM, abc.ABC):
|
||||
|
||||
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
||||
if LOG_ALL_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
self._log_prompt(prompt)
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
@ -85,23 +85,6 @@ class LangChainChatLLM(LLM, abc.ABC):
|
||||
logger.debug(f"Raw Model Output:\n{full_output}")
|
||||
|
||||
|
||||
def _get_model_str(
|
||||
model_provider: str | None,
|
||||
model_version: str | None,
|
||||
) -> str:
|
||||
if model_provider and model_version:
|
||||
return model_provider + "/" + model_version
|
||||
|
||||
if model_version:
|
||||
# Litellm defaults to openai if no provider specified
|
||||
# It's implicit so no need to specify here either
|
||||
return model_version
|
||||
|
||||
# User specified something wrong, just use Danswer default
|
||||
base, _ = get_default_llm_version()
|
||||
return base
|
||||
|
||||
|
||||
class DefaultMultiLLM(LangChainChatLLM):
|
||||
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
|
||||
See https://python.langchain.com/docs/integrations/chat/litellm"""
|
||||
@ -115,25 +98,35 @@ class DefaultMultiLLM(LangChainChatLLM):
|
||||
self,
|
||||
api_key: str | None,
|
||||
timeout: int,
|
||||
model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||
model_version: str | None = GEN_AI_MODEL_VERSION,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
api_base: str | None = GEN_AI_API_ENDPOINT,
|
||||
api_version: str | None = GEN_AI_API_VERSION,
|
||||
custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
):
|
||||
self._model_provider = model_provider
|
||||
self._model_version = model_name
|
||||
self._temperature = temperature
|
||||
|
||||
# Litellm Langchain integration currently doesn't take in the api key param
|
||||
# Can place this in the call below once integration is in
|
||||
litellm.api_key = api_key or "dummy-key"
|
||||
litellm.api_version = api_version
|
||||
|
||||
model_version = model_version or get_default_llm_version()[0]
|
||||
# NOTE: have to set these as environment variables for Litellm since
|
||||
# not all are able to passed in but they always support them set as env
|
||||
# variables
|
||||
if custom_config:
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
|
||||
self._llm = ChatLiteLLM( # type: ignore
|
||||
model=model_version
|
||||
if custom_llm_provider
|
||||
else _get_model_str(model_provider, model_version),
|
||||
model=(
|
||||
model_name if custom_llm_provider else f"{model_provider}/{model_name}"
|
||||
),
|
||||
api_base=api_base,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
max_tokens=max_output_tokens,
|
||||
@ -148,6 +141,14 @@ class DefaultMultiLLM(LangChainChatLLM):
|
||||
max_retries=0, # retries are handled outside of langchain
|
||||
)
|
||||
|
||||
@property
|
||||
def config(self) -> LLMConfig:
|
||||
return LLMConfig(
|
||||
model_provider=self._model_provider,
|
||||
model_name=self._model_version,
|
||||
temperature=self._temperature,
|
||||
)
|
||||
|
||||
@property
|
||||
def llm(self) -> ChatLiteLLM:
|
||||
return self._llm
|
||||
|
@ -1,48 +1,90 @@
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_default_provider
|
||||
from danswer.db.llm import fetch_provider
|
||||
from danswer.db.models import Persona
|
||||
from danswer.llm.chat_llm import DefaultMultiLLM
|
||||
from danswer.llm.custom_llm import CustomModelServer
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.gpt_4_all import DanswerGPT4All
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.llm.utils import get_gen_ai_api_key
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
|
||||
|
||||
def get_llm_for_persona(
|
||||
persona: Persona, llm_override: LLMOverride | None = None
|
||||
) -> LLM:
|
||||
model_provider_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
|
||||
return get_default_llm(
|
||||
gen_ai_model_provider=model_provider_override
|
||||
or persona.llm_model_provider_override,
|
||||
gen_ai_model_version_override=(
|
||||
model_version_override or persona.llm_model_version_override
|
||||
),
|
||||
temperature=temperature_override or GEN_AI_TEMPERATURE,
|
||||
)
|
||||
|
||||
|
||||
def get_default_llm(
|
||||
gen_ai_model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||
api_key: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
use_fast_llm: bool = False,
|
||||
gen_ai_model_provider: str | None = None,
|
||||
gen_ai_model_version_override: str | None = None,
|
||||
) -> LLM:
|
||||
"""A single place to fetch the configured LLM for Danswer
|
||||
Also allows overriding certain LLM defaults"""
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
if gen_ai_model_version_override:
|
||||
model_version = gen_ai_model_version_override
|
||||
else:
|
||||
base, fast = get_default_llm_version()
|
||||
model_version = fast if use_fast_llm else base
|
||||
if api_key is None:
|
||||
api_key = get_gen_ai_api_key()
|
||||
# TODO: pass this in
|
||||
with get_session_context_manager() as session:
|
||||
if gen_ai_model_provider is None:
|
||||
llm_provider = fetch_default_provider(session)
|
||||
else:
|
||||
llm_provider = fetch_provider(session, gen_ai_model_provider)
|
||||
|
||||
if gen_ai_model_provider.lower() == "custom":
|
||||
return CustomModelServer(api_key=api_key, timeout=timeout)
|
||||
if not llm_provider:
|
||||
raise ValueError("No default LLM provider found")
|
||||
|
||||
if gen_ai_model_provider.lower() == "gpt4all":
|
||||
return DanswerGPT4All(
|
||||
model_version=model_version, timeout=timeout, temperature=temperature
|
||||
)
|
||||
model_name = gen_ai_model_version_override or (
|
||||
(llm_provider.fast_default_model_name or llm_provider.default_model_name)
|
||||
if use_fast_llm
|
||||
else llm_provider.default_model_name
|
||||
)
|
||||
if not model_name:
|
||||
raise ValueError("No default model name found")
|
||||
|
||||
return DefaultMultiLLM(
|
||||
model_version=model_version,
|
||||
api_key=api_key,
|
||||
return get_llm(
|
||||
provider=llm_provider.name,
|
||||
model=model_name,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
|
||||
def get_llm(
|
||||
provider: str,
|
||||
model: str,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
) -> LLM:
|
||||
return DefaultMultiLLM(
|
||||
model_provider=provider,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
custom_config=custom_config,
|
||||
)
|
||||
|
@ -4,11 +4,9 @@ from typing import Any
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import convert_lm_input_to_basic_string
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@ -38,7 +36,10 @@ except ImportError:
|
||||
|
||||
class DanswerGPT4All(LLM):
|
||||
"""Option to run an LLM locally, however this is significantly slower and
|
||||
answers tend to be much worse"""
|
||||
answers tend to be much worse
|
||||
|
||||
NOTE: currently unused, but kept for future reference / if we want to add this back.
|
||||
"""
|
||||
|
||||
@property
|
||||
def requires_warm_up(self) -> bool:
|
||||
@ -53,14 +54,14 @@ class DanswerGPT4All(LLM):
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int,
|
||||
model_version: str | None = GEN_AI_MODEL_VERSION,
|
||||
model_version: str,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.temperature = temperature
|
||||
self.gpt4all_model = GPT4All(model_version or get_default_llm_version()[0])
|
||||
self.gpt4all_model = GPT4All(model_version)
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(
|
||||
|
@ -2,6 +2,7 @@ import abc
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@ -9,6 +10,12 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
model_provider: str
|
||||
model_name: str
|
||||
temperature: float
|
||||
|
||||
|
||||
class LLM(abc.ABC):
|
||||
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
|
||||
to use these implementations to connect to a variety of LLM providers."""
|
||||
@ -22,6 +29,11 @@ class LLM(abc.ABC):
|
||||
def requires_api_key(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def config(self) -> LLMConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def log_model_configs(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
109
backend/danswer/llm/options.py
Normal file
109
backend/danswer/llm/options.py
Normal file
@ -0,0 +1,109 @@
|
||||
import litellm # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
name: str
|
||||
display_name: str | None = None
|
||||
api_key_required: bool
|
||||
api_base_required: bool
|
||||
api_version_required: bool
|
||||
custom_config_keys: list[str] | None = None
|
||||
|
||||
llm_names: list[str]
|
||||
default_model: str | None = None
|
||||
default_fast_model: str | None = None
|
||||
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"gpt-4",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
]
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named
|
||||
# models
|
||||
BEDROCK_MODEL_NAMES = [model for model in litellm.bedrock_models if "/" not in model][
|
||||
::-1
|
||||
]
|
||||
|
||||
ANTHROPIC_PROVIDER_NAME = "anthropic"
|
||||
ANTHROPIC_MODEL_NAMES = [model for model in litellm.anthropic_models][::-1]
|
||||
|
||||
AZURE_PROVIDER_NAME = "azure"
|
||||
|
||||
|
||||
_PROVIDER_TO_MODELS_MAP = {
|
||||
OPENAI_PROVIDER_NAME: OPEN_AI_MODEL_NAMES,
|
||||
BEDROCK_PROVIDER_NAME: BEDROCK_MODEL_NAMES,
|
||||
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_MODEL_NAMES,
|
||||
}
|
||||
|
||||
|
||||
def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
return [
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name="openai",
|
||||
display_name="OpenAI",
|
||||
api_key_required=True,
|
||||
api_base_required=False,
|
||||
api_version_required=False,
|
||||
custom_config_keys=[],
|
||||
llm_names=fetch_models_for_provider("openai"),
|
||||
default_model="gpt-4",
|
||||
default_fast_model="gpt-3.5-turbo",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=ANTHROPIC_PROVIDER_NAME,
|
||||
display_name="Anthropic",
|
||||
api_key_required=True,
|
||||
api_base_required=False,
|
||||
api_version_required=False,
|
||||
custom_config_keys=[],
|
||||
llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME),
|
||||
default_model="claude-3-opus-20240229",
|
||||
default_fast_model="claude-3-sonnet-20240229",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=AZURE_PROVIDER_NAME,
|
||||
display_name="Azure OpenAI",
|
||||
api_key_required=True,
|
||||
api_base_required=True,
|
||||
api_version_required=True,
|
||||
custom_config_keys=[],
|
||||
llm_names=fetch_models_for_provider(AZURE_PROVIDER_NAME),
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=BEDROCK_PROVIDER_NAME,
|
||||
display_name="AWS Bedrock",
|
||||
api_key_required=False,
|
||||
api_base_required=False,
|
||||
api_version_required=False,
|
||||
custom_config_keys=[
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_REGION_NAME",
|
||||
],
|
||||
llm_names=fetch_models_for_provider(BEDROCK_PROVIDER_NAME),
|
||||
default_model="anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
default_fast_model="anthropic.claude-3-haiku-20240307-v1:0",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def fetch_models_for_provider(provider_name: str) -> list[str]:
|
||||
return _PROVIDER_TO_MODELS_MAP.get(provider_name, [])
|
@ -1,9 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from copy import copy
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
@ -20,19 +18,12 @@ from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from tiktoken.core import Encoding
|
||||
|
||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.constants import GEN_AI_DETECTED_MODEL
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -47,34 +38,6 @@ _LLM_TOKENIZER: Any = None
|
||||
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_default_llm_version() -> tuple[str, str]:
|
||||
default_openai_model = "gpt-3.5-turbo-16k-0613"
|
||||
if GEN_AI_MODEL_VERSION:
|
||||
llm_version = GEN_AI_MODEL_VERSION
|
||||
else:
|
||||
if GEN_AI_MODEL_PROVIDER != "openai":
|
||||
logger.warning("No LLM Model Version set")
|
||||
# Either this value is unused or it will throw an error
|
||||
llm_version = default_openai_model
|
||||
else:
|
||||
kv_store = get_dynamic_config_store()
|
||||
try:
|
||||
llm_version = cast(str, kv_store.load(GEN_AI_DETECTED_MODEL))
|
||||
except ConfigNotFoundError:
|
||||
llm_version = default_openai_model
|
||||
|
||||
if FAST_GEN_AI_MODEL_VERSION:
|
||||
fast_llm_version = FAST_GEN_AI_MODEL_VERSION
|
||||
else:
|
||||
if GEN_AI_MODEL_PROVIDER == "openai":
|
||||
fast_llm_version = default_openai_model
|
||||
else:
|
||||
fast_llm_version = llm_version
|
||||
|
||||
return llm_version, fast_llm_version
|
||||
|
||||
|
||||
def get_default_llm_tokenizer() -> Encoding:
|
||||
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
|
||||
global _LLM_TOKENIZER
|
||||
@ -219,17 +182,6 @@ def check_number_of_tokens(
|
||||
return len(encode_fn(text))
|
||||
|
||||
|
||||
def get_gen_ai_api_key() -> str | None:
|
||||
# first check if the key has been provided by the UI
|
||||
try:
|
||||
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
||||
# if not provided by the UI, fallback to the env variable
|
||||
return GEN_AI_API_KEY
|
||||
|
||||
|
||||
def test_llm(llm: LLM) -> str | None:
|
||||
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||
error_msg = None
|
||||
@ -246,7 +198,7 @@ def test_llm(llm: LLM) -> str | None:
|
||||
|
||||
def get_llm_max_tokens(
|
||||
model_map: dict,
|
||||
model_name: str | None = GEN_AI_MODEL_VERSION,
|
||||
model_name: str,
|
||||
model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||
) -> int:
|
||||
"""Best effort attempt to get the max tokens for the LLM"""
|
||||
@ -254,13 +206,10 @@ def get_llm_max_tokens(
|
||||
# This is an override, so always return this
|
||||
return GEN_AI_MAX_TOKENS
|
||||
|
||||
model_name = model_name or get_default_llm_version()[0]
|
||||
|
||||
try:
|
||||
if model_provider == "openai":
|
||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||
if not model_obj:
|
||||
model_obj = model_map[model_name]
|
||||
else:
|
||||
model_obj = model_map[f"{model_provider}/{model_name}"]
|
||||
|
||||
if "max_input_tokens" in model_obj:
|
||||
return model_obj["max_input_tokens"]
|
||||
@ -277,8 +226,8 @@ def get_llm_max_tokens(
|
||||
|
||||
|
||||
def get_max_input_tokens(
|
||||
model_name: str | None = GEN_AI_MODEL_VERSION,
|
||||
model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||
model_name: str,
|
||||
model_provider: str,
|
||||
output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
) -> int:
|
||||
# NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually
|
||||
@ -289,8 +238,6 @@ def get_max_input_tokens(
|
||||
# model_map is litellm.model_cost
|
||||
litellm_model_map = litellm.model_cost
|
||||
|
||||
model_name = model_name or get_default_llm_version()[0]
|
||||
|
||||
input_toks = (
|
||||
get_llm_max_tokens(
|
||||
model_name=model_name,
|
||||
|
@ -33,8 +33,6 @@ from danswer.configs.app_configs import SECRET
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.db.chat import delete_old_default_personas
|
||||
from danswer.db.connector import create_initial_default_connector
|
||||
from danswer.db.connector_credential_pair import associate_default_cc_pair
|
||||
@ -48,9 +46,8 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import expire_index_attempts
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.dynamic_configs.port_configs import port_api_key_to_postgres
|
||||
from danswer.dynamic_configs.port_configs import port_filesystem_to_postgres
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.server.auth_check import check_router_auth
|
||||
@ -67,6 +64,8 @@ from danswer.server.features.prompt.api import basic_router as prompt_router
|
||||
from danswer.server.gpts.api import router as gpts_router
|
||||
from danswer.server.manage.administrative import router as admin_router
|
||||
from danswer.server.manage.get_state import router as state_router
|
||||
from danswer.server.manage.llm.api import admin_router as llm_admin_router
|
||||
from danswer.server.manage.llm.api import basic_router as llm_router
|
||||
from danswer.server.manage.secondary_index import router as secondary_index_router
|
||||
from danswer.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from danswer.server.manage.users import router as user_router
|
||||
@ -156,17 +155,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
logger.info("Generative AI Q&A disabled")
|
||||
else:
|
||||
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
|
||||
base, fast = get_default_llm_version()
|
||||
logger.info(f"Using LLM Model Version: {base}")
|
||||
if base != fast:
|
||||
logger.info(f"Using Fast LLM Model Version: {fast}")
|
||||
if GEN_AI_API_ENDPOINT:
|
||||
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
|
||||
|
||||
# Any additional model configs logged here
|
||||
get_default_llm().log_model_configs()
|
||||
|
||||
if MULTILINGUAL_QUERY_EXPANSION:
|
||||
logger.info(
|
||||
@ -180,6 +168,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
"Skipping port of persistent volumes. Maybe these have already been removed?"
|
||||
)
|
||||
|
||||
try:
|
||||
port_api_key_to_postgres()
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to port API keys. Exception: {e}. Continuing...")
|
||||
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
@ -281,6 +274,8 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, gpts_router)
|
||||
include_router_with_global_prefix_prepended(application, settings_router)
|
||||
include_router_with_global_prefix_prepended(application, settings_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
@ -28,9 +28,9 @@ from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import LLMConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import QuotesConfig
|
||||
from danswer.llm.factory import get_llm_for_persona
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
@ -212,7 +212,7 @@ def stream_answer_objects(
|
||||
docs=[llm_doc_from_inference_section(section) for section in top_sections],
|
||||
answer_style_config=answer_config,
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm_config=LLMConfig.from_persona(chat_session.persona),
|
||||
llm=get_llm_for_persona(persona=chat_session.persona),
|
||||
doc_relevance_list=search_pipeline.section_relevance_list,
|
||||
single_message_history=history_str,
|
||||
timeout=timeout,
|
||||
|
@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.db.chat import get_persona_by_id
|
||||
from danswer.db.chat import get_personas
|
||||
from danswer.db.chat import mark_persona_as_deleted
|
||||
@ -15,7 +14,6 @@ from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import create_update_persona
|
||||
from danswer.llm.answering.prompts.utils import build_dummy_prompt
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.features.persona.models import PromptTemplateResponse
|
||||
@ -168,49 +166,3 @@ def build_final_template_prompt(
|
||||
retrieval_disabled=retrieval_disabled,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
"""Utility endpoints for selecting which model to use for a persona.
|
||||
Putting here for now, since we have no other flows which use this."""
|
||||
|
||||
GPT_4_MODEL_VERSIONS = [
|
||||
"gpt-4",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
]
|
||||
GPT_3_5_TURBO_MODEL_VERSIONS = [
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-0125",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
]
|
||||
|
||||
|
||||
@basic_router.get("/utils/list-available-models")
|
||||
def list_available_model_versions(
|
||||
_: User | None = Depends(current_user),
|
||||
) -> list[str]:
|
||||
# currently only support selecting different models for OpenAI
|
||||
if GEN_AI_MODEL_PROVIDER != "openai":
|
||||
return []
|
||||
|
||||
return GPT_4_MODEL_VERSIONS + GPT_3_5_TURBO_MODEL_VERSIONS
|
||||
|
||||
|
||||
@basic_router.get("/utils/default-model")
|
||||
def get_default_model(
|
||||
_: User | None = Depends(current_user),
|
||||
) -> str:
|
||||
# currently only support selecting different models for OpenAI
|
||||
if GEN_AI_MODEL_PROVIDER != "openai":
|
||||
return ""
|
||||
|
||||
return get_default_llm_version()[0]
|
||||
|
@ -20,6 +20,7 @@ class CreatePersonaRequest(BaseModel):
|
||||
recency_bias: RecencyBiasSetting
|
||||
prompt_ids: list[int]
|
||||
document_set_ids: list[int]
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
starter_messages: list[StarterMessage] | None = None
|
||||
# For Private Personas, who should be able to access these
|
||||
@ -38,6 +39,7 @@ class PersonaSnapshot(BaseModel):
|
||||
num_chunks: float | None
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
llm_model_provider_override: str | None
|
||||
llm_model_version_override: str | None
|
||||
starter_messages: list[StarterMessage] | None
|
||||
default_persona: bool
|
||||
@ -66,6 +68,7 @@ class PersonaSnapshot(BaseModel):
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
starter_messages=persona.starter_messages,
|
||||
default_persona=persona.default_persona,
|
||||
|
@ -1,5 +1,4 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@ -16,13 +15,9 @@ from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from danswer.configs.app_configs import TOKEN_BUDGET_GLOBALLY_ENABLED
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import ENABLE_TOKEN_BUDGET
|
||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.constants import GEN_AI_DETECTED_MODEL
|
||||
from danswer.configs.constants import TOKEN_BUDGET
|
||||
from danswer.configs.constants import TOKEN_BUDGET_SETTINGS
|
||||
from danswer.configs.constants import TOKEN_BUDGET_TIME_PERIOD
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.engine import get_session
|
||||
@ -35,18 +30,13 @@ from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_default_llm_version
|
||||
from danswer.llm.utils import get_gen_ai_api_key
|
||||
from danswer.llm.utils import test_llm
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.manage.models import BoostDoc
|
||||
from danswer.server.manage.models import BoostUpdateRequest
|
||||
from danswer.server.manage.models import HiddenUpdateRequest
|
||||
from danswer.server.models import ApiKey
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
logger = setup_logger()
|
||||
@ -123,52 +113,6 @@ def document_hidden_update(
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
def _validate_llm_key(genai_api_key: str | None) -> None:
|
||||
# Checking new API key, may change things for this if kv-store is updated
|
||||
get_default_llm_version.cache_clear()
|
||||
kv_store = get_dynamic_config_store()
|
||||
try:
|
||||
kv_store.delete(GEN_AI_DETECTED_MODEL)
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
||||
gpt_4_version = "gpt-4" # 32k is not available to most people
|
||||
gpt4_llm = None
|
||||
try:
|
||||
llm = get_default_llm(api_key=genai_api_key, timeout=10)
|
||||
if GEN_AI_MODEL_PROVIDER == "openai" and not GEN_AI_MODEL_VERSION:
|
||||
gpt4_llm = get_default_llm(
|
||||
gen_ai_model_version_override=gpt_4_version,
|
||||
api_key=genai_api_key,
|
||||
timeout=10,
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
return
|
||||
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))]
|
||||
if gpt4_llm:
|
||||
functions_with_args.append((test_llm, (gpt4_llm,)))
|
||||
|
||||
parallel_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
|
||||
error_msg = parallel_results[0]
|
||||
|
||||
if error_msg:
|
||||
if genai_api_key is None:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
# Mark check as successful
|
||||
curr_time = datetime.now(tz=timezone.utc)
|
||||
kv_store.store(GEN_AI_KEY_CHECK_TIME, curr_time.timestamp())
|
||||
|
||||
# None for no errors
|
||||
if gpt4_llm and parallel_results[1] is None:
|
||||
kv_store.store(GEN_AI_DETECTED_MODEL, gpt_4_version)
|
||||
|
||||
|
||||
@router.get("/admin/genai-api-key/validate")
|
||||
def validate_existing_genai_api_key(
|
||||
_: User = Depends(current_admin_user),
|
||||
@ -187,46 +131,18 @@ def validate_existing_genai_api_key(
|
||||
# First time checking the key, nothing unusual
|
||||
pass
|
||||
|
||||
genai_api_key = get_gen_ai_api_key()
|
||||
_validate_llm_key(genai_api_key)
|
||||
|
||||
|
||||
@router.get("/admin/genai-api-key", response_model=ApiKey)
|
||||
def get_gen_ai_api_key_from_dynamic_config_store(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> ApiKey:
|
||||
"""
|
||||
NOTE: Only gets value from dynamic config store as to not expose env variables.
|
||||
"""
|
||||
try:
|
||||
# only get last 4 characters of key to not expose full key
|
||||
return ApiKey(
|
||||
api_key=cast(
|
||||
str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)
|
||||
)[-4:]
|
||||
)
|
||||
except ConfigNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
llm = get_default_llm(timeout=10)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="LLM not setup")
|
||||
|
||||
error = test_llm(llm)
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
|
||||
@router.put("/admin/genai-api-key")
|
||||
def store_genai_api_key(
|
||||
request: ApiKey,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
if not request.api_key:
|
||||
raise HTTPException(400, "No API key provided")
|
||||
|
||||
_validate_llm_key(request.api_key)
|
||||
|
||||
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
|
||||
|
||||
|
||||
@router.delete("/admin/genai-api-key")
|
||||
def delete_genai_api_key(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
get_dynamic_config_store().delete(GEN_AI_API_KEY_STORAGE_KEY)
|
||||
# Mark check as successful
|
||||
curr_time = datetime.now(tz=timezone.utc)
|
||||
kv_store.store(GEN_AI_KEY_CHECK_TIME, curr_time.timestamp())
|
||||
|
||||
|
||||
@router.post("/admin/deletion-attempt")
|
||||
|
157
backend/danswer/server/manage/llm/api.py
Normal file
157
backend/danswer/server/manage/llm/api.py
Normal file
@ -0,0 +1,157 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.llm import remove_llm_provider
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.db.models import User
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_llm
|
||||
from danswer.llm.options import fetch_available_well_known_llms
|
||||
from danswer.llm.options import WellKnownLLMProviderDescriptor
|
||||
from danswer.llm.utils import test_llm
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderDescriptor
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.server.manage.llm.models import TestLLMRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/llm")
|
||||
basic_router = APIRouter(prefix="/llm")
|
||||
|
||||
|
||||
@admin_router.get("/built-in/options")
|
||||
def fetch_llm_options(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[WellKnownLLMProviderDescriptor]:
|
||||
return fetch_available_well_known_llms()
|
||||
|
||||
|
||||
@admin_router.post("/test")
|
||||
def test_llm_configuration(
|
||||
test_llm_request: TestLLMRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.default_model_name,
|
||||
api_key=test_llm_request.api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
custom_config=test_llm_request.custom_config,
|
||||
)
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))]
|
||||
|
||||
if (
|
||||
test_llm_request.default_fast_model_name
|
||||
and test_llm_request.default_fast_model_name
|
||||
!= test_llm_request.default_model_name
|
||||
):
|
||||
fast_llm = get_llm(
|
||||
provider=test_llm_request.provider,
|
||||
model=test_llm_request.default_fast_model_name,
|
||||
api_key=test_llm_request.api_key,
|
||||
api_base=test_llm_request.api_base,
|
||||
api_version=test_llm_request.api_version,
|
||||
custom_config=test_llm_request.custom_config,
|
||||
)
|
||||
functions_with_args.append((test_llm, (fast_llm,)))
|
||||
|
||||
parallel_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
error = parallel_results[0] or (
|
||||
parallel_results[1] if len(parallel_results) > 1 else None
|
||||
)
|
||||
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
|
||||
|
||||
@admin_router.post("/test/default")
|
||||
def test_default_provider(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
fast_llm = get_default_llm(use_fast_llm=True)
|
||||
except ValueError:
|
||||
logger.exception("Failed to fetch default LLM Provider")
|
||||
raise HTTPException(status_code=400, detail="No LLM Provider setup")
|
||||
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(test_llm, (llm,)),
|
||||
(test_llm, (fast_llm,)),
|
||||
]
|
||||
parallel_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
error = parallel_results[0] or (
|
||||
parallel_results[1] if len(parallel_results) > 1 else None
|
||||
)
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
|
||||
|
||||
@admin_router.get("/provider")
|
||||
def list_llm_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[FullLLMProvider]:
|
||||
return [
|
||||
FullLLMProvider.from_model(llm_provider_model)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||
]
|
||||
|
||||
|
||||
@admin_router.put("/provider")
|
||||
def put_llm_provider(
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FullLLMProvider:
|
||||
return upsert_llm_provider(db_session, llm_provider)
|
||||
|
||||
|
||||
@admin_router.delete("/provider/{provider_id}")
|
||||
def delete_llm_provider(
|
||||
provider_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
remove_llm_provider(db_session, provider_id)
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default")
|
||||
def set_provider_as_default(
|
||||
provider_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_provider(db_session, provider_id)
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
|
||||
|
||||
@basic_router.get("/provider")
|
||||
def list_llm_provider_basics(
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
return [
|
||||
LLMProviderDescriptor.from_model(llm_provider_model)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||
]
|
89
backend/danswer/server/manage/llm/models.py
Normal file
89
backend/danswer/server/manage/llm/models.py
Normal file
@ -0,0 +1,89 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.llm.options import fetch_models_for_provider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
|
||||
|
||||
class TestLLMRequest(BaseModel):
|
||||
# provider level
|
||||
provider: str
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
# model level
|
||||
default_model_name: str
|
||||
default_fast_model_name: str | None = None
|
||||
|
||||
|
||||
class LLMProviderDescriptor(BaseModel):
|
||||
"""A descriptor for an LLM provider that can be safely viewed by
|
||||
non-admin users. Used when giving a list of available LLMs."""
|
||||
|
||||
name: str
|
||||
model_names: list[str]
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
is_default_provider: bool | None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, llm_provider_model: "LLMProviderModel"
|
||||
) -> "LLMProviderDescriptor":
|
||||
return cls(
|
||||
name=llm_provider_model.name,
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
model_names=(
|
||||
llm_provider_model.model_names
|
||||
or fetch_models_for_provider(llm_provider_model.name)
|
||||
or [llm_provider_model.default_model_name]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LLMProvider(BaseModel):
|
||||
name: str
|
||||
api_key: str | None
|
||||
api_base: str | None
|
||||
api_version: str | None
|
||||
custom_config: dict[str, str] | None
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
# should only be used for a "custom" provider
|
||||
# for default providers, the built-in model names are used
|
||||
model_names: list[str] | None
|
||||
|
||||
|
||||
class FullLLMProvider(LLMProvider):
|
||||
id: int
|
||||
is_default_provider: bool | None
|
||||
model_names: list[str]
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider":
|
||||
return cls(
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
api_key=llm_provider_model.api_key,
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
model_names=(
|
||||
llm_provider_model.model_names
|
||||
or fetch_models_for_provider(llm_provider_model.name)
|
||||
or [llm_provider_model.default_model_name]
|
||||
),
|
||||
)
|
@ -3,6 +3,7 @@ alembic==1.10.4
|
||||
asyncpg==0.27.0
|
||||
atlassian-python-api==3.37.0
|
||||
beautifulsoup4==4.12.2
|
||||
boto3==1.34.84
|
||||
celery==5.3.4
|
||||
chardet==5.2.0
|
||||
dask==2023.8.1
|
||||
|
7
web/package-lock.json
generated
7
web/package-lock.json
generated
@ -15,12 +15,14 @@
|
||||
"@radix-ui/react-popover": "^1.0.7",
|
||||
"@tremor/react": "^3.9.2",
|
||||
"@types/js-cookie": "^3.0.3",
|
||||
"@types/lodash": "^4.17.0",
|
||||
"@types/node": "18.15.11",
|
||||
"@types/react": "18.0.32",
|
||||
"@types/react-dom": "18.0.11",
|
||||
"autoprefixer": "^10.4.14",
|
||||
"formik": "^2.2.9",
|
||||
"js-cookie": "^3.0.5",
|
||||
"lodash": "^4.17.21",
|
||||
"mdast-util-find-and-replace": "^3.0.1",
|
||||
"next": "^14.1.0",
|
||||
"postcss": "^8.4.31",
|
||||
@ -1740,6 +1742,11 @@
|
||||
"integrity": "sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@types/lodash": {
|
||||
"version": "4.17.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.0.tgz",
|
||||
"integrity": "sha512-t7dhREVv6dbNj0q17X12j7yDG4bD/DHYX7o5/DbDxobP0HnGPgpRz2Ej77aL7TZT3DSw13fqUTj8J4mMnqa7WA=="
|
||||
},
|
||||
"node_modules/@types/mdast": {
|
||||
"version": "4.0.3",
|
||||
"resolved": "https://registry.npmjs.org/@types/mdast/-/mdast-4.0.3.tgz",
|
||||
|
@ -16,12 +16,14 @@
|
||||
"@radix-ui/react-popover": "^1.0.7",
|
||||
"@tremor/react": "^3.9.2",
|
||||
"@types/js-cookie": "^3.0.3",
|
||||
"@types/lodash": "^4.17.0",
|
||||
"@types/node": "18.15.11",
|
||||
"@types/react": "18.0.32",
|
||||
"@types/react-dom": "18.0.11",
|
||||
"autoprefixer": "^10.4.14",
|
||||
"formik": "^2.2.9",
|
||||
"js-cookie": "^3.0.5",
|
||||
"lodash": "^4.17.21",
|
||||
"mdast-util-find-and-replace": "^3.0.1",
|
||||
"next": "^14.1.0",
|
||||
"postcss": "^8.4.31",
|
||||
|
@ -31,6 +31,15 @@ import { Bubble } from "@/components/Bubble";
|
||||
import { GroupsIcon } from "@/components/icons/icons";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
|
||||
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
|
||||
import { FullLLMProvider } from "../models/llm/interfaces";
|
||||
import { Option } from "@/components/Dropdown";
|
||||
|
||||
const DEFAULT_LLM_PROVIDER_TO_DISPLAY_NAME: Record<string, string> = {
|
||||
openai: "OpenAI",
|
||||
azure: "Azure OpenAI",
|
||||
anthropic: "Anthropic",
|
||||
bedrock: "AWS Bedrock",
|
||||
};
|
||||
|
||||
function Label({ children }: { children: string | JSX.Element }) {
|
||||
return (
|
||||
@ -46,20 +55,18 @@ export function AssistantEditor({
|
||||
existingPersona,
|
||||
ccPairs,
|
||||
documentSets,
|
||||
llmOverrideOptions,
|
||||
defaultLLM,
|
||||
user,
|
||||
defaultPublic,
|
||||
redirectType,
|
||||
llmProviders,
|
||||
}: {
|
||||
existingPersona?: Persona | null;
|
||||
ccPairs: CCPairBasicInfo[];
|
||||
documentSets: DocumentSet[];
|
||||
llmOverrideOptions: string[];
|
||||
defaultLLM: string;
|
||||
user: User | null;
|
||||
defaultPublic: boolean;
|
||||
redirectType: SuccessfulPersonaUpdateRedirectType;
|
||||
llmProviders: FullLLMProvider[];
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const { popup, setPopup } = usePopup();
|
||||
@ -98,6 +105,21 @@ export function AssistantEditor({
|
||||
}
|
||||
}, []);
|
||||
|
||||
const defaultLLM = llmProviders.find(
|
||||
(llmProvider) => llmProvider.is_default_provider
|
||||
)?.default_model_name;
|
||||
|
||||
const modelOptionsByProvider = new Map<string, Option<string>[]>();
|
||||
llmProviders.forEach((llmProvider) => {
|
||||
const providerOptions = llmProvider.model_names.map((modelName) => {
|
||||
return {
|
||||
name: modelName,
|
||||
value: modelName,
|
||||
};
|
||||
});
|
||||
modelOptionsByProvider.set(llmProvider.name, providerOptions);
|
||||
});
|
||||
|
||||
return (
|
||||
<div>
|
||||
{popup}
|
||||
@ -118,6 +140,8 @@ export function AssistantEditor({
|
||||
include_citations:
|
||||
existingPersona?.prompts[0]?.include_citations ?? true,
|
||||
llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false,
|
||||
llm_model_provider_override:
|
||||
existingPersona?.llm_model_provider_override ?? null,
|
||||
llm_model_version_override:
|
||||
existingPersona?.llm_model_version_override ?? null,
|
||||
starter_messages: existingPersona?.starter_messages ?? [],
|
||||
@ -139,6 +163,7 @@ export function AssistantEditor({
|
||||
include_citations: Yup.boolean().required(),
|
||||
llm_relevance_filter: Yup.boolean().required(),
|
||||
llm_model_version_override: Yup.string().nullable(),
|
||||
llm_model_provider_override: Yup.string().nullable(),
|
||||
starter_messages: Yup.array().of(
|
||||
Yup.object().shape({
|
||||
name: Yup.string().required(),
|
||||
@ -178,6 +203,18 @@ export function AssistantEditor({
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
values.llm_model_provider_override &&
|
||||
!values.llm_model_version_override
|
||||
) {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message:
|
||||
"Must select a model if a non-default LLM provider is chosen.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
formikHelpers.setSubmitting(true);
|
||||
|
||||
// if disable_retrieval is set, set num_chunks to 0
|
||||
@ -428,7 +465,7 @@ export function AssistantEditor({
|
||||
</>
|
||||
)}
|
||||
|
||||
{llmOverrideOptions.length > 0 && defaultLLM && (
|
||||
{llmProviders.length > 0 && (
|
||||
<>
|
||||
<HidableSection
|
||||
sectionTitle="[Advanced] Model Selection"
|
||||
@ -452,17 +489,50 @@ export function AssistantEditor({
|
||||
.
|
||||
</Text>
|
||||
|
||||
<div className="w-96">
|
||||
<SelectorFormField
|
||||
name="llm_model_version_override"
|
||||
options={llmOverrideOptions.map((llmOption) => {
|
||||
return {
|
||||
name: llmOption,
|
||||
value: llmOption,
|
||||
};
|
||||
})}
|
||||
includeDefault={true}
|
||||
/>
|
||||
<div className="flex mt-6">
|
||||
<div className="w-96">
|
||||
<SubLabel>LLM Provider</SubLabel>
|
||||
<SelectorFormField
|
||||
name="llm_model_provider_override"
|
||||
options={llmProviders.map((llmProvider) => ({
|
||||
name:
|
||||
DEFAULT_LLM_PROVIDER_TO_DISPLAY_NAME[
|
||||
llmProvider.name
|
||||
] || llmProvider.name,
|
||||
value: llmProvider.name,
|
||||
}))}
|
||||
includeDefault={true}
|
||||
onSelect={(selected) => {
|
||||
if (
|
||||
selected !== values.llm_model_provider_override
|
||||
) {
|
||||
setFieldValue(
|
||||
"llm_model_version_override",
|
||||
null
|
||||
);
|
||||
}
|
||||
setFieldValue(
|
||||
"llm_model_provider_override",
|
||||
selected
|
||||
);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{values.llm_model_provider_override && (
|
||||
<div className="w-96 ml-4">
|
||||
<SubLabel>Model</SubLabel>
|
||||
<SelectorFormField
|
||||
name="llm_model_version_override"
|
||||
options={
|
||||
modelOptionsByProvider.get(
|
||||
values.llm_model_provider_override
|
||||
) || []
|
||||
}
|
||||
maxHeight="max-h-72"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
</HidableSection>
|
||||
|
@ -30,6 +30,7 @@ export interface Persona {
|
||||
num_chunks?: number;
|
||||
llm_relevance_filter?: boolean;
|
||||
llm_filter_extraction?: boolean;
|
||||
llm_model_provider_override?: string;
|
||||
llm_model_version_override?: string;
|
||||
starter_messages: StarterMessage[] | null;
|
||||
default_persona: boolean;
|
||||
|
@ -10,6 +10,7 @@ interface PersonaCreationRequest {
|
||||
include_citations: boolean;
|
||||
is_public: boolean;
|
||||
llm_relevance_filter: boolean | null;
|
||||
llm_model_provider_override: string | null;
|
||||
llm_model_version_override: string | null;
|
||||
starter_messages: StarterMessage[] | null;
|
||||
users?: string[];
|
||||
@ -28,6 +29,7 @@ interface PersonaUpdateRequest {
|
||||
include_citations: boolean;
|
||||
is_public: boolean;
|
||||
llm_relevance_filter: boolean | null;
|
||||
llm_model_provider_override: string | null;
|
||||
llm_model_version_override: string | null;
|
||||
starter_messages: StarterMessage[] | null;
|
||||
users?: string[];
|
||||
@ -117,6 +119,7 @@ function buildPersonaAPIBody(
|
||||
recency_bias: "base_decay",
|
||||
prompt_ids: [promptId],
|
||||
document_set_ids,
|
||||
llm_model_provider_override: creationRequest.llm_model_provider_override,
|
||||
llm_model_version_override: creationRequest.llm_model_version_override,
|
||||
starter_messages: creationRequest.starter_messages,
|
||||
users,
|
||||
|
@ -1,13 +1,10 @@
|
||||
"use client";
|
||||
|
||||
import { LoadingAnimation, ThreeDotsLoader } from "@/components/Loading";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { KeyIcon, TrashIcon } from "@/components/icons/icons";
|
||||
import { ApiKeyForm } from "@/components/openai/ApiKeyForm";
|
||||
import { GEN_AI_API_KEY_URL } from "@/components/openai/constants";
|
||||
import { errorHandlingFetcher, fetcher } from "@/lib/fetcher";
|
||||
import { Button, Card, Divider, Text, Title } from "@tremor/react";
|
||||
import { FiCpu, FiPackage } from "react-icons/fi";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { Button, Card, Text, Title } from "@tremor/react";
|
||||
import { FiPackage } from "react-icons/fi";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { ModelOption, ModelSelector } from "./ModelSelector";
|
||||
import { useState } from "react";
|
||||
|
450
web/src/app/admin/models/llm/CustomLLMProviderUpdateForm.tsx
Normal file
450
web/src/app/admin/models/llm/CustomLLMProviderUpdateForm.tsx
Normal file
@ -0,0 +1,450 @@
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { Button, Divider, Text } from "@tremor/react";
|
||||
import {
|
||||
ArrayHelpers,
|
||||
ErrorMessage,
|
||||
Field,
|
||||
FieldArray,
|
||||
Form,
|
||||
Formik,
|
||||
} from "formik";
|
||||
import { FiPlus, FiTrash, FiX } from "react-icons/fi";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
||||
import {
|
||||
Label,
|
||||
SubLabel,
|
||||
TextArrayField,
|
||||
TextFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { FullLLMProvider } from "./interfaces";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
|
||||
function customConfigProcessing(customConfigsList: [string, string][]) {
|
||||
const customConfig: { [key: string]: string } = {};
|
||||
customConfigsList.forEach(([key, value]) => {
|
||||
customConfig[key] = value;
|
||||
});
|
||||
return customConfig;
|
||||
}
|
||||
|
||||
export function CustomLLMProviderUpdateForm({
|
||||
onClose,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setPopup,
|
||||
}: {
|
||||
onClose: () => void;
|
||||
existingLlmProvider?: FullLLMProvider;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
setPopup?: (popup: PopupSpec) => void;
|
||||
}) {
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const [testError, setTestError] = useState<string>("");
|
||||
const [isTestSuccessful, setTestSuccessful] = useState(
|
||||
existingLlmProvider ? true : false
|
||||
);
|
||||
|
||||
// Define the initial values based on the provider's requirements
|
||||
const initialValues = {
|
||||
name: existingLlmProvider?.name ?? "",
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? "",
|
||||
api_version: existingLlmProvider?.api_version ?? "",
|
||||
default_model_name: existingLlmProvider?.default_model_name ?? null,
|
||||
default_fast_model_name:
|
||||
existingLlmProvider?.fast_default_model_name ?? null,
|
||||
model_names: existingLlmProvider?.model_names ?? [],
|
||||
custom_config_list: existingLlmProvider?.custom_config
|
||||
? Object.entries(existingLlmProvider.custom_config)
|
||||
: [],
|
||||
};
|
||||
|
||||
const [validatedConfig, setValidatedConfig] = useState(
|
||||
existingLlmProvider ? initialValues : null
|
||||
);
|
||||
|
||||
// Setup validation schema if required
|
||||
const validationSchema = Yup.object({
|
||||
name: Yup.string().required("Name is required"),
|
||||
api_key: Yup.string(),
|
||||
api_base: Yup.string(),
|
||||
api_version: Yup.string(),
|
||||
model_names: Yup.array(Yup.string().required("Model name is required")),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
default_fast_model_name: Yup.string().nullable(),
|
||||
custom_config_list: Yup.array(),
|
||||
});
|
||||
|
||||
return (
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
// hijack this to re-enable testing on any change
|
||||
validate={(values) => {
|
||||
if (!isEqual(values, validatedConfig)) {
|
||||
setTestSuccessful(false);
|
||||
}
|
||||
}}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
setSubmitting(true);
|
||||
|
||||
if (!isTestSuccessful) {
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
if (values.model_names.length === 0) {
|
||||
const fullErrorMsg = "At least one model name is required";
|
||||
if (setPopup) {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: fullErrorMsg,
|
||||
});
|
||||
} else {
|
||||
alert(fullErrorMsg);
|
||||
}
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await fetch(LLM_PROVIDERS_ADMIN_URL, {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...values,
|
||||
custom_config: customConfigProcessing(values.custom_config_list),
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
const fullErrorMsg = existingLlmProvider
|
||||
? `Failed to update provider: ${errorMsg}`
|
||||
: `Failed to enable provider: ${errorMsg}`;
|
||||
if (setPopup) {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: fullErrorMsg,
|
||||
});
|
||||
} else {
|
||||
alert(fullErrorMsg);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (shouldMarkAsDefault) {
|
||||
const newLlmProvider = (await response.json()) as FullLLMProvider;
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
if (!setDefaultResponse.ok) {
|
||||
const errorMsg = (await setDefaultResponse.json()).detail;
|
||||
const fullErrorMsg = `Failed to set provider as default: ${errorMsg}`;
|
||||
if (setPopup) {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: fullErrorMsg,
|
||||
});
|
||||
} else {
|
||||
alert(fullErrorMsg);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onClose();
|
||||
|
||||
const successMsg = existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!";
|
||||
if (setPopup) {
|
||||
setPopup({
|
||||
type: "success",
|
||||
message: successMsg,
|
||||
});
|
||||
} else {
|
||||
alert(successMsg);
|
||||
}
|
||||
|
||||
setSubmitting(false);
|
||||
}}
|
||||
>
|
||||
{({ values }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="name"
|
||||
label="Provider Name"
|
||||
subtext={
|
||||
<>
|
||||
Should be one of the providers listed at{" "}
|
||||
<a
|
||||
target="_blank"
|
||||
href="https://docs.litellm.ai/docs/providers"
|
||||
className="text-link"
|
||||
>
|
||||
https://docs.litellm.ai/docs/providers
|
||||
</a>
|
||||
.
|
||||
</>
|
||||
}
|
||||
placeholder="Name of the custom provider"
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
|
||||
<SubLabel>
|
||||
Fill in the following as is needed. Refer to the LiteLLM
|
||||
documentation for the model provider name specified above in order
|
||||
to determine which fields are required.
|
||||
</SubLabel>
|
||||
|
||||
<TextFormField
|
||||
name="api_key"
|
||||
label="[Optional] API Key"
|
||||
placeholder="API Key"
|
||||
type="password"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="api_base"
|
||||
label="[Optional] API Base"
|
||||
placeholder="API Base"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="api_version"
|
||||
label="[Optional] API Version"
|
||||
placeholder="API Version"
|
||||
/>
|
||||
|
||||
<Label>[Optional] Custom Configs</Label>
|
||||
<SubLabel>
|
||||
<>
|
||||
<div>
|
||||
Additional configurations needed by the model provider. Are
|
||||
passed to litellm via environment variables.
|
||||
</div>
|
||||
|
||||
<div className="mt-2">
|
||||
For example, when configuring the Cloudflare provider, you would
|
||||
need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
|
||||
Cloudflare account ID as the value.
|
||||
</div>
|
||||
</>
|
||||
</SubLabel>
|
||||
|
||||
<FieldArray
|
||||
name="custom_config_list"
|
||||
render={(arrayHelpers: ArrayHelpers<any[]>) => (
|
||||
<div>
|
||||
{values.custom_config_list.map((_, index) => {
|
||||
return (
|
||||
<div key={index} className={index === 0 ? "mt-2" : "mt-6"}>
|
||||
<div className="flex">
|
||||
<div className="w-full mr-6 border border-border p-3 rounded">
|
||||
<div>
|
||||
<Label>Key</Label>
|
||||
<Field
|
||||
name={`custom_config_list[${index}][0]`}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`custom_config_list[${index}][0]`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="mt-3">
|
||||
<Label>Value</Label>
|
||||
<Field
|
||||
name={`custom_config_list[${index}][1]`}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
bg-background
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mr-4
|
||||
`}
|
||||
autoComplete="off"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={`custom_config_list[${index}][1]`}
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="my-auto">
|
||||
<FiX
|
||||
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
|
||||
onClick={() => arrayHelpers.remove(index)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
|
||||
<Button
|
||||
onClick={() => {
|
||||
arrayHelpers.push(["", ""]);
|
||||
}}
|
||||
className="mt-3"
|
||||
color="green"
|
||||
size="xs"
|
||||
type="button"
|
||||
icon={FiPlus}
|
||||
>
|
||||
Add New
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
|
||||
<TextArrayField
|
||||
name="model_names"
|
||||
label="Model Names"
|
||||
values={values}
|
||||
subtext={`List the individual models that you want to make
|
||||
available as a part of this provider. At least one must be specified.
|
||||
As an example, for OpenAI one model might be "gpt-4".`}
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
|
||||
<TextFormField
|
||||
name="default_model_name"
|
||||
subtext={`
|
||||
The model to use by default for this provider unless
|
||||
otherwise specified. Must be one of the models listed
|
||||
above.`}
|
||||
label="Default Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
name="default_fast_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If not set, will use
|
||||
the Default Model configured above.`}
|
||||
label="[Optional] Fast Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
|
||||
<div>
|
||||
{/* NOTE: this is above the test button to make sure it's visible */}
|
||||
{!isTestSuccessful && testError && (
|
||||
<Text className="text-error mt-2">{testError}</Text>
|
||||
)}
|
||||
{isTestSuccessful && (
|
||||
<Text className="text-success mt-2">
|
||||
Test successful! LLM provider is ready to go.
|
||||
</Text>
|
||||
)}
|
||||
|
||||
<div className="flex w-full mt-4">
|
||||
{isTestSuccessful ? (
|
||||
<Button type="submit" size="xs">
|
||||
{existingLlmProvider ? "Update" : "Enable"}
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
type="button"
|
||||
size="xs"
|
||||
disabled={isTesting}
|
||||
onClick={async () => {
|
||||
setIsTesting(true);
|
||||
console.log(values.custom_config_list);
|
||||
|
||||
const response = await fetch("/api/admin/llm/test", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: values.name,
|
||||
custom_config: customConfigProcessing(
|
||||
values.custom_config_list
|
||||
),
|
||||
...values,
|
||||
}),
|
||||
});
|
||||
setIsTesting(false);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
setTestError(errorMsg);
|
||||
return;
|
||||
}
|
||||
|
||||
setTestSuccessful(true);
|
||||
setValidatedConfig(values);
|
||||
}}
|
||||
>
|
||||
{isTesting ? <LoadingAnimation text="Testing" /> : "Test"}
|
||||
</Button>
|
||||
)}
|
||||
{existingLlmProvider && (
|
||||
<Button
|
||||
type="button"
|
||||
color="red"
|
||||
className="ml-3"
|
||||
size="xs"
|
||||
icon={FiTrash}
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
alert(`Failed to delete provider: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onClose();
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
234
web/src/app/admin/models/llm/LLMConfiguration.tsx
Normal file
234
web/src/app/admin/models/llm/LLMConfiguration.tsx
Normal file
@ -0,0 +1,234 @@
|
||||
"use client";
|
||||
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { useState } from "react";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { Badge, Button, Text, Title } from "@tremor/react";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
||||
import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm";
|
||||
|
||||
function LLMProviderUpdateModal({
|
||||
llmProviderDescriptor,
|
||||
onClose,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setPopup,
|
||||
}: {
|
||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null;
|
||||
onClose: () => void;
|
||||
existingLlmProvider?: FullLLMProvider;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
setPopup?: (popup: PopupSpec) => void;
|
||||
}) {
|
||||
const providerName =
|
||||
llmProviderDescriptor?.display_name ||
|
||||
llmProviderDescriptor?.name ||
|
||||
existingLlmProvider?.name ||
|
||||
"Custom LLM Provider";
|
||||
return (
|
||||
<Modal title={`Setup ${providerName}`} onOutsideClick={() => onClose()}>
|
||||
<div className="max-h-[70vh] overflow-y-auto px-4">
|
||||
{llmProviderDescriptor ? (
|
||||
<LLMProviderUpdateForm
|
||||
llmProviderDescriptor={llmProviderDescriptor}
|
||||
onClose={onClose}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
shouldMarkAsDefault={shouldMarkAsDefault}
|
||||
setPopup={setPopup}
|
||||
/>
|
||||
) : (
|
||||
<CustomLLMProviderUpdateForm
|
||||
onClose={onClose}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
shouldMarkAsDefault={shouldMarkAsDefault}
|
||||
setPopup={setPopup}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
|
||||
function LLMProviderDisplay({
|
||||
llmProviderDescriptor,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
}: {
|
||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null;
|
||||
existingLlmProvider?: FullLLMProvider;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
}) {
|
||||
const [formIsVisible, setFormIsVisible] = useState(false);
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const providerName =
|
||||
llmProviderDescriptor?.display_name ||
|
||||
llmProviderDescriptor?.name ||
|
||||
existingLlmProvider?.name;
|
||||
return (
|
||||
<div>
|
||||
{popup}
|
||||
<div className="border border-border p-3 rounded w-96 flex shadow-md">
|
||||
<div className="my-auto">
|
||||
<div className="font-bold">{providerName} </div>
|
||||
{existingLlmProvider && !existingLlmProvider.is_default_provider && (
|
||||
<div
|
||||
className="text-xs text-link cursor-pointer"
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to set provider as default: ${errorMsg}`,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
setPopup({
|
||||
type: "success",
|
||||
message: "Provider set as default successfully!",
|
||||
});
|
||||
}}
|
||||
>
|
||||
Set as default
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{existingLlmProvider && (
|
||||
<div className="my-auto">
|
||||
{existingLlmProvider.is_default_provider ? (
|
||||
<Badge color="orange" className="ml-2" size="xs">
|
||||
Default
|
||||
</Badge>
|
||||
) : (
|
||||
<Badge color="green" className="ml-2" size="xs">
|
||||
Enabled
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="ml-auto">
|
||||
<Button
|
||||
color={existingLlmProvider ? "green" : "blue"}
|
||||
size="xs"
|
||||
onClick={() => setFormIsVisible(true)}
|
||||
>
|
||||
{existingLlmProvider ? "Edit" : "Set up"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
{formIsVisible && (
|
||||
<LLMProviderUpdateModal
|
||||
llmProviderDescriptor={llmProviderDescriptor}
|
||||
onClose={() => setFormIsVisible(false)}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
shouldMarkAsDefault={shouldMarkAsDefault}
|
||||
setPopup={setPopup}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function AddCustomLLMProvider({}) {
|
||||
const [formIsVisible, setFormIsVisible] = useState(false);
|
||||
|
||||
if (formIsVisible) {
|
||||
return (
|
||||
<Modal
|
||||
title={`Setup Custom LLM Provider`}
|
||||
onOutsideClick={() => setFormIsVisible(false)}
|
||||
>
|
||||
<div className="max-h-[70vh] overflow-y-auto px-4">
|
||||
<CustomLLMProviderUpdateForm
|
||||
onClose={() => setFormIsVisible(false)}
|
||||
/>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Button size="xs" onClick={() => setFormIsVisible(true)}>
|
||||
Add Custom LLM Provider
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
|
||||
export function LLMConfiguration() {
|
||||
const { data: llmProviderDescriptors } = useSWR<
|
||||
WellKnownLLMProviderDescriptor[]
|
||||
>("/api/admin/llm/built-in/options", errorHandlingFetcher);
|
||||
const { data: existingLlmProviders } = useSWR<FullLLMProvider[]>(
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
if (!llmProviderDescriptors || !existingLlmProviders) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
const wellKnownLLMProviderNames = llmProviderDescriptors.map(
|
||||
(llmProviderDescriptor) => llmProviderDescriptor.name
|
||||
);
|
||||
const customLLMProviders = existingLlmProviders.filter(
|
||||
(llmProvider) => !wellKnownLLMProviderNames.includes(llmProvider.name)
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Text className="mb-4">
|
||||
If multiple LLM providers are enabled, the default provider will be used
|
||||
for all "Default" Personas. For user-created Personas, you can
|
||||
select the LLM provider/model that best fits the use case!
|
||||
</Text>
|
||||
|
||||
<Title className="mb-2">Default Providers</Title>
|
||||
<div className="gap-y-4 flex flex-col">
|
||||
{llmProviderDescriptors.map((llmProviderDescriptor) => {
|
||||
const existingLlmProvider = existingLlmProviders.find(
|
||||
(llmProvider) => llmProvider.name === llmProviderDescriptor.name
|
||||
);
|
||||
|
||||
return (
|
||||
<LLMProviderDisplay
|
||||
key={llmProviderDescriptor.name}
|
||||
llmProviderDescriptor={llmProviderDescriptor}
|
||||
existingLlmProvider={existingLlmProvider}
|
||||
shouldMarkAsDefault={existingLlmProviders.length === 0}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
<Title className="mb-2 mt-4">Custom Providers</Title>
|
||||
{customLLMProviders.length > 0 && (
|
||||
<div className="gap-y-4 flex flex-col mb-4">
|
||||
{customLLMProviders.map((llmProvider) => (
|
||||
<LLMProviderDisplay
|
||||
key={llmProvider.id}
|
||||
llmProviderDescriptor={null}
|
||||
existingLlmProvider={llmProvider}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<AddCustomLLMProvider />
|
||||
</>
|
||||
);
|
||||
}
|
347
web/src/app/admin/models/llm/LLMProviderUpdateForm.tsx
Normal file
347
web/src/app/admin/models/llm/LLMProviderUpdateForm.tsx
Normal file
@ -0,0 +1,347 @@
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { Button, Divider, Text } from "@tremor/react";
|
||||
import { Form, Formik } from "formik";
|
||||
import { FiTrash } from "react-icons/fi";
|
||||
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
|
||||
import {
|
||||
SelectorFormField,
|
||||
TextFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { useState } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import * as Yup from "yup";
|
||||
import isEqual from "lodash/isEqual";
|
||||
|
||||
export function LLMProviderUpdateForm({
|
||||
llmProviderDescriptor,
|
||||
onClose,
|
||||
existingLlmProvider,
|
||||
shouldMarkAsDefault,
|
||||
setPopup,
|
||||
}: {
|
||||
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
|
||||
onClose: () => void;
|
||||
existingLlmProvider?: FullLLMProvider;
|
||||
shouldMarkAsDefault?: boolean;
|
||||
setPopup?: (popup: PopupSpec) => void;
|
||||
}) {
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
const [isTesting, setIsTesting] = useState(false);
|
||||
const [testError, setTestError] = useState<string>("");
|
||||
const [isTestSuccessful, setTestSuccessful] = useState(
|
||||
existingLlmProvider ? true : false
|
||||
);
|
||||
|
||||
// Define the initial values based on the provider's requirements
|
||||
const initialValues = {
|
||||
api_key: existingLlmProvider?.api_key ?? "",
|
||||
api_base: existingLlmProvider?.api_base ?? "",
|
||||
api_version: existingLlmProvider?.api_version ?? "",
|
||||
default_model_name:
|
||||
existingLlmProvider?.default_model_name ??
|
||||
(llmProviderDescriptor.default_model ||
|
||||
llmProviderDescriptor.llm_names[0]),
|
||||
default_fast_model_name:
|
||||
existingLlmProvider?.fast_default_model_name ??
|
||||
(llmProviderDescriptor.default_fast_model || null),
|
||||
custom_config:
|
||||
existingLlmProvider?.custom_config ??
|
||||
llmProviderDescriptor.custom_config_keys?.reduce(
|
||||
(acc, key) => {
|
||||
acc[key] = "";
|
||||
return acc;
|
||||
},
|
||||
{} as { [key: string]: string }
|
||||
),
|
||||
};
|
||||
|
||||
const [validatedConfig, setValidatedConfig] = useState(
|
||||
existingLlmProvider ? initialValues : null
|
||||
);
|
||||
|
||||
// Setup validation schema if required
|
||||
const validationSchema = Yup.object({
|
||||
api_key: llmProviderDescriptor.api_key_required
|
||||
? Yup.string().required("API Key is required")
|
||||
: Yup.string(),
|
||||
api_base: llmProviderDescriptor.api_base_required
|
||||
? Yup.string().required("API Base is required")
|
||||
: Yup.string(),
|
||||
api_version: llmProviderDescriptor.api_version_required
|
||||
? Yup.string().required("API Version is required")
|
||||
: Yup.string(),
|
||||
...(llmProviderDescriptor.custom_config_keys
|
||||
? {
|
||||
custom_config: Yup.object(
|
||||
llmProviderDescriptor.custom_config_keys.reduce(
|
||||
(acc, key) => {
|
||||
acc[key] = Yup.string().required(`${key} is required`);
|
||||
return acc;
|
||||
},
|
||||
{} as { [key: string]: Yup.StringSchema }
|
||||
)
|
||||
),
|
||||
}
|
||||
: {}),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
default_fast_model_name: Yup.string().nullable(),
|
||||
});
|
||||
|
||||
return (
|
||||
<Formik
|
||||
initialValues={initialValues}
|
||||
validationSchema={validationSchema}
|
||||
// hijack this to re-enable testing on any change
|
||||
validate={(values) => {
|
||||
if (!isEqual(values, validatedConfig)) {
|
||||
setTestSuccessful(false);
|
||||
}
|
||||
}}
|
||||
onSubmit={async (values, { setSubmitting }) => {
|
||||
setSubmitting(true);
|
||||
|
||||
if (!isTestSuccessful) {
|
||||
setSubmitting(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const response = await fetch(LLM_PROVIDERS_ADMIN_URL, {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
name: llmProviderDescriptor.name,
|
||||
...values,
|
||||
fast_default_model_name:
|
||||
values.default_fast_model_name || values.default_model_name,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
const fullErrorMsg = existingLlmProvider
|
||||
? `Failed to update provider: ${errorMsg}`
|
||||
: `Failed to enable provider: ${errorMsg}`;
|
||||
if (setPopup) {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: fullErrorMsg,
|
||||
});
|
||||
} else {
|
||||
alert(fullErrorMsg);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (shouldMarkAsDefault) {
|
||||
const newLlmProvider = (await response.json()) as FullLLMProvider;
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
if (!setDefaultResponse.ok) {
|
||||
const errorMsg = (await setDefaultResponse.json()).detail;
|
||||
const fullErrorMsg = `Failed to set provider as default: ${errorMsg}`;
|
||||
if (setPopup) {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: fullErrorMsg,
|
||||
});
|
||||
} else {
|
||||
alert(fullErrorMsg);
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onClose();
|
||||
|
||||
const successMsg = existingLlmProvider
|
||||
? "Provider updated successfully!"
|
||||
: "Provider enabled successfully!";
|
||||
if (setPopup) {
|
||||
setPopup({
|
||||
type: "success",
|
||||
message: successMsg,
|
||||
});
|
||||
} else {
|
||||
alert(successMsg);
|
||||
}
|
||||
|
||||
setSubmitting(false);
|
||||
}}
|
||||
>
|
||||
{({ values }) => (
|
||||
<Form>
|
||||
{llmProviderDescriptor.api_key_required && (
|
||||
<TextFormField
|
||||
name="api_key"
|
||||
label="API Key"
|
||||
placeholder="API Key"
|
||||
type="password"
|
||||
/>
|
||||
)}
|
||||
|
||||
{llmProviderDescriptor.api_base_required && (
|
||||
<TextFormField
|
||||
name="api_base"
|
||||
label="API Base"
|
||||
placeholder="API Base"
|
||||
/>
|
||||
)}
|
||||
|
||||
{llmProviderDescriptor.api_version_required && (
|
||||
<TextFormField
|
||||
name="api_version"
|
||||
label="API Version"
|
||||
placeholder="API Version"
|
||||
/>
|
||||
)}
|
||||
|
||||
{llmProviderDescriptor.custom_config_keys?.map((key) => (
|
||||
<div key={key}>
|
||||
<TextFormField name={`custom_config.${key}`} label={key} />
|
||||
</div>
|
||||
))}
|
||||
|
||||
<Divider />
|
||||
|
||||
{llmProviderDescriptor.llm_names.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="default_model_name"
|
||||
subtext="The model to use by default for this provider unless otherwise specified."
|
||||
label="Default Model"
|
||||
options={llmProviderDescriptor.llm_names.map((name) => ({
|
||||
name,
|
||||
value: name,
|
||||
}))}
|
||||
direction="up"
|
||||
maxHeight="max-h-56"
|
||||
/>
|
||||
) : (
|
||||
<TextFormField
|
||||
name="default_model_name"
|
||||
subtext="The model to use by default for this provider unless otherwise specified."
|
||||
label="Default Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
)}
|
||||
|
||||
{llmProviderDescriptor.llm_names.length > 0 ? (
|
||||
<SelectorFormField
|
||||
name="default_fast_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If \`Default\` is specified, will use
|
||||
the Default Model configured above.`}
|
||||
label="[Optional] Fast Model"
|
||||
options={llmProviderDescriptor.llm_names.map((name) => ({
|
||||
name,
|
||||
value: name,
|
||||
}))}
|
||||
includeDefault
|
||||
direction="up"
|
||||
maxHeight="max-h-56"
|
||||
/>
|
||||
) : (
|
||||
<TextFormField
|
||||
name="default_fast_model_name"
|
||||
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
|
||||
for this provider. If \`Default\` is specified, will use
|
||||
the Default Model configured above.`}
|
||||
label="[Optional] Fast Model"
|
||||
placeholder="E.g. gpt-4"
|
||||
/>
|
||||
)}
|
||||
|
||||
<Divider />
|
||||
|
||||
<div>
|
||||
{/* NOTE: this is above the test button to make sure it's visible */}
|
||||
{!isTestSuccessful && testError && (
|
||||
<Text className="text-error mt-2">{testError}</Text>
|
||||
)}
|
||||
{isTestSuccessful && (
|
||||
<Text className="text-success mt-2">
|
||||
Test successful! LLM provider is ready to go.
|
||||
</Text>
|
||||
)}
|
||||
|
||||
<div className="flex w-full mt-4">
|
||||
{isTestSuccessful ? (
|
||||
<Button type="submit" size="xs">
|
||||
{existingLlmProvider ? "Update" : "Enable"}
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
type="button"
|
||||
size="xs"
|
||||
disabled={isTesting}
|
||||
onClick={async () => {
|
||||
setIsTesting(true);
|
||||
|
||||
const response = await fetch("/api/admin/llm/test", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
provider: llmProviderDescriptor.name,
|
||||
...values,
|
||||
}),
|
||||
});
|
||||
setIsTesting(false);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
setTestError(errorMsg);
|
||||
return;
|
||||
}
|
||||
|
||||
setTestSuccessful(true);
|
||||
setValidatedConfig(values);
|
||||
}}
|
||||
>
|
||||
{isTesting ? <LoadingAnimation text="Testing" /> : "Test"}
|
||||
</Button>
|
||||
)}
|
||||
{existingLlmProvider && (
|
||||
<Button
|
||||
type="button"
|
||||
color="red"
|
||||
className="ml-3"
|
||||
size="xs"
|
||||
icon={FiTrash}
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
alert(`Failed to delete provider: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
onClose();
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
);
|
||||
}
|
1
web/src/app/admin/models/llm/constants.ts
Normal file
1
web/src/app/admin/models/llm/constants.ts
Normal file
@ -0,0 +1 @@
|
||||
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
|
29
web/src/app/admin/models/llm/interfaces.ts
Normal file
29
web/src/app/admin/models/llm/interfaces.ts
Normal file
@ -0,0 +1,29 @@
|
||||
export interface WellKnownLLMProviderDescriptor {
|
||||
name: string;
|
||||
display_name: string | null;
|
||||
|
||||
api_key_required: boolean;
|
||||
api_base_required: boolean;
|
||||
api_version_required: boolean;
|
||||
custom_config_keys: string[] | null;
|
||||
|
||||
llm_names: string[];
|
||||
default_model: string | null;
|
||||
default_fast_model: string | null;
|
||||
}
|
||||
|
||||
export interface LLMProvider {
|
||||
name: string;
|
||||
api_key: string | null;
|
||||
api_base: string | null;
|
||||
api_version: string | null;
|
||||
custom_config: { [key: string]: string } | null;
|
||||
default_model_name: string;
|
||||
fast_default_model_name: string | null;
|
||||
}
|
||||
|
||||
export interface FullLLMProvider extends LLMProvider {
|
||||
id: number;
|
||||
is_default_provider: boolean | null;
|
||||
model_names: string[];
|
||||
}
|
@ -2,7 +2,6 @@
|
||||
|
||||
import { Form, Formik } from "formik";
|
||||
import { useEffect, useState } from "react";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import {
|
||||
BooleanFormField,
|
||||
@ -10,52 +9,9 @@ import {
|
||||
TextFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { Popup } from "@/components/admin/connectors/Popup";
|
||||
import { TrashIcon } from "@/components/icons/icons";
|
||||
import { ApiKeyForm } from "@/components/openai/ApiKeyForm";
|
||||
import { GEN_AI_API_KEY_URL } from "@/components/openai/constants";
|
||||
import { fetcher } from "@/lib/fetcher";
|
||||
import { Button, Divider, Text, Title } from "@tremor/react";
|
||||
import { Button, Divider, Text } from "@tremor/react";
|
||||
import { FiCpu } from "react-icons/fi";
|
||||
import useSWR, { mutate } from "swr";
|
||||
|
||||
const ExistingKeys = () => {
|
||||
const { data, isLoading, error } = useSWR<{ api_key: string }>(
|
||||
GEN_AI_API_KEY_URL,
|
||||
fetcher
|
||||
);
|
||||
|
||||
if (isLoading) {
|
||||
return <LoadingAnimation text="Loading" />;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return <div className="text-error">Error loading existing keys</div>;
|
||||
}
|
||||
|
||||
if (!data?.api_key) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Title className="mb-2">Existing Key</Title>
|
||||
<div className="flex mb-1">
|
||||
<p className="text-sm italic my-auto">sk- ****...**{data?.api_key}</p>
|
||||
<button
|
||||
className="ml-1 my-auto hover:bg-hover rounded p-1"
|
||||
onClick={async () => {
|
||||
await fetch(GEN_AI_API_KEY_URL, {
|
||||
method: "DELETE",
|
||||
});
|
||||
window.location.reload();
|
||||
}}
|
||||
>
|
||||
<TrashIcon />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
import { LLMConfiguration } from "./LLMConfiguration";
|
||||
|
||||
const LLMOptions = () => {
|
||||
const [popup, setPopup] = useState<{
|
||||
@ -219,27 +175,12 @@ const Page = () => {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<AdminPageTitle
|
||||
title="LLM Options"
|
||||
title="LLM Setup"
|
||||
icon={<FiCpu size={32} className="my-auto" />}
|
||||
/>
|
||||
|
||||
<SectionHeader>LLM Keys</SectionHeader>
|
||||
<LLMConfiguration />
|
||||
|
||||
<ExistingKeys />
|
||||
|
||||
<Title className="mb-2 mt-6">Update Key</Title>
|
||||
<Text className="mb-2">
|
||||
Specify an OpenAI API key and click the "Submit" button.
|
||||
</Text>
|
||||
<div className="border rounded-md border-border p-3">
|
||||
<ApiKeyForm
|
||||
handleResponse={(response) => {
|
||||
if (response.ok) {
|
||||
mutate(GEN_AI_API_KEY_URL);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<LLMOptions />
|
||||
</div>
|
||||
);
|
@ -20,7 +20,7 @@ import {
|
||||
WelcomeModal,
|
||||
hasCompletedWelcomeFlowSS,
|
||||
} from "@/components/initialSetup/welcome/WelcomeModalWrapper";
|
||||
import { ApiKeyModal } from "@/components/openai/ApiKeyModal";
|
||||
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
|
||||
import { cookies } from "next/headers";
|
||||
import { DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME } from "@/components/resizable/contants";
|
||||
import { personaComparator } from "../admin/assistants/lib";
|
||||
@ -169,9 +169,9 @@ export default async function Page({
|
||||
<>
|
||||
<InstantSSRAutoRefresh />
|
||||
|
||||
{shouldShowWelcomeModal && <WelcomeModal />}
|
||||
{shouldShowWelcomeModal && <WelcomeModal user={user} />}
|
||||
{!shouldShowWelcomeModal && !shouldDisplaySourcesIncompleteModal && (
|
||||
<ApiKeyModal />
|
||||
<ApiKeyModal user={user} />
|
||||
)}
|
||||
{shouldDisplaySourcesIncompleteModal && (
|
||||
<NoCompleteSourcesModal ccPairs={ccPairs} />
|
||||
|
@ -7,7 +7,7 @@ import {
|
||||
} from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
import { HealthCheckBanner } from "@/components/health/healthcheck";
|
||||
import { ApiKeyModal } from "@/components/openai/ApiKeyModal";
|
||||
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
|
||||
import { fetchSS } from "@/lib/utilsSS";
|
||||
import { CCPairBasicInfo, DocumentSet, Tag, User } from "@/lib/types";
|
||||
import { cookies } from "next/headers";
|
||||
@ -147,10 +147,10 @@ export default async function Home() {
|
||||
<div className="m-3">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
{shouldShowWelcomeModal && <WelcomeModal />}
|
||||
{shouldShowWelcomeModal && <WelcomeModal user={user} />}
|
||||
{!shouldShowWelcomeModal &&
|
||||
!shouldDisplayNoSourcesModal &&
|
||||
!shouldDisplaySourcesIncompleteModal && <ApiKeyModal />}
|
||||
!shouldDisplaySourcesIncompleteModal && <ApiKeyModal user={user} />}
|
||||
{shouldDisplayNoSourcesModal && <NoSourcesModal />}
|
||||
{shouldDisplaySourcesIncompleteModal && (
|
||||
<NoCompleteSourcesModal ccPairs={ccPairs} />
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { ChangeEvent, FC, useEffect, useRef, useState } from "react";
|
||||
import { ChevronDownIcon } from "./icons/icons";
|
||||
import { FiCheck, FiChevronDown } from "react-icons/fi";
|
||||
import { Popover } from "./popover/Popover";
|
||||
|
||||
export interface Option<T> {
|
||||
name: string;
|
||||
@ -181,9 +182,11 @@ export function SearchMultiSelectDropdown({
|
||||
export const CustomDropdown = ({
|
||||
children,
|
||||
dropdown,
|
||||
direction = "down", // Default to 'down' if not specified
|
||||
}: {
|
||||
children: JSX.Element | string;
|
||||
dropdown: JSX.Element | string;
|
||||
direction?: "up" | "down";
|
||||
}) => {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const dropdownRef = useRef<HTMLDivElement>(null);
|
||||
@ -211,7 +214,9 @@ export const CustomDropdown = ({
|
||||
{isOpen && (
|
||||
<div
|
||||
onClick={() => setIsOpen(!isOpen)}
|
||||
className="pt-2 absolute bottom w-full z-30 box-shadow"
|
||||
className={`absolute ${
|
||||
direction === "up" ? "bottom-full pb-2" : "pt-2 "
|
||||
} w-full z-30 box-shadow`}
|
||||
>
|
||||
{dropdown}
|
||||
</div>
|
||||
@ -281,16 +286,21 @@ export function DefaultDropdown({
|
||||
selected,
|
||||
onSelect,
|
||||
includeDefault = false,
|
||||
direction = "down",
|
||||
maxHeight,
|
||||
}: {
|
||||
options: StringOrNumberOption[];
|
||||
selected: string | null;
|
||||
onSelect: (value: string | number | null) => void;
|
||||
includeDefault?: boolean;
|
||||
direction?: "up" | "down";
|
||||
maxHeight?: string;
|
||||
}) {
|
||||
const selectedOption = options.find((option) => option.value === selected);
|
||||
|
||||
return (
|
||||
<CustomDropdown
|
||||
direction={direction}
|
||||
dropdown={
|
||||
<div
|
||||
className={`
|
||||
@ -300,7 +310,7 @@ export function DefaultDropdown({
|
||||
flex
|
||||
flex-col
|
||||
bg-background
|
||||
max-h-96
|
||||
${maxHeight || "max-h-96"}
|
||||
overflow-y-auto
|
||||
overscroll-contain`}
|
||||
>
|
||||
|
@ -154,7 +154,7 @@ export async function Layout({ children }: { children: React.ReactNode }) {
|
||||
<div className="ml-1">LLM</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/keys/openai",
|
||||
link: "/admin/models/llm",
|
||||
},
|
||||
{
|
||||
name: (
|
||||
|
@ -233,6 +233,9 @@ interface SelectorFormFieldProps {
|
||||
options: StringOrNumberOption[];
|
||||
subtext?: string | JSX.Element;
|
||||
includeDefault?: boolean;
|
||||
direction?: "up" | "down";
|
||||
maxHeight?: string;
|
||||
onSelect?: (selected: string | number | null) => void;
|
||||
}
|
||||
|
||||
export function SelectorFormField({
|
||||
@ -241,6 +244,9 @@ export function SelectorFormField({
|
||||
options,
|
||||
subtext,
|
||||
includeDefault = false,
|
||||
direction = "down",
|
||||
maxHeight,
|
||||
onSelect,
|
||||
}: SelectorFormFieldProps) {
|
||||
const [field] = useField<string>(name);
|
||||
const { setFieldValue } = useFormikContext();
|
||||
@ -254,8 +260,10 @@ export function SelectorFormField({
|
||||
<DefaultDropdown
|
||||
options={options}
|
||||
selected={field.value}
|
||||
onSelect={(selected) => setFieldValue(name, selected)}
|
||||
onSelect={onSelect || ((selected) => setFieldValue(name, selected))}
|
||||
includeDefault={includeDefault}
|
||||
direction={direction}
|
||||
maxHeight={maxHeight}
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
@ -37,7 +37,7 @@ export function NoCompleteSourcesModal({
|
||||
title="⏳ None of your connectors have finished a full sync yet"
|
||||
onOutsideClick={() => setIsHidden(true)}
|
||||
>
|
||||
<div className="text-base">
|
||||
<div className="text-sm">
|
||||
<div>
|
||||
<div>
|
||||
You've connected some sources, but none of them have finished
|
||||
|
@ -1,6 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { Button, Divider } from "@tremor/react";
|
||||
import { Button, Divider, Text } from "@tremor/react";
|
||||
import { Modal } from "../../Modal";
|
||||
import Link from "next/link";
|
||||
import { FiMessageSquare, FiShare2 } from "react-icons/fi";
|
||||
@ -21,11 +21,11 @@ export function NoSourcesModal() {
|
||||
>
|
||||
<div className="text-base">
|
||||
<div>
|
||||
<p>
|
||||
<Text>
|
||||
Before using Search you'll need to connect at least one source.
|
||||
Without any connected knowledge sources, there isn't anything
|
||||
to search over.
|
||||
</p>
|
||||
</Text>
|
||||
<Link href="/admin/add-connector">
|
||||
<Button className="mt-3" size="xs" icon={FiShare2}>
|
||||
Connect a Source!
|
||||
@ -33,11 +33,11 @@ export function NoSourcesModal() {
|
||||
</Link>
|
||||
<Divider />
|
||||
<div>
|
||||
<p>
|
||||
<Text>
|
||||
Or, if you're looking for a pure ChatGPT-like experience
|
||||
without any organization specific knowledge, then you can head
|
||||
over to the Chat page and start chatting with Danswer right away!
|
||||
</p>
|
||||
</Text>
|
||||
<Link href="/chat">
|
||||
<Button className="mt-3" size="xs" icon={FiMessageSquare}>
|
||||
Start Chatting!
|
||||
|
@ -9,8 +9,10 @@ import { COMPLETED_WELCOME_FLOW_COOKIE } from "./constants";
|
||||
import { FiCheckCircle, FiMessageSquare, FiShare2 } from "react-icons/fi";
|
||||
import { useEffect, useState } from "react";
|
||||
import { BackButton } from "@/components/BackButton";
|
||||
import { ApiKeyForm } from "@/components/openai/ApiKeyForm";
|
||||
import { checkApiKey } from "@/components/openai/ApiKeyModal";
|
||||
import { ApiKeyForm } from "@/components/llm/ApiKeyForm";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
|
||||
import { checkLlmProvider } from "./lib";
|
||||
import { User } from "@/lib/types";
|
||||
|
||||
function setWelcomeFlowComplete() {
|
||||
Cookies.set(COMPLETED_WELCOME_FLOW_COOKIE, "true", { expires: 365 });
|
||||
@ -52,20 +54,26 @@ function UsageTypeSection({
|
||||
);
|
||||
}
|
||||
|
||||
export function _WelcomeModal() {
|
||||
export function _WelcomeModal({ user }: { user: User | null }) {
|
||||
const router = useRouter();
|
||||
const [selectedFlow, setSelectedFlow] = useState<null | "search" | "chat">(
|
||||
null
|
||||
);
|
||||
const [isHidden, setIsHidden] = useState(false);
|
||||
const [apiKeyVerified, setApiKeyVerified] = useState<boolean>(false);
|
||||
const [providerOptions, setProviderOptions] = useState<
|
||||
WellKnownLLMProviderDescriptor[]
|
||||
>([]);
|
||||
|
||||
useEffect(() => {
|
||||
checkApiKey().then((error) => {
|
||||
if (!error) {
|
||||
setApiKeyVerified(true);
|
||||
}
|
||||
});
|
||||
async function fetchProviderInfo() {
|
||||
const { providers, options, defaultCheckSuccessful } =
|
||||
await checkLlmProvider(user);
|
||||
setApiKeyVerified(providers.length > 0 && defaultCheckSuccessful);
|
||||
setProviderOptions(options);
|
||||
}
|
||||
|
||||
fetchProviderInfo();
|
||||
}, []);
|
||||
|
||||
if (isHidden) {
|
||||
@ -78,30 +86,27 @@ export function _WelcomeModal() {
|
||||
case "search":
|
||||
title = undefined;
|
||||
body = (
|
||||
<>
|
||||
<div className="max-h-[85vh] overflow-y-auto px-4 pb-4">
|
||||
<BackButton behaviorOverride={() => setSelectedFlow(null)} />
|
||||
<div className="mt-3">
|
||||
<Text className="font-bold mt-6 mb-2 flex">
|
||||
<Text className="font-bold flex">
|
||||
{apiKeyVerified && (
|
||||
<FiCheckCircle className="my-auto mr-2 text-success" />
|
||||
)}
|
||||
Step 1: Provide OpenAI API Key
|
||||
Step 1: Setup an LLM
|
||||
</Text>
|
||||
<div>
|
||||
{apiKeyVerified ? (
|
||||
<div>
|
||||
API Key setup complete!
|
||||
<Text className="mt-2">
|
||||
LLM setup complete!
|
||||
<br /> <br />
|
||||
If you want to change the key later, you'll be able to
|
||||
easily to do so in the Admin Panel.
|
||||
</div>
|
||||
</Text>
|
||||
) : (
|
||||
<ApiKeyForm
|
||||
handleResponse={async (response) => {
|
||||
if (response.ok) {
|
||||
setApiKeyVerified(true);
|
||||
}
|
||||
}}
|
||||
onSuccess={() => setApiKeyVerified(true)}
|
||||
providerOptions={providerOptions}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
@ -109,12 +114,12 @@ export function _WelcomeModal() {
|
||||
Step 2: Connect Data Sources
|
||||
</Text>
|
||||
<div>
|
||||
<p>
|
||||
<Text>
|
||||
Connectors are the way that Danswer gets data from your
|
||||
organization's various data sources. Once setup, we'll
|
||||
automatically sync data from your apps and docs into Danswer, so
|
||||
you can search through all of them in one place.
|
||||
</p>
|
||||
</Text>
|
||||
|
||||
<div className="flex mt-3">
|
||||
<Link
|
||||
@ -133,59 +138,37 @@ export function _WelcomeModal() {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
break;
|
||||
case "chat":
|
||||
title = undefined;
|
||||
body = (
|
||||
<>
|
||||
<div className="mt-3 max-h-[85vh] overflow-y-auto px-4 pb-4">
|
||||
<BackButton behaviorOverride={() => setSelectedFlow(null)} />
|
||||
|
||||
<div className="mt-3">
|
||||
<div>
|
||||
To start using Danswer as a secure ChatGPT, we just need to
|
||||
configure our LLM!
|
||||
<br />
|
||||
<br />
|
||||
Danswer supports connections with a wide range of LLMs, including
|
||||
self-hosted open-source LLMs. For more details, check out the{" "}
|
||||
<a
|
||||
className="text-link"
|
||||
href="https://docs.danswer.dev/gen_ai_configs/overview"
|
||||
>
|
||||
documentation
|
||||
</a>
|
||||
.
|
||||
<br />
|
||||
<br />
|
||||
If you haven't done anything special with the Gen AI configs,
|
||||
then we default to use OpenAI.
|
||||
</div>
|
||||
|
||||
<Text className="font-bold mt-6 mb-2 flex">
|
||||
<Text className="font-bold flex">
|
||||
{apiKeyVerified && (
|
||||
<FiCheckCircle className="my-auto mr-2 text-success" />
|
||||
)}
|
||||
Step 1: Provide LLM API Key
|
||||
Step 1: Setup an LLM
|
||||
</Text>
|
||||
<div>
|
||||
{apiKeyVerified ? (
|
||||
<div>
|
||||
<Text className="mt-2">
|
||||
LLM setup complete!
|
||||
<br /> <br />
|
||||
If you want to change the key later or choose a different LLM,
|
||||
you'll be able to easily to do so in the Admin Panel / by
|
||||
changing some environment variables.
|
||||
</div>
|
||||
you'll be able to easily to do so in the Admin Panel.
|
||||
</Text>
|
||||
) : (
|
||||
<ApiKeyForm
|
||||
handleResponse={async (response) => {
|
||||
if (response.ok) {
|
||||
setApiKeyVerified(true);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<div>
|
||||
<ApiKeyForm
|
||||
onSuccess={() => setApiKeyVerified(true)}
|
||||
providerOptions={providerOptions}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@ -193,7 +176,7 @@ export function _WelcomeModal() {
|
||||
Step 2: Start Chatting!
|
||||
</Text>
|
||||
|
||||
<div>
|
||||
<Text>
|
||||
Click the button below to start chatting with the LLM setup above!
|
||||
Don't worry, if you do decide later on you want to connect
|
||||
your organization's knowledge, you can always do that in the{" "}
|
||||
@ -209,7 +192,7 @@ export function _WelcomeModal() {
|
||||
Admin Panel
|
||||
</Link>
|
||||
.
|
||||
</div>
|
||||
</Text>
|
||||
|
||||
<div className="flex mt-3">
|
||||
<Link
|
||||
@ -228,7 +211,7 @@ export function _WelcomeModal() {
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
break;
|
||||
default:
|
||||
@ -236,17 +219,17 @@ export function _WelcomeModal() {
|
||||
body = (
|
||||
<>
|
||||
<div>
|
||||
<p>How are you planning on using Danswer?</p>
|
||||
<Text>How are you planning on using Danswer?</Text>
|
||||
</div>
|
||||
<Divider />
|
||||
<UsageTypeSection
|
||||
title="Search / Chat with Knowledge"
|
||||
description={
|
||||
<div>
|
||||
<Text>
|
||||
If you're looking to search through, chat with, or ask
|
||||
direct questions of your organization's knowledge, then
|
||||
this is the option for you!
|
||||
</div>
|
||||
</Text>
|
||||
}
|
||||
callToAction="Get Started"
|
||||
onClick={() => setSelectedFlow("search")}
|
||||
@ -255,10 +238,10 @@ export function _WelcomeModal() {
|
||||
<UsageTypeSection
|
||||
title="Secure ChatGPT"
|
||||
description={
|
||||
<>
|
||||
<Text>
|
||||
If you're looking for a pure ChatGPT-like experience, then
|
||||
this is the option for you!
|
||||
</>
|
||||
</Text>
|
||||
}
|
||||
icon={FiMessageSquare}
|
||||
callToAction="Get Started"
|
||||
|
@ -4,6 +4,7 @@ import {
|
||||
_WelcomeModal,
|
||||
} from "./WelcomeModal";
|
||||
import { COMPLETED_WELCOME_FLOW_COOKIE } from "./constants";
|
||||
import { User } from "@/lib/types";
|
||||
|
||||
export function hasCompletedWelcomeFlowSS() {
|
||||
const cookieStore = cookies();
|
||||
@ -13,11 +14,11 @@ export function hasCompletedWelcomeFlowSS() {
|
||||
);
|
||||
}
|
||||
|
||||
export function WelcomeModal() {
|
||||
export function WelcomeModal({ user }: { user: User | null }) {
|
||||
const hasCompletedWelcomeFlow = hasCompletedWelcomeFlowSS();
|
||||
if (hasCompletedWelcomeFlow) {
|
||||
return <_CompletedWelcomeFlowDummyComponent />;
|
||||
}
|
||||
|
||||
return <_WelcomeModal />;
|
||||
return <_WelcomeModal user={user} />;
|
||||
}
|
||||
|
56
web/src/components/initialSetup/welcome/lib.ts
Normal file
56
web/src/components/initialSetup/welcome/lib.ts
Normal file
@ -0,0 +1,56 @@
|
||||
import {
|
||||
FullLLMProvider,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/app/admin/models/llm/interfaces";
|
||||
import { User } from "@/lib/types";
|
||||
|
||||
const DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY = "defaultLlmProviderTestComplete";
|
||||
|
||||
function checkDefaultLLMProviderTestComplete() {
|
||||
return (
|
||||
localStorage.getItem(DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY) === "true"
|
||||
);
|
||||
}
|
||||
|
||||
function setDefaultLLMProviderTestComplete() {
|
||||
localStorage.setItem(DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY, "true");
|
||||
}
|
||||
|
||||
function shouldCheckDefaultLLMProvider(user: User | null) {
|
||||
return (
|
||||
!checkDefaultLLMProviderTestComplete() && (!user || user.role === "admin")
|
||||
);
|
||||
}
|
||||
|
||||
export async function checkLlmProvider(user: User | null) {
|
||||
/* NOTE: should only be called on the client side, after initial render */
|
||||
const checkDefault = shouldCheckDefaultLLMProvider(user);
|
||||
|
||||
const tasks = [
|
||||
fetch("/api/llm/provider"),
|
||||
fetch("/api/admin/llm/built-in/options"),
|
||||
checkDefault
|
||||
? fetch("/api/admin/llm/test/default", { method: "POST" })
|
||||
: (async () => null)(),
|
||||
];
|
||||
const [providerResponse, optionsResponse, defaultCheckResponse] =
|
||||
await Promise.all(tasks);
|
||||
|
||||
let providers: FullLLMProvider[] = [];
|
||||
if (providerResponse?.ok) {
|
||||
providers = await providerResponse.json();
|
||||
}
|
||||
|
||||
let options: WellKnownLLMProviderDescriptor[] = [];
|
||||
if (optionsResponse?.ok) {
|
||||
options = await optionsResponse.json();
|
||||
}
|
||||
|
||||
let defaultCheckSuccessful =
|
||||
!checkDefault || defaultCheckResponse?.ok || false;
|
||||
if (defaultCheckSuccessful) {
|
||||
setDefaultLLMProviderTestComplete();
|
||||
}
|
||||
|
||||
return { providers, options, defaultCheckSuccessful };
|
||||
}
|
77
web/src/components/llm/ApiKeyForm.tsx
Normal file
77
web/src/components/llm/ApiKeyForm.tsx
Normal file
@ -0,0 +1,77 @@
|
||||
import { Popup } from "../admin/connectors/Popup";
|
||||
import { useState } from "react";
|
||||
import { TabGroup, TabList, Tab, TabPanels, TabPanel } from "@tremor/react";
|
||||
import { WellKnownLLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
|
||||
import { LLMProviderUpdateForm } from "@/app/admin/models/llm/LLMProviderUpdateForm";
|
||||
import { CustomLLMProviderUpdateForm } from "@/app/admin/models/llm/CustomLLMProviderUpdateForm";
|
||||
|
||||
export const ApiKeyForm = ({
|
||||
onSuccess,
|
||||
providerOptions,
|
||||
}: {
|
||||
onSuccess: () => void;
|
||||
providerOptions: WellKnownLLMProviderDescriptor[];
|
||||
}) => {
|
||||
const [popup, setPopup] = useState<{
|
||||
message: string;
|
||||
type: "success" | "error";
|
||||
} | null>(null);
|
||||
|
||||
const defaultProvider = providerOptions[0]?.name;
|
||||
const providerNameToIndexMap = new Map<string, number>();
|
||||
providerOptions.forEach((provider, index) => {
|
||||
providerNameToIndexMap.set(provider.name, index);
|
||||
});
|
||||
providerNameToIndexMap.set("custom", providerOptions.length);
|
||||
|
||||
const providerIndexToNameMap = new Map<number, string>();
|
||||
Array.from(providerNameToIndexMap.keys()).forEach((key) => {
|
||||
providerIndexToNameMap.set(providerNameToIndexMap.get(key)!, key);
|
||||
});
|
||||
|
||||
const [providerName, setProviderName] = useState<string>(defaultProvider);
|
||||
|
||||
return (
|
||||
<div>
|
||||
{popup && <Popup message={popup.message} type={popup.type} />}
|
||||
<TabGroup
|
||||
index={providerNameToIndexMap.get(providerName) || 0}
|
||||
onIndexChange={(index) =>
|
||||
setProviderName(providerIndexToNameMap.get(index) || defaultProvider)
|
||||
}
|
||||
>
|
||||
<TabList className="mt-3 mb-4">
|
||||
<>
|
||||
{providerOptions.map((provider) => (
|
||||
<Tab key={provider.name}>
|
||||
{provider.display_name || provider.name}
|
||||
</Tab>
|
||||
))}
|
||||
<Tab key="custom">Custom</Tab>
|
||||
</>
|
||||
</TabList>
|
||||
<TabPanels>
|
||||
{providerOptions.map((provider) => {
|
||||
return (
|
||||
<TabPanel key={provider.name}>
|
||||
<LLMProviderUpdateForm
|
||||
llmProviderDescriptor={provider}
|
||||
onClose={() => onSuccess()}
|
||||
shouldMarkAsDefault
|
||||
setPopup={setPopup}
|
||||
/>
|
||||
</TabPanel>
|
||||
);
|
||||
})}
|
||||
<TabPanel key="custom">
|
||||
<CustomLLMProviderUpdateForm
|
||||
onClose={() => onSuccess()}
|
||||
shouldMarkAsDefault
|
||||
setPopup={setPopup}
|
||||
/>
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</TabGroup>
|
||||
</div>
|
||||
);
|
||||
};
|
74
web/src/components/llm/ApiKeyModal.tsx
Normal file
74
web/src/components/llm/ApiKeyModal.tsx
Normal file
@ -0,0 +1,74 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { ApiKeyForm } from "./ApiKeyForm";
|
||||
import { Modal } from "../Modal";
|
||||
import { Divider } from "@tremor/react";
|
||||
import {
|
||||
FullLLMProvider,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "@/app/admin/models/llm/interfaces";
|
||||
import { checkLlmProvider } from "../initialSetup/welcome/lib";
|
||||
import { User } from "@/lib/types";
|
||||
|
||||
export const ApiKeyModal = ({ user }: { user: User | null }) => {
|
||||
const [forceHidden, setForceHidden] = useState<boolean>(false);
|
||||
const [validProviderExists, setValidProviderExists] = useState<boolean>(true);
|
||||
const [providerOptions, setProviderOptions] = useState<
|
||||
WellKnownLLMProviderDescriptor[]
|
||||
>([]);
|
||||
|
||||
useEffect(() => {
|
||||
async function fetchProviderInfo() {
|
||||
const { providers, options, defaultCheckSuccessful } =
|
||||
await checkLlmProvider(user);
|
||||
setValidProviderExists(providers.length > 0 && defaultCheckSuccessful);
|
||||
setProviderOptions(options);
|
||||
}
|
||||
|
||||
fetchProviderInfo();
|
||||
}, []);
|
||||
|
||||
// don't show if
|
||||
// (1) a valid provider has been setup or
|
||||
// (2) there are no provider options (e.g. user isn't an admin)
|
||||
// (3) user explicitly hides the modal
|
||||
if (validProviderExists || !providerOptions.length || forceHidden) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title="LLM Key Setup"
|
||||
className="max-w-4xl"
|
||||
onOutsideClick={() => setForceHidden(true)}
|
||||
>
|
||||
<div className="max-h-[75vh] overflow-y-auto flex flex-col px-4">
|
||||
<div>
|
||||
<div className="mb-5 text-sm">
|
||||
Please setup an LLM below in order to start using Danswer Search or
|
||||
Danswer Chat. Don't worry, you can always change this later in
|
||||
the Admin Panel.
|
||||
<br />
|
||||
<br />
|
||||
Or if you'd rather look around first,{" "}
|
||||
<strong
|
||||
onClick={() => setForceHidden(true)}
|
||||
className="text-link cursor-pointer"
|
||||
>
|
||||
skip this step
|
||||
</strong>
|
||||
.
|
||||
</div>
|
||||
|
||||
<ApiKeyForm
|
||||
onSuccess={() => {
|
||||
setForceHidden(true);
|
||||
}}
|
||||
providerOptions={providerOptions}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
};
|
@ -1,82 +0,0 @@
|
||||
import { Form, Formik } from "formik";
|
||||
import { Popup } from "../admin/connectors/Popup";
|
||||
import { useState } from "react";
|
||||
import { TextFormField } from "../admin/connectors/Field";
|
||||
import { GEN_AI_API_KEY_URL } from "./constants";
|
||||
import { LoadingAnimation } from "../Loading";
|
||||
import { Button } from "@tremor/react";
|
||||
|
||||
interface Props {
|
||||
handleResponse?: (response: Response) => void;
|
||||
}
|
||||
|
||||
export const ApiKeyForm = ({ handleResponse }: Props) => {
|
||||
const [popup, setPopup] = useState<{
|
||||
message: string;
|
||||
type: "success" | "error";
|
||||
} | null>(null);
|
||||
|
||||
return (
|
||||
<div>
|
||||
{popup && <Popup message={popup.message} type={popup.type} />}
|
||||
<Formik
|
||||
initialValues={{ apiKey: "" }}
|
||||
onSubmit={async ({ apiKey }, formikHelpers) => {
|
||||
const response = await fetch(GEN_AI_API_KEY_URL, {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ api_key: apiKey }),
|
||||
});
|
||||
if (handleResponse) {
|
||||
handleResponse(response);
|
||||
}
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Updated API key!",
|
||||
type: "success",
|
||||
});
|
||||
formikHelpers.resetForm();
|
||||
} else {
|
||||
const body = await response.json();
|
||||
if (body.detail) {
|
||||
setPopup({ message: body.detail, type: "error" });
|
||||
} else {
|
||||
setPopup({
|
||||
message:
|
||||
"Unable to set API key. Check if the provided key is valid.",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
setTimeout(() => {
|
||||
setPopup(null);
|
||||
}, 10000);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting }) =>
|
||||
isSubmitting ? (
|
||||
<div className="text-base">
|
||||
<LoadingAnimation text="Validating API key" />
|
||||
</div>
|
||||
) : (
|
||||
<Form>
|
||||
<TextFormField name="apiKey" type="password" label="API Key:" />
|
||||
<div className="flex">
|
||||
<Button
|
||||
size="xs"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
className="w-48 mx-auto"
|
||||
>
|
||||
Submit
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)
|
||||
}
|
||||
</Formik>
|
||||
</div>
|
||||
);
|
||||
};
|
@ -1,68 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { ApiKeyForm } from "./ApiKeyForm";
|
||||
import { Modal } from "../Modal";
|
||||
import { Divider, Text } from "@tremor/react";
|
||||
|
||||
export async function checkApiKey() {
|
||||
const response = await fetch("/api/manage/admin/genai-api-key/validate");
|
||||
if (!response.ok && (response.status === 404 || response.status === 400)) {
|
||||
const jsonResponse = await response.json();
|
||||
return jsonResponse.detail;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export const ApiKeyModal = () => {
|
||||
const [errorMsg, setErrorMsg] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
checkApiKey().then((error) => {
|
||||
if (error) {
|
||||
setErrorMsg(error);
|
||||
}
|
||||
});
|
||||
}, []);
|
||||
|
||||
if (!errorMsg) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title="LLM Key Setup"
|
||||
className="max-w-4xl"
|
||||
onOutsideClick={() => setErrorMsg(null)}
|
||||
>
|
||||
<div>
|
||||
<div>
|
||||
<div className="mb-2.5 text-base">
|
||||
Please provide a valid OpenAI API key below in order to start using
|
||||
Danswer Search or Danswer Chat.
|
||||
<br />
|
||||
<br />
|
||||
Or if you'd rather look around first,{" "}
|
||||
<strong
|
||||
onClick={() => setErrorMsg(null)}
|
||||
className="text-link cursor-pointer"
|
||||
>
|
||||
skip this step
|
||||
</strong>
|
||||
.
|
||||
</div>
|
||||
|
||||
<Divider />
|
||||
|
||||
<ApiKeyForm
|
||||
handleResponse={(response) => {
|
||||
if (response.ok) {
|
||||
setErrorMsg(null);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
};
|
@ -1 +0,0 @@
|
||||
export const GEN_AI_API_KEY_URL = "/api/manage/admin/genai-api-key";
|
@ -2,6 +2,7 @@ import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { CCPairBasicInfo, DocumentSet, User } from "../types";
|
||||
import { getCurrentUserSS } from "../userSS";
|
||||
import { fetchSS } from "../utilsSS";
|
||||
import { FullLLMProvider } from "@/app/admin/models/llm/interfaces";
|
||||
|
||||
export async function fetchAssistantEditorInfoSS(
|
||||
personaId?: number | string
|
||||
@ -10,8 +11,7 @@ export async function fetchAssistantEditorInfoSS(
|
||||
{
|
||||
ccPairs: CCPairBasicInfo[];
|
||||
documentSets: DocumentSet[];
|
||||
llmOverrideOptions: string[];
|
||||
defaultLLM: string;
|
||||
llmProviders: FullLLMProvider[];
|
||||
user: User | null;
|
||||
existingPersona: Persona | null;
|
||||
},
|
||||
@ -22,8 +22,7 @@ export async function fetchAssistantEditorInfoSS(
|
||||
const tasks = [
|
||||
fetchSS("/manage/indexing-status"),
|
||||
fetchSS("/manage/document-set"),
|
||||
fetchSS("/persona/utils/list-available-models"),
|
||||
fetchSS("/persona/utils/default-model"),
|
||||
fetchSS("/llm/provider"),
|
||||
// duplicate fetch, but shouldn't be too big of a deal
|
||||
// this page is not a high traffic page
|
||||
getCurrentUserSS(),
|
||||
@ -37,15 +36,13 @@ export async function fetchAssistantEditorInfoSS(
|
||||
const [
|
||||
ccPairsInfoResponse,
|
||||
documentSetsResponse,
|
||||
llmOverridesResponse,
|
||||
defaultLLMResponse,
|
||||
llmProvidersResponse,
|
||||
user,
|
||||
personaResponse,
|
||||
] = (await Promise.all(tasks)) as [
|
||||
Response,
|
||||
Response,
|
||||
Response,
|
||||
Response,
|
||||
User | null,
|
||||
Response | null,
|
||||
];
|
||||
@ -66,21 +63,13 @@ export async function fetchAssistantEditorInfoSS(
|
||||
}
|
||||
const documentSets = (await documentSetsResponse.json()) as DocumentSet[];
|
||||
|
||||
if (!llmOverridesResponse.ok) {
|
||||
if (!llmProvidersResponse.ok) {
|
||||
return [
|
||||
null,
|
||||
`Failed to fetch LLM override options - ${await llmOverridesResponse.text()}`,
|
||||
`Failed to fetch LLM providers - ${await llmProvidersResponse.text()}`,
|
||||
];
|
||||
}
|
||||
const llmOverrideOptions = (await llmOverridesResponse.json()) as string[];
|
||||
|
||||
if (!defaultLLMResponse.ok) {
|
||||
return [
|
||||
null,
|
||||
`Failed to fetch default LLM - ${await defaultLLMResponse.text()}`,
|
||||
];
|
||||
}
|
||||
const defaultLLM = (await defaultLLMResponse.json()) as string;
|
||||
const llmProviders = (await llmProvidersResponse.json()) as FullLLMProvider[];
|
||||
|
||||
if (personaId && personaResponse && !personaResponse.ok) {
|
||||
return [null, `Failed to fetch Persona - ${await personaResponse.text()}`];
|
||||
@ -93,8 +82,7 @@ export async function fetchAssistantEditorInfoSS(
|
||||
{
|
||||
ccPairs,
|
||||
documentSets,
|
||||
llmOverrideOptions,
|
||||
defaultLLM,
|
||||
llmProviders,
|
||||
user,
|
||||
existingPersona,
|
||||
},
|
||||
|
Loading…
x
Reference in New Issue
Block a user