Add UI-based LLM selection

This commit is contained in:
Weves 2024-04-13 18:07:13 -07:00 committed by Chris Weaver
parent 4c740060aa
commit f5b3333df3
58 changed files with 2284 additions and 701 deletions

View File

@ -0,0 +1,49 @@
"""Add tables for UI-based LLM configuration
Revision ID: 401c1ac29467
Revises: 703313b75876
Create Date: 2024-04-13 18:07:29.153817
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "401c1ac29467"
down_revision = "703313b75876"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"llm_provider",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("api_base", sa.String(), nullable=True),
sa.Column("api_version", sa.String(), nullable=True),
sa.Column(
"custom_config",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
sa.Column("default_model_name", sa.String(), nullable=False),
sa.Column("fast_default_model_name", sa.String(), nullable=True),
sa.Column("is_default_provider", sa.Boolean(), unique=True, nullable=True),
sa.Column("model_names", postgresql.ARRAY(sa.String()), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
op.add_column(
"persona",
sa.Column("llm_model_provider_override", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("persona", "llm_model_provider_override")
op.drop_table("llm_provider")

View File

@ -88,6 +88,7 @@ def load_personas_from_yaml(
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
llm_model_provider_override=None,
llm_model_version_override=None,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompts=cast(list[PromptDBModel] | None, prompts),

View File

@ -34,11 +34,10 @@ from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import LLMConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import SearchRequest
@ -135,8 +134,8 @@ def stream_chat_message_objects(
)
try:
llm = get_default_llm(
gen_ai_model_version_override=persona.llm_model_version_override
llm = get_llm_for_persona(
persona, new_msg_req.llm_override or chat_session.llm_override
)
except GenAIDisabledException:
llm = None
@ -373,9 +372,11 @@ def stream_chat_message_objects(
new_msg_req.prompt_override or chat_session.prompt_override
),
),
llm_config=LLMConfig.from_persona(
persona,
llm_override=(new_msg_req.llm_override or chat_session.llm_override),
llm=(
llm
or get_llm_for_persona(
persona, new_msg_req.llm_override or chat_session.llm_override
)
),
doc_relevance_list=llm_relevance_list,
message_history=[

View File

@ -24,6 +24,7 @@ MATCH_HIGHLIGHTS = "match_highlights"
# not be used for QA. For example, Google Drive file types which can't be parsed
# are still useful as a search result but not for QA.
IGNORE_FOR_QA = "ignore_for_qa"
# NOTE: deprecated, only used for porting key from old system
GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
PUBLIC_DOC_PAT = "PUBLIC"
PUBLIC_DOCUMENT_SET = "__PUBLIC"
@ -52,10 +53,6 @@ SECTION_SEPARATOR = "\n\n"
INDEX_SEPARATOR = "==="
# Key-Value store constants
GEN_AI_DETECTED_MODEL = "gen_ai_detected_model"
# Messages
DISABLED_GEN_AI_MSG = (
"Your System Admin has disabled the Generative AI functionalities of Danswer.\n"

View File

@ -36,13 +36,15 @@ from danswer.danswerbot.slack.utils import slack_usage_report
from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import Persona
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.persona import fetch_persona_by_id
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
@ -237,32 +239,38 @@ def handle_message(
max_document_tokens: int | None = None
max_history_tokens: int | None = None
if len(new_message_request.messages) > 1:
llm_name = get_default_llm_version()[0]
if persona and persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(model_name=llm_name)
max_history_tokens = int(input_tokens * thread_context_percent)
remaining_tokens = input_tokens - max_history_tokens
query_text = new_message_request.messages[0].message
if persona:
max_document_tokens = compute_max_document_tokens_for_persona(
persona=persona,
actual_user_input=query_text,
max_llm_token_override=remaining_tokens,
)
else:
max_document_tokens = (
remaining_tokens
- 512 # Needs to be more than any of the QA prompts
- check_number_of_tokens(query_text)
)
with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1:
persona = cast(
Persona,
fetch_persona_by_id(db_session, new_message_request.persona_id),
)
llm = get_llm_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
)
max_history_tokens = int(input_tokens * thread_context_percent)
remaining_tokens = input_tokens - max_history_tokens
query_text = new_message_request.messages[0].message
if persona:
max_document_tokens = compute_max_document_tokens_for_persona(
persona=persona,
actual_user_input=query_text,
max_llm_token_override=remaining_tokens,
)
else:
max_document_tokens = (
remaining_tokens
- 512 # Needs to be more than any of the QA prompts
- check_number_of_tokens(query_text)
)
# This also handles creating the query event in postgres
answer = get_search_answer(
query_req=new_message_request,

View File

@ -498,6 +498,7 @@ def upsert_persona(
recency_bias: RecencyBiasSetting,
prompts: list[Prompt] | None,
document_sets: list[DBDocumentSet] | None,
llm_model_provider_override: str | None,
llm_model_version_override: str | None,
starter_messages: list[StarterMessage] | None,
is_public: bool,
@ -524,6 +525,7 @@ def upsert_persona(
persona.llm_filter_extraction = llm_filter_extraction
persona.recency_bias = recency_bias
persona.default_persona = default_persona
persona.llm_model_provider_override = llm_model_provider_override
persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages
persona.deleted = False # Un-delete if previously deleted
@ -553,6 +555,7 @@ def upsert_persona(
default_persona=default_persona,
prompts=prompts or [],
document_sets=document_sets or [],
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
starter_messages=starter_messages,
)

View File

@ -71,7 +71,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
return _ASYNC_ENGINE
def get_session_context_manager() -> ContextManager:
def get_session_context_manager() -> ContextManager[Session]:
return contextlib.contextmanager(get_session)()

96
backend/danswer/db/llm.py Normal file
View File

@ -0,0 +1,96 @@
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
def upsert_llm_provider(
db_session: Session, llm_provider: LLMProviderUpsertRequest
) -> FullLLMProvider:
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider:
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = (
llm_provider.fast_default_model_name
)
existing_llm_provider.model_names = llm_provider.model_names
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
# if it does not exist, create a new entry
llm_provider_model = LLMProviderModel(
name=llm_provider.name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
default_model_name=llm_provider.default_model_name,
fast_default_model_name=llm_provider.fast_default_model_name,
model_names=llm_provider.model_names,
is_default_provider=None,
)
db_session.add(llm_provider_model)
db_session.commit()
return FullLLMProvider.from_model(llm_provider_model)
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
return list(db_session.scalars(select(LLMProviderModel)).all())
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_provider == True # noqa: E712
)
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)
db_session.commit()
def update_default_provider(db_session: Session, provider_id: int) -> None:
new_default = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)
if not new_default:
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
existing_default = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_provider == True # noqa: E712
)
)
if existing_default:
existing_default.is_default_provider = None
# required to ensure that the below does not cause a unique constraint violation
db_session.flush()
new_default.is_default_provider = True
db_session.commit()

View File

@ -8,8 +8,8 @@ from typing import Optional
from typing import TypedDict
from uuid import UUID
from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
from sqlalchemy import Boolean
from sqlalchemy import DateTime
@ -695,6 +695,33 @@ Structures, Organizational, Configurations Tables
"""
class LLMProvider(Base):
__tablename__ = "llm_provider"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
api_key: Mapped[str | None] = mapped_column(String, nullable=True)
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
# custom configs that should be passed to the LLM provider at inference time
# (e.g. `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, etc. for bedrock)
custom_config: Mapped[dict[str, str] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
default_model_name: Mapped[str] = mapped_column(String)
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
# The LLMs that are available for this provider. Only required if not a default provider.
# If a default provider, then the LLM options are pulled from the `options.py` file.
# If needed, can be pulled out as a separate table in the future.
model_names: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
class DocumentSet(Base):
__tablename__ = "document_set"
@ -792,6 +819,9 @@ class Persona(Base):
# globablly via env variables. For flexibility, validity is not currently enforced
# NOTE: only is applied on the actual response generation - is not used for things like
# auto-detected time filters, relevance filters, etc.
llm_model_provider_override: Mapped[str | None] = mapped_column(
String, nullable=True
)
llm_model_version_override: Mapped[str | None] = mapped_column(
String, nullable=True
)

View File

@ -1,11 +1,13 @@
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.chat import get_prompts_by_ids
from danswer.db.chat import upsert_persona
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.models import Persona
from danswer.db.models import Persona__User
from danswer.db.models import User
from danswer.server.features.persona.models import CreatePersonaRequest
@ -69,6 +71,7 @@ def create_update_persona(
recency_bias=create_persona_request.recency_bias,
prompts=prompts,
document_sets=document_sets,
llm_model_provider_override=create_persona_request.llm_model_provider_override,
llm_model_version_override=create_persona_request.llm_model_version_override,
starter_messages=create_persona_request.starter_messages,
is_public=create_persona_request.is_public,
@ -91,3 +94,7 @@ def create_update_persona(
logger.exception("Failed to create persona")
raise HTTPException(status_code=400, detail=str(e))
return PersonaSnapshot.from_model(persona)
def fetch_persona_by_id(db_session: Session, persona_id: int) -> Persona | None:
return db_session.scalar(select(Persona).where(Persona.id == persona_id))

View File

@ -59,6 +59,7 @@ def create_slack_bot_persona(
recency_bias=RecencyBiasSetting.AUTO,
prompts=None,
document_sets=document_sets,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
is_public=True,

View File

@ -1,9 +1,27 @@
import json
from pathlib import Path
from typing import cast
from danswer.configs.app_configs import DYNAMIC_CONFIG_DIR_PATH
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_API_VERSION
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.factory import PostgresBackedDynamicConfigStore
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.utils.logger import setup_logger
logger = setup_logger()
def read_file_system_store(directory_path: str) -> dict:
@ -38,3 +56,60 @@ def insert_into_postgres(store_data: dict) -> None:
def port_filesystem_to_postgres(directory_path: str = DYNAMIC_CONFIG_DIR_PATH) -> None:
store_data = read_file_system_store(directory_path)
insert_into_postgres(store_data)
def port_api_key_to_postgres() -> None:
# can't port over custom, no longer supported
if GEN_AI_MODEL_PROVIDER == "custom":
return
with get_session_context_manager() as db_session:
# if we already have ported things over / setup providers in the db, don't do anything
if len(fetch_existing_llm_providers(db_session)) > 0:
return
api_key = GEN_AI_API_KEY
try:
api_key = cast(
str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)
)
except ConfigNotFoundError:
pass
# if no API key set, don't port anything over
if not api_key:
return
default_model_name = GEN_AI_MODEL_VERSION
if GEN_AI_MODEL_PROVIDER == "openai" and not default_model_name:
default_model_name = "gpt-4"
# if no default model name found, don't port anything over
if not default_model_name:
return
default_fast_model_name = FAST_GEN_AI_MODEL_VERSION
if GEN_AI_MODEL_PROVIDER == "openai" and not default_fast_model_name:
default_fast_model_name = "gpt-3.5-turbo"
llm_provider_upsert = LLMProviderUpsertRequest(
name=GEN_AI_MODEL_PROVIDER,
api_key=api_key,
api_base=GEN_AI_API_ENDPOINT,
api_version=GEN_AI_API_VERSION,
# can't port over any custom configs, since we don't know
# all the possible keys and values that could be in there
custom_config=None,
default_model_name=default_model_name,
fast_default_model_name=default_fast_model_name,
model_names=None,
)
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
update_default_provider(db_session, llm_provider.id)
logger.info(f"Ported over LLM provider:\n\n{llm_provider}")
# delete the old API key
try:
get_dynamic_config_store().delete(GEN_AI_API_KEY_STORAGE_KEY)
except ConfigNotFoundError:
pass

View File

@ -12,7 +12,6 @@ from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.llm.answering.doc_pruning import prune_documents
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import LLMConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import StreamProcessor
@ -26,7 +25,7 @@ from danswer.llm.answering.stream_processing.citation_processing import (
from danswer.llm.answering.stream_processing.quotes_processing import (
build_quotes_processor,
)
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_tokenizer
@ -53,7 +52,7 @@ class Answer:
question: str,
docs: list[LlmDoc],
answer_style_config: AnswerStyleConfig,
llm_config: LLMConfig,
llm: LLM,
prompt_config: PromptConfig,
# must be the same length as `docs`. If None, all docs are considered "relevant"
doc_relevance_list: list[bool] | None = None,
@ -74,15 +73,9 @@ class Answer:
self.single_message_history = single_message_history
self.answer_style_config = answer_style_config
self.llm_config = llm_config
self.prompt_config = prompt_config
self.llm = get_default_llm(
gen_ai_model_provider=self.llm_config.model_provider,
gen_ai_model_version_override=self.llm_config.model_version,
timeout=timeout,
temperature=self.llm_config.temperature,
)
self.llm = llm
self.llm_tokenizer = get_default_llm_tokenizer()
self._final_prompt: list[BaseMessage] | None = None
@ -101,7 +94,7 @@ class Answer:
docs=self.docs,
doc_relevance_list=self.doc_relevance_list,
prompt_config=self.prompt_config,
llm_config=self.llm_config,
llm_config=self.llm.config,
question=self.question,
document_pruning_config=self.answer_style_config.document_pruning_config,
)
@ -116,7 +109,7 @@ class Answer:
self._final_prompt = build_citations_prompt(
question=self.question,
message_history=self.message_history,
llm_config=self.llm_config,
llm_config=self.llm.config,
prompt_config=self.prompt_config,
context_docs=self.pruned_docs,
all_doc_useful=self.answer_style_config.citation_config.all_docs_useful,

View File

@ -7,9 +7,9 @@ from danswer.chat.models import (
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import LLMConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str

View File

@ -9,15 +9,11 @@ from pydantic import root_validator
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.llm.utils import get_default_llm_version
if TYPE_CHECKING:
from danswer.db.models import ChatMessage
from danswer.db.models import Prompt
from danswer.db.models import Persona
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
@ -86,36 +82,6 @@ class AnswerStyleConfig(BaseModel):
return values
class LLMConfig(BaseModel):
"""Final representation of the LLM configuration passed into
the `Answer` object."""
model_provider: str
model_version: str
temperature: float
@classmethod
def from_persona(
cls, persona: "Persona", llm_override: LLMOverride | None = None
) -> "LLMConfig":
model_provider_override = llm_override.model_provider if llm_override else None
model_version_override = llm_override.model_version if llm_override else None
temperature_override = llm_override.temperature if llm_override else None
return cls(
model_provider=model_provider_override or GEN_AI_MODEL_PROVIDER,
model_version=(
model_version_override
or persona.llm_model_version_override
or get_default_llm_version()[0]
),
temperature=temperature_override or 0.0,
)
class Config:
frozen = True
class PromptConfig(BaseModel):
"""Final representation of the Prompt configuration passed
into the `Answer` object."""

View File

@ -11,9 +11,10 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_default_prompt
from danswer.db.models import Persona
from danswer.llm.answering.models import LLMConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import get_max_input_tokens
@ -131,7 +132,9 @@ def compute_max_document_tokens(
max_input_tokens = (
max_llm_token_override
if max_llm_token_override
else get_max_input_tokens(model_name=llm_config.model_version)
else get_max_input_tokens(
model_name=llm_config.model_name, model_provider=llm_config.model_provider
)
)
prompt_tokens = get_prompt_tokens(prompt_config)
@ -152,7 +155,7 @@ def compute_max_document_tokens_for_persona(
prompt = persona.prompts[0] if persona.prompts else get_default_prompt()
return compute_max_document_tokens(
prompt_config=PromptConfig.from_model(prompt),
llm_config=LLMConfig.from_persona(persona),
llm_config=get_llm_for_persona(persona).config,
actual_user_input=actual_user_input,
max_llm_token_override=max_llm_token_override,
)
@ -162,7 +165,7 @@ def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int:
"""Maximum tokens allows in the input to the LLM (of any type)."""
input_tokens = get_max_input_tokens(
model_name=llm_config.model_version, model_provider=llm_config.model_provider
model_name=llm_config.model_name, model_provider=llm_config.model_provider
)
return input_tokens - _MISC_BUFFER

View File

@ -1,4 +1,5 @@
import abc
import os
from collections.abc import Iterator
import litellm # type:ignore
@ -12,11 +13,9 @@ from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_API_VERSION
from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_version
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import message_generator_to_string_generator
from danswer.llm.utils import should_be_verbose
from danswer.utils.logger import setup_logger
@ -69,6 +68,7 @@ class LangChainChatLLM(LLM, abc.ABC):
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
if LOG_ALL_MODEL_INTERACTIONS:
self.log_model_configs()
self._log_prompt(prompt)
if DISABLE_LITELLM_STREAMING:
@ -85,23 +85,6 @@ class LangChainChatLLM(LLM, abc.ABC):
logger.debug(f"Raw Model Output:\n{full_output}")
def _get_model_str(
model_provider: str | None,
model_version: str | None,
) -> str:
if model_provider and model_version:
return model_provider + "/" + model_version
if model_version:
# Litellm defaults to openai if no provider specified
# It's implicit so no need to specify here either
return model_version
# User specified something wrong, just use Danswer default
base, _ = get_default_llm_version()
return base
class DefaultMultiLLM(LangChainChatLLM):
"""Uses Litellm library to allow easy configuration to use a multitude of LLMs
See https://python.langchain.com/docs/integrations/chat/litellm"""
@ -115,25 +98,35 @@ class DefaultMultiLLM(LangChainChatLLM):
self,
api_key: str | None,
timeout: int,
model_provider: str = GEN_AI_MODEL_PROVIDER,
model_version: str | None = GEN_AI_MODEL_VERSION,
model_provider: str,
model_name: str,
api_base: str | None = GEN_AI_API_ENDPOINT,
api_version: str | None = GEN_AI_API_VERSION,
custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
temperature: float = GEN_AI_TEMPERATURE,
custom_config: dict[str, str] | None = None,
):
self._model_provider = model_provider
self._model_version = model_name
self._temperature = temperature
# Litellm Langchain integration currently doesn't take in the api key param
# Can place this in the call below once integration is in
litellm.api_key = api_key or "dummy-key"
litellm.api_version = api_version
model_version = model_version or get_default_llm_version()[0]
# NOTE: have to set these as environment variables for Litellm since
# not all are able to passed in but they always support them set as env
# variables
if custom_config:
for k, v in custom_config.items():
os.environ[k] = v
self._llm = ChatLiteLLM( # type: ignore
model=model_version
if custom_llm_provider
else _get_model_str(model_provider, model_version),
model=(
model_name if custom_llm_provider else f"{model_provider}/{model_name}"
),
api_base=api_base,
custom_llm_provider=custom_llm_provider,
max_tokens=max_output_tokens,
@ -148,6 +141,14 @@ class DefaultMultiLLM(LangChainChatLLM):
max_retries=0, # retries are handled outside of langchain
)
@property
def config(self) -> LLMConfig:
return LLMConfig(
model_provider=self._model_provider,
model_name=self._model_version,
temperature=self._temperature,
)
@property
def llm(self) -> ChatLiteLLM:
return self._llm

View File

@ -1,48 +1,90 @@
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_default_provider
from danswer.db.llm import fetch_provider
from danswer.db.models import Persona
from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.custom_llm import CustomModelServer
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.gpt_4_all import DanswerGPT4All
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_gen_ai_api_key
from danswer.llm.override_models import LLMOverride
def get_llm_for_persona(
persona: Persona, llm_override: LLMOverride | None = None
) -> LLM:
model_provider_override = llm_override.model_provider if llm_override else None
model_version_override = llm_override.model_version if llm_override else None
temperature_override = llm_override.temperature if llm_override else None
return get_default_llm(
gen_ai_model_provider=model_provider_override
or persona.llm_model_provider_override,
gen_ai_model_version_override=(
model_version_override or persona.llm_model_version_override
),
temperature=temperature_override or GEN_AI_TEMPERATURE,
)
def get_default_llm(
gen_ai_model_provider: str = GEN_AI_MODEL_PROVIDER,
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
temperature: float = GEN_AI_TEMPERATURE,
use_fast_llm: bool = False,
gen_ai_model_provider: str | None = None,
gen_ai_model_version_override: str | None = None,
) -> LLM:
"""A single place to fetch the configured LLM for Danswer
Also allows overriding certain LLM defaults"""
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
if gen_ai_model_version_override:
model_version = gen_ai_model_version_override
else:
base, fast = get_default_llm_version()
model_version = fast if use_fast_llm else base
if api_key is None:
api_key = get_gen_ai_api_key()
# TODO: pass this in
with get_session_context_manager() as session:
if gen_ai_model_provider is None:
llm_provider = fetch_default_provider(session)
else:
llm_provider = fetch_provider(session, gen_ai_model_provider)
if gen_ai_model_provider.lower() == "custom":
return CustomModelServer(api_key=api_key, timeout=timeout)
if not llm_provider:
raise ValueError("No default LLM provider found")
if gen_ai_model_provider.lower() == "gpt4all":
return DanswerGPT4All(
model_version=model_version, timeout=timeout, temperature=temperature
)
model_name = gen_ai_model_version_override or (
(llm_provider.fast_default_model_name or llm_provider.default_model_name)
if use_fast_llm
else llm_provider.default_model_name
)
if not model_name:
raise ValueError("No default model name found")
return DefaultMultiLLM(
model_version=model_version,
api_key=api_key,
return get_llm(
provider=llm_provider.name,
model=model_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
timeout=timeout,
temperature=temperature,
)
def get_llm(
provider: str,
model: str,
api_key: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
custom_config: dict[str, str] | None = None,
temperature: float = GEN_AI_TEMPERATURE,
timeout: int = QA_TIMEOUT,
) -> LLM:
return DefaultMultiLLM(
model_provider=provider,
model_name=model,
api_key=api_key,
api_base=api_base,
api_version=api_version,
timeout=timeout,
temperature=temperature,
custom_config=custom_config,
)

View File

@ -4,11 +4,9 @@ from typing import Any
from langchain.schema.language_model import LanguageModelInput
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.interfaces import LLM
from danswer.llm.utils import convert_lm_input_to_basic_string
from danswer.llm.utils import get_default_llm_version
from danswer.utils.logger import setup_logger
@ -38,7 +36,10 @@ except ImportError:
class DanswerGPT4All(LLM):
"""Option to run an LLM locally, however this is significantly slower and
answers tend to be much worse"""
answers tend to be much worse
NOTE: currently unused, but kept for future reference / if we want to add this back.
"""
@property
def requires_warm_up(self) -> bool:
@ -53,14 +54,14 @@ class DanswerGPT4All(LLM):
def __init__(
self,
timeout: int,
model_version: str | None = GEN_AI_MODEL_VERSION,
model_version: str,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
temperature: float = GEN_AI_TEMPERATURE,
):
self.timeout = timeout
self.max_output_tokens = max_output_tokens
self.temperature = temperature
self.gpt4all_model = GPT4All(model_version or get_default_llm_version()[0])
self.gpt4all_model = GPT4All(model_version)
def log_model_configs(self) -> None:
logger.debug(

View File

@ -2,6 +2,7 @@ import abc
from collections.abc import Iterator
from langchain.schema.language_model import LanguageModelInput
from pydantic import BaseModel
from danswer.utils.logger import setup_logger
@ -9,6 +10,12 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
class LLMConfig(BaseModel):
model_provider: str
model_name: str
temperature: float
class LLM(abc.ABC):
"""Mimics the LangChain LLM / BaseChatModel interfaces to make it easy
to use these implementations to connect to a variety of LLM providers."""
@ -22,6 +29,11 @@ class LLM(abc.ABC):
def requires_api_key(self) -> bool:
return True
@property
@abc.abstractmethod
def config(self) -> LLMConfig:
raise NotImplementedError
@abc.abstractmethod
def log_model_configs(self) -> None:
raise NotImplementedError

View File

@ -0,0 +1,109 @@
import litellm # type: ignore
from pydantic import BaseModel
class WellKnownLLMProviderDescriptor(BaseModel):
name: str
display_name: str | None = None
api_key_required: bool
api_base_required: bool
api_version_required: bool
custom_config_keys: list[str] | None = None
llm_names: list[str]
default_model: str | None = None
default_fast_model: str | None = None
OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"gpt-4",
"gpt-4-turbo-preview",
"gpt-4-1106-preview",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
]
BEDROCK_PROVIDER_NAME = "bedrock"
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named
# models
BEDROCK_MODEL_NAMES = [model for model in litellm.bedrock_models if "/" not in model][
::-1
]
ANTHROPIC_PROVIDER_NAME = "anthropic"
ANTHROPIC_MODEL_NAMES = [model for model in litellm.anthropic_models][::-1]
AZURE_PROVIDER_NAME = "azure"
_PROVIDER_TO_MODELS_MAP = {
OPENAI_PROVIDER_NAME: OPEN_AI_MODEL_NAMES,
BEDROCK_PROVIDER_NAME: BEDROCK_MODEL_NAMES,
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_MODEL_NAMES,
}
def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
return [
WellKnownLLMProviderDescriptor(
name="openai",
display_name="OpenAI",
api_key_required=True,
api_base_required=False,
api_version_required=False,
custom_config_keys=[],
llm_names=fetch_models_for_provider("openai"),
default_model="gpt-4",
default_fast_model="gpt-3.5-turbo",
),
WellKnownLLMProviderDescriptor(
name=ANTHROPIC_PROVIDER_NAME,
display_name="Anthropic",
api_key_required=True,
api_base_required=False,
api_version_required=False,
custom_config_keys=[],
llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME),
default_model="claude-3-opus-20240229",
default_fast_model="claude-3-sonnet-20240229",
),
WellKnownLLMProviderDescriptor(
name=AZURE_PROVIDER_NAME,
display_name="Azure OpenAI",
api_key_required=True,
api_base_required=True,
api_version_required=True,
custom_config_keys=[],
llm_names=fetch_models_for_provider(AZURE_PROVIDER_NAME),
),
WellKnownLLMProviderDescriptor(
name=BEDROCK_PROVIDER_NAME,
display_name="AWS Bedrock",
api_key_required=False,
api_base_required=False,
api_version_required=False,
custom_config_keys=[
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
],
llm_names=fetch_models_for_provider(BEDROCK_PROVIDER_NAME),
default_model="anthropic.claude-3-sonnet-20240229-v1:0",
default_fast_model="anthropic.claude-3-haiku-20240307-v1:0",
),
]
def fetch_models_for_provider(provider_name: str) -> list[str]:
return _PROVIDER_TO_MODELS_MAP.get(provider_name, [])

View File

@ -1,9 +1,7 @@
from collections.abc import Callable
from collections.abc import Iterator
from copy import copy
from functools import lru_cache
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
from typing import Union
@ -20,19 +18,12 @@ from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from tiktoken.core import Encoding
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.constants import GEN_AI_DETECTED_MODEL
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.models import ChatMessage
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.interfaces import LLM
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
@ -47,34 +38,6 @@ _LLM_TOKENIZER: Any = None
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
@lru_cache()
def get_default_llm_version() -> tuple[str, str]:
default_openai_model = "gpt-3.5-turbo-16k-0613"
if GEN_AI_MODEL_VERSION:
llm_version = GEN_AI_MODEL_VERSION
else:
if GEN_AI_MODEL_PROVIDER != "openai":
logger.warning("No LLM Model Version set")
# Either this value is unused or it will throw an error
llm_version = default_openai_model
else:
kv_store = get_dynamic_config_store()
try:
llm_version = cast(str, kv_store.load(GEN_AI_DETECTED_MODEL))
except ConfigNotFoundError:
llm_version = default_openai_model
if FAST_GEN_AI_MODEL_VERSION:
fast_llm_version = FAST_GEN_AI_MODEL_VERSION
else:
if GEN_AI_MODEL_PROVIDER == "openai":
fast_llm_version = default_openai_model
else:
fast_llm_version = llm_version
return llm_version, fast_llm_version
def get_default_llm_tokenizer() -> Encoding:
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
global _LLM_TOKENIZER
@ -219,17 +182,6 @@ def check_number_of_tokens(
return len(encode_fn(text))
def get_gen_ai_api_key() -> str | None:
# first check if the key has been provided by the UI
try:
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
except ConfigNotFoundError:
pass
# if not provided by the UI, fallback to the env variable
return GEN_AI_API_KEY
def test_llm(llm: LLM) -> str | None:
# try for up to 2 timeouts (e.g. 10 seconds in total)
error_msg = None
@ -246,7 +198,7 @@ def test_llm(llm: LLM) -> str | None:
def get_llm_max_tokens(
model_map: dict,
model_name: str | None = GEN_AI_MODEL_VERSION,
model_name: str,
model_provider: str = GEN_AI_MODEL_PROVIDER,
) -> int:
"""Best effort attempt to get the max tokens for the LLM"""
@ -254,13 +206,10 @@ def get_llm_max_tokens(
# This is an override, so always return this
return GEN_AI_MAX_TOKENS
model_name = model_name or get_default_llm_version()[0]
try:
if model_provider == "openai":
model_obj = model_map.get(f"{model_provider}/{model_name}")
if not model_obj:
model_obj = model_map[model_name]
else:
model_obj = model_map[f"{model_provider}/{model_name}"]
if "max_input_tokens" in model_obj:
return model_obj["max_input_tokens"]
@ -277,8 +226,8 @@ def get_llm_max_tokens(
def get_max_input_tokens(
model_name: str | None = GEN_AI_MODEL_VERSION,
model_provider: str = GEN_AI_MODEL_PROVIDER,
model_name: str,
model_provider: str,
output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
) -> int:
# NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually
@ -289,8 +238,6 @@ def get_max_input_tokens(
# model_map is litellm.model_cost
litellm_model_map = litellm.model_cost
model_name = model_name or get_default_llm_version()[0]
input_toks = (
get_llm_max_tokens(
model_name=model_name,

View File

@ -33,8 +33,6 @@ from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.constants import AuthType
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.db.chat import delete_old_default_personas
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
@ -48,9 +46,8 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs.port_configs import port_api_key_to_postgres
from danswer.dynamic_configs.port_configs import port_filesystem_to_postgres
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import get_default_llm_version
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.auth_check import check_router_auth
@ -67,6 +64,8 @@ from danswer.server.features.prompt.api import basic_router as prompt_router
from danswer.server.gpts.api import router as gpts_router
from danswer.server.manage.administrative import router as admin_router
from danswer.server.manage.get_state import router as state_router
from danswer.server.manage.llm.api import admin_router as llm_admin_router
from danswer.server.manage.llm.api import basic_router as llm_router
from danswer.server.manage.secondary_index import router as secondary_index_router
from danswer.server.manage.slack_bot import router as slack_bot_management_router
from danswer.server.manage.users import router as user_router
@ -156,17 +155,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
if DISABLE_GENERATIVE_AI:
logger.info("Generative AI Q&A disabled")
else:
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
base, fast = get_default_llm_version()
logger.info(f"Using LLM Model Version: {base}")
if base != fast:
logger.info(f"Using Fast LLM Model Version: {fast}")
if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
# Any additional model configs logged here
get_default_llm().log_model_configs()
if MULTILINGUAL_QUERY_EXPANSION:
logger.info(
@ -180,6 +168,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
"Skipping port of persistent volumes. Maybe these have already been removed?"
)
try:
port_api_key_to_postgres()
except Exception as e:
logger.debug(f"Failed to port API keys. Exception: {e}. Continuing...")
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
db_embedding_model = get_current_db_embedding_model(db_session)
@ -281,6 +274,8 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, gpts_router)
include_router_with_global_prefix_prepended(application, settings_router)
include_router_with_global_prefix_prepended(application, settings_admin_router)
include_router_with_global_prefix_prepended(application, llm_admin_router)
include_router_with_global_prefix_prepended(application, llm_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step

View File

@ -28,9 +28,9 @@ from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import LLMConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import QuotesConfig
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
@ -212,7 +212,7 @@ def stream_answer_objects(
docs=[llm_doc_from_inference_section(section) for section in top_sections],
answer_style_config=answer_config,
prompt_config=PromptConfig.from_model(prompt),
llm_config=LLMConfig.from_persona(chat_session.persona),
llm=get_llm_for_persona(persona=chat_session.persona),
doc_relevance_list=search_pipeline.section_relevance_list,
single_message_history=history_str,
timeout=timeout,

View File

@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.db.chat import get_persona_by_id
from danswer.db.chat import get_personas
from danswer.db.chat import mark_persona_as_deleted
@ -15,7 +14,6 @@ from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.persona import create_update_persona
from danswer.llm.answering.prompts.utils import build_dummy_prompt
from danswer.llm.utils import get_default_llm_version
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.features.persona.models import PromptTemplateResponse
@ -168,49 +166,3 @@ def build_final_template_prompt(
retrieval_disabled=retrieval_disabled,
)
)
"""Utility endpoints for selecting which model to use for a persona.
Putting here for now, since we have no other flows which use this."""
GPT_4_MODEL_VERSIONS = [
"gpt-4",
"gpt-4-turbo-preview",
"gpt-4-1106-preview",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
]
GPT_3_5_TURBO_MODEL_VERSIONS = [
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
]
@basic_router.get("/utils/list-available-models")
def list_available_model_versions(
_: User | None = Depends(current_user),
) -> list[str]:
# currently only support selecting different models for OpenAI
if GEN_AI_MODEL_PROVIDER != "openai":
return []
return GPT_4_MODEL_VERSIONS + GPT_3_5_TURBO_MODEL_VERSIONS
@basic_router.get("/utils/default-model")
def get_default_model(
_: User | None = Depends(current_user),
) -> str:
# currently only support selecting different models for OpenAI
if GEN_AI_MODEL_PROVIDER != "openai":
return ""
return get_default_llm_version()[0]

View File

@ -20,6 +20,7 @@ class CreatePersonaRequest(BaseModel):
recency_bias: RecencyBiasSetting
prompt_ids: list[int]
document_set_ids: list[int]
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
starter_messages: list[StarterMessage] | None = None
# For Private Personas, who should be able to access these
@ -38,6 +39,7 @@ class PersonaSnapshot(BaseModel):
num_chunks: float | None
llm_relevance_filter: bool
llm_filter_extraction: bool
llm_model_provider_override: str | None
llm_model_version_override: str | None
starter_messages: list[StarterMessage] | None
default_persona: bool
@ -66,6 +68,7 @@ class PersonaSnapshot(BaseModel):
num_chunks=persona.num_chunks,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
default_persona=persona.default_persona,

View File

@ -1,5 +1,4 @@
import json
from collections.abc import Callable
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@ -16,13 +15,9 @@ from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.app_configs import TOKEN_BUDGET_GLOBALLY_ENABLED
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import ENABLE_TOKEN_BUDGET
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.constants import GEN_AI_DETECTED_MODEL
from danswer.configs.constants import TOKEN_BUDGET
from danswer.configs.constants import TOKEN_BUDGET_SETTINGS
from danswer.configs.constants import TOKEN_BUDGET_TIME_PERIOD
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.engine import get_session
@ -35,18 +30,13 @@ from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_default_llm_version
from danswer.llm.utils import get_gen_ai_api_key
from danswer.llm.utils import test_llm
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.manage.models import BoostDoc
from danswer.server.manage.models import BoostUpdateRequest
from danswer.server.manage.models import HiddenUpdateRequest
from danswer.server.models import ApiKey
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
router = APIRouter(prefix="/manage")
logger = setup_logger()
@ -123,52 +113,6 @@ def document_hidden_update(
raise HTTPException(status_code=400, detail=str(e))
def _validate_llm_key(genai_api_key: str | None) -> None:
# Checking new API key, may change things for this if kv-store is updated
get_default_llm_version.cache_clear()
kv_store = get_dynamic_config_store()
try:
kv_store.delete(GEN_AI_DETECTED_MODEL)
except ConfigNotFoundError:
pass
gpt_4_version = "gpt-4" # 32k is not available to most people
gpt4_llm = None
try:
llm = get_default_llm(api_key=genai_api_key, timeout=10)
if GEN_AI_MODEL_PROVIDER == "openai" and not GEN_AI_MODEL_VERSION:
gpt4_llm = get_default_llm(
gen_ai_model_version_override=gpt_4_version,
api_key=genai_api_key,
timeout=10,
)
except GenAIDisabledException:
return
functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))]
if gpt4_llm:
functions_with_args.append((test_llm, (gpt4_llm,)))
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
error_msg = parallel_results[0]
if error_msg:
if genai_api_key is None:
raise HTTPException(status_code=404, detail="Key not found")
raise HTTPException(status_code=400, detail=error_msg)
# Mark check as successful
curr_time = datetime.now(tz=timezone.utc)
kv_store.store(GEN_AI_KEY_CHECK_TIME, curr_time.timestamp())
# None for no errors
if gpt4_llm and parallel_results[1] is None:
kv_store.store(GEN_AI_DETECTED_MODEL, gpt_4_version)
@router.get("/admin/genai-api-key/validate")
def validate_existing_genai_api_key(
_: User = Depends(current_admin_user),
@ -187,46 +131,18 @@ def validate_existing_genai_api_key(
# First time checking the key, nothing unusual
pass
genai_api_key = get_gen_ai_api_key()
_validate_llm_key(genai_api_key)
@router.get("/admin/genai-api-key", response_model=ApiKey)
def get_gen_ai_api_key_from_dynamic_config_store(
_: User = Depends(current_admin_user),
) -> ApiKey:
"""
NOTE: Only gets value from dynamic config store as to not expose env variables.
"""
try:
# only get last 4 characters of key to not expose full key
return ApiKey(
api_key=cast(
str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)
)[-4:]
)
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found")
llm = get_default_llm(timeout=10)
except ValueError:
raise HTTPException(status_code=404, detail="LLM not setup")
error = test_llm(llm)
if error:
raise HTTPException(status_code=400, detail=error)
@router.put("/admin/genai-api-key")
def store_genai_api_key(
request: ApiKey,
_: User = Depends(current_admin_user),
) -> None:
if not request.api_key:
raise HTTPException(400, "No API key provided")
_validate_llm_key(request.api_key)
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
@router.delete("/admin/genai-api-key")
def delete_genai_api_key(
_: User = Depends(current_admin_user),
) -> None:
get_dynamic_config_store().delete(GEN_AI_API_KEY_STORAGE_KEY)
# Mark check as successful
curr_time = datetime.now(tz=timezone.utc)
kv_store.store(GEN_AI_KEY_CHECK_TIME, curr_time.timestamp())
@router.post("/admin/deletion-attempt")

View File

@ -0,0 +1,157 @@
from collections.abc import Callable
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.llm import remove_llm_provider
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.models import User
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_llm
from danswer.llm.options import fetch_available_well_known_llms
from danswer.llm.options import WellKnownLLMProviderDescriptor
from danswer.llm.utils import test_llm
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderDescriptor
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.server.manage.llm.models import TestLLMRequest
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/llm")
basic_router = APIRouter(prefix="/llm")
@admin_router.get("/built-in/options")
def fetch_llm_options(
_: User | None = Depends(current_admin_user),
) -> list[WellKnownLLMProviderDescriptor]:
return fetch_available_well_known_llms()
@admin_router.post("/test")
def test_llm_configuration(
test_llm_request: TestLLMRequest,
_: User | None = Depends(current_admin_user),
) -> None:
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
api_key=test_llm_request.api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
)
functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))]
if (
test_llm_request.default_fast_model_name
and test_llm_request.default_fast_model_name
!= test_llm_request.default_model_name
):
fast_llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_fast_model_name,
api_key=test_llm_request.api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
)
functions_with_args.append((test_llm, (fast_llm,)))
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
error = parallel_results[0] or (
parallel_results[1] if len(parallel_results) > 1 else None
)
if error:
raise HTTPException(status_code=400, detail=error)
@admin_router.post("/test/default")
def test_default_provider(
_: User | None = Depends(current_admin_user),
) -> None:
try:
llm = get_default_llm()
fast_llm = get_default_llm(use_fast_llm=True)
except ValueError:
logger.exception("Failed to fetch default LLM Provider")
raise HTTPException(status_code=400, detail="No LLM Provider setup")
functions_with_args: list[tuple[Callable, tuple]] = [
(test_llm, (llm,)),
(test_llm, (fast_llm,)),
]
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
error = parallel_results[0] or (
parallel_results[1] if len(parallel_results) > 1 else None
)
if error:
raise HTTPException(status_code=400, detail=error)
@admin_router.get("/provider")
def list_llm_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[FullLLMProvider]:
return [
FullLLMProvider.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
]
@admin_router.put("/provider")
def put_llm_provider(
llm_provider: LLMProviderUpsertRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProvider:
return upsert_llm_provider(db_session, llm_provider)
@admin_router.delete("/provider/{provider_id}")
def delete_llm_provider(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
remove_llm_provider(db_session, provider_id)
@admin_router.post("/provider/{provider_id}/default")
def set_provider_as_default(
provider_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_provider(db_session, provider_id)
"""Endpoints for all"""
@basic_router.get("/provider")
def list_llm_provider_basics(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
return [
LLMProviderDescriptor.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
]

View File

@ -0,0 +1,89 @@
from typing import TYPE_CHECKING
from pydantic import BaseModel
from danswer.llm.options import fetch_models_for_provider
if TYPE_CHECKING:
from danswer.db.models import LLMProvider as LLMProviderModel
class TestLLMRequest(BaseModel):
# provider level
provider: str
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
# model level
default_model_name: str
default_fast_model_name: str | None = None
class LLMProviderDescriptor(BaseModel):
"""A descriptor for an LLM provider that can be safely viewed by
non-admin users. Used when giving a list of available LLMs."""
name: str
model_names: list[str]
default_model_name: str
fast_default_model_name: str | None
is_default_provider: bool | None
@classmethod
def from_model(
cls, llm_provider_model: "LLMProviderModel"
) -> "LLMProviderDescriptor":
return cls(
name=llm_provider_model.name,
default_model_name=llm_provider_model.default_model_name,
fast_default_model_name=llm_provider_model.fast_default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
model_names=(
llm_provider_model.model_names
or fetch_models_for_provider(llm_provider_model.name)
or [llm_provider_model.default_model_name]
),
)
class LLMProvider(BaseModel):
name: str
api_key: str | None
api_base: str | None
api_version: str | None
custom_config: dict[str, str] | None
default_model_name: str
fast_default_model_name: str | None
class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider
# for default providers, the built-in model names are used
model_names: list[str] | None
class FullLLMProvider(LLMProvider):
id: int
is_default_provider: bool | None
model_names: list[str]
@classmethod
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider":
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
api_key=llm_provider_model.api_key,
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,
default_model_name=llm_provider_model.default_model_name,
fast_default_model_name=llm_provider_model.fast_default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
model_names=(
llm_provider_model.model_names
or fetch_models_for_provider(llm_provider_model.name)
or [llm_provider_model.default_model_name]
),
)

View File

@ -3,6 +3,7 @@ alembic==1.10.4
asyncpg==0.27.0
atlassian-python-api==3.37.0
beautifulsoup4==4.12.2
boto3==1.34.84
celery==5.3.4
chardet==5.2.0
dask==2023.8.1

7
web/package-lock.json generated
View File

@ -15,12 +15,14 @@
"@radix-ui/react-popover": "^1.0.7",
"@tremor/react": "^3.9.2",
"@types/js-cookie": "^3.0.3",
"@types/lodash": "^4.17.0",
"@types/node": "18.15.11",
"@types/react": "18.0.32",
"@types/react-dom": "18.0.11",
"autoprefixer": "^10.4.14",
"formik": "^2.2.9",
"js-cookie": "^3.0.5",
"lodash": "^4.17.21",
"mdast-util-find-and-replace": "^3.0.1",
"next": "^14.1.0",
"postcss": "^8.4.31",
@ -1740,6 +1742,11 @@
"integrity": "sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==",
"dev": true
},
"node_modules/@types/lodash": {
"version": "4.17.0",
"resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.0.tgz",
"integrity": "sha512-t7dhREVv6dbNj0q17X12j7yDG4bD/DHYX7o5/DbDxobP0HnGPgpRz2Ej77aL7TZT3DSw13fqUTj8J4mMnqa7WA=="
},
"node_modules/@types/mdast": {
"version": "4.0.3",
"resolved": "https://registry.npmjs.org/@types/mdast/-/mdast-4.0.3.tgz",

View File

@ -16,12 +16,14 @@
"@radix-ui/react-popover": "^1.0.7",
"@tremor/react": "^3.9.2",
"@types/js-cookie": "^3.0.3",
"@types/lodash": "^4.17.0",
"@types/node": "18.15.11",
"@types/react": "18.0.32",
"@types/react-dom": "18.0.11",
"autoprefixer": "^10.4.14",
"formik": "^2.2.9",
"js-cookie": "^3.0.5",
"lodash": "^4.17.21",
"mdast-util-find-and-replace": "^3.0.1",
"next": "^14.1.0",
"postcss": "^8.4.31",

View File

@ -31,6 +31,15 @@ import { Bubble } from "@/components/Bubble";
import { GroupsIcon } from "@/components/icons/icons";
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import { FullLLMProvider } from "../models/llm/interfaces";
import { Option } from "@/components/Dropdown";
const DEFAULT_LLM_PROVIDER_TO_DISPLAY_NAME: Record<string, string> = {
openai: "OpenAI",
azure: "Azure OpenAI",
anthropic: "Anthropic",
bedrock: "AWS Bedrock",
};
function Label({ children }: { children: string | JSX.Element }) {
return (
@ -46,20 +55,18 @@ export function AssistantEditor({
existingPersona,
ccPairs,
documentSets,
llmOverrideOptions,
defaultLLM,
user,
defaultPublic,
redirectType,
llmProviders,
}: {
existingPersona?: Persona | null;
ccPairs: CCPairBasicInfo[];
documentSets: DocumentSet[];
llmOverrideOptions: string[];
defaultLLM: string;
user: User | null;
defaultPublic: boolean;
redirectType: SuccessfulPersonaUpdateRedirectType;
llmProviders: FullLLMProvider[];
}) {
const router = useRouter();
const { popup, setPopup } = usePopup();
@ -98,6 +105,21 @@ export function AssistantEditor({
}
}, []);
const defaultLLM = llmProviders.find(
(llmProvider) => llmProvider.is_default_provider
)?.default_model_name;
const modelOptionsByProvider = new Map<string, Option<string>[]>();
llmProviders.forEach((llmProvider) => {
const providerOptions = llmProvider.model_names.map((modelName) => {
return {
name: modelName,
value: modelName,
};
});
modelOptionsByProvider.set(llmProvider.name, providerOptions);
});
return (
<div>
{popup}
@ -118,6 +140,8 @@ export function AssistantEditor({
include_citations:
existingPersona?.prompts[0]?.include_citations ?? true,
llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false,
llm_model_provider_override:
existingPersona?.llm_model_provider_override ?? null,
llm_model_version_override:
existingPersona?.llm_model_version_override ?? null,
starter_messages: existingPersona?.starter_messages ?? [],
@ -139,6 +163,7 @@ export function AssistantEditor({
include_citations: Yup.boolean().required(),
llm_relevance_filter: Yup.boolean().required(),
llm_model_version_override: Yup.string().nullable(),
llm_model_provider_override: Yup.string().nullable(),
starter_messages: Yup.array().of(
Yup.object().shape({
name: Yup.string().required(),
@ -178,6 +203,18 @@ export function AssistantEditor({
return;
}
if (
values.llm_model_provider_override &&
!values.llm_model_version_override
) {
setPopup({
type: "error",
message:
"Must select a model if a non-default LLM provider is chosen.",
});
return;
}
formikHelpers.setSubmitting(true);
// if disable_retrieval is set, set num_chunks to 0
@ -428,7 +465,7 @@ export function AssistantEditor({
</>
)}
{llmOverrideOptions.length > 0 && defaultLLM && (
{llmProviders.length > 0 && (
<>
<HidableSection
sectionTitle="[Advanced] Model Selection"
@ -452,17 +489,50 @@ export function AssistantEditor({
.
</Text>
<div className="w-96">
<SelectorFormField
name="llm_model_version_override"
options={llmOverrideOptions.map((llmOption) => {
return {
name: llmOption,
value: llmOption,
};
})}
includeDefault={true}
/>
<div className="flex mt-6">
<div className="w-96">
<SubLabel>LLM Provider</SubLabel>
<SelectorFormField
name="llm_model_provider_override"
options={llmProviders.map((llmProvider) => ({
name:
DEFAULT_LLM_PROVIDER_TO_DISPLAY_NAME[
llmProvider.name
] || llmProvider.name,
value: llmProvider.name,
}))}
includeDefault={true}
onSelect={(selected) => {
if (
selected !== values.llm_model_provider_override
) {
setFieldValue(
"llm_model_version_override",
null
);
}
setFieldValue(
"llm_model_provider_override",
selected
);
}}
/>
</div>
{values.llm_model_provider_override && (
<div className="w-96 ml-4">
<SubLabel>Model</SubLabel>
<SelectorFormField
name="llm_model_version_override"
options={
modelOptionsByProvider.get(
values.llm_model_provider_override
) || []
}
maxHeight="max-h-72"
/>
</div>
)}
</div>
</>
</HidableSection>

View File

@ -30,6 +30,7 @@ export interface Persona {
num_chunks?: number;
llm_relevance_filter?: boolean;
llm_filter_extraction?: boolean;
llm_model_provider_override?: string;
llm_model_version_override?: string;
starter_messages: StarterMessage[] | null;
default_persona: boolean;

View File

@ -10,6 +10,7 @@ interface PersonaCreationRequest {
include_citations: boolean;
is_public: boolean;
llm_relevance_filter: boolean | null;
llm_model_provider_override: string | null;
llm_model_version_override: string | null;
starter_messages: StarterMessage[] | null;
users?: string[];
@ -28,6 +29,7 @@ interface PersonaUpdateRequest {
include_citations: boolean;
is_public: boolean;
llm_relevance_filter: boolean | null;
llm_model_provider_override: string | null;
llm_model_version_override: string | null;
starter_messages: StarterMessage[] | null;
users?: string[];
@ -117,6 +119,7 @@ function buildPersonaAPIBody(
recency_bias: "base_decay",
prompt_ids: [promptId],
document_set_ids,
llm_model_provider_override: creationRequest.llm_model_provider_override,
llm_model_version_override: creationRequest.llm_model_version_override,
starter_messages: creationRequest.starter_messages,
users,

View File

@ -1,13 +1,10 @@
"use client";
import { LoadingAnimation, ThreeDotsLoader } from "@/components/Loading";
import { ThreeDotsLoader } from "@/components/Loading";
import { AdminPageTitle } from "@/components/admin/Title";
import { KeyIcon, TrashIcon } from "@/components/icons/icons";
import { ApiKeyForm } from "@/components/openai/ApiKeyForm";
import { GEN_AI_API_KEY_URL } from "@/components/openai/constants";
import { errorHandlingFetcher, fetcher } from "@/lib/fetcher";
import { Button, Card, Divider, Text, Title } from "@tremor/react";
import { FiCpu, FiPackage } from "react-icons/fi";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { Button, Card, Text, Title } from "@tremor/react";
import { FiPackage } from "react-icons/fi";
import useSWR, { mutate } from "swr";
import { ModelOption, ModelSelector } from "./ModelSelector";
import { useState } from "react";

View File

@ -0,0 +1,450 @@
import { LoadingAnimation } from "@/components/Loading";
import { Button, Divider, Text } from "@tremor/react";
import {
ArrayHelpers,
ErrorMessage,
Field,
FieldArray,
Form,
Formik,
} from "formik";
import { FiPlus, FiTrash, FiX } from "react-icons/fi";
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import {
Label,
SubLabel,
TextArrayField,
TextFormField,
} from "@/components/admin/connectors/Field";
import { useState } from "react";
import { useSWRConfig } from "swr";
import { FullLLMProvider } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
function customConfigProcessing(customConfigsList: [string, string][]) {
const customConfig: { [key: string]: string } = {};
customConfigsList.forEach(([key, value]) => {
customConfig[key] = value;
});
return customConfig;
}
export function CustomLLMProviderUpdateForm({
onClose,
existingLlmProvider,
shouldMarkAsDefault,
setPopup,
}: {
onClose: () => void;
existingLlmProvider?: FullLLMProvider;
shouldMarkAsDefault?: boolean;
setPopup?: (popup: PopupSpec) => void;
}) {
const { mutate } = useSWRConfig();
const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>("");
const [isTestSuccessful, setTestSuccessful] = useState(
existingLlmProvider ? true : false
);
// Define the initial values based on the provider's requirements
const initialValues = {
name: existingLlmProvider?.name ?? "",
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? "",
api_version: existingLlmProvider?.api_version ?? "",
default_model_name: existingLlmProvider?.default_model_name ?? null,
default_fast_model_name:
existingLlmProvider?.fast_default_model_name ?? null,
model_names: existingLlmProvider?.model_names ?? [],
custom_config_list: existingLlmProvider?.custom_config
? Object.entries(existingLlmProvider.custom_config)
: [],
};
const [validatedConfig, setValidatedConfig] = useState(
existingLlmProvider ? initialValues : null
);
// Setup validation schema if required
const validationSchema = Yup.object({
name: Yup.string().required("Name is required"),
api_key: Yup.string(),
api_base: Yup.string(),
api_version: Yup.string(),
model_names: Yup.array(Yup.string().required("Model name is required")),
default_model_name: Yup.string().required("Model name is required"),
default_fast_model_name: Yup.string().nullable(),
custom_config_list: Yup.array(),
});
return (
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
// hijack this to re-enable testing on any change
validate={(values) => {
if (!isEqual(values, validatedConfig)) {
setTestSuccessful(false);
}
}}
onSubmit={async (values, { setSubmitting }) => {
setSubmitting(true);
if (!isTestSuccessful) {
setSubmitting(false);
return;
}
if (values.model_names.length === 0) {
const fullErrorMsg = "At least one model name is required";
if (setPopup) {
setPopup({
type: "error",
message: fullErrorMsg,
});
} else {
alert(fullErrorMsg);
}
setSubmitting(false);
return;
}
const response = await fetch(LLM_PROVIDERS_ADMIN_URL, {
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
...values,
custom_config: customConfigProcessing(values.custom_config_list),
}),
});
if (!response.ok) {
const errorMsg = (await response.json()).detail;
const fullErrorMsg = existingLlmProvider
? `Failed to update provider: ${errorMsg}`
: `Failed to enable provider: ${errorMsg}`;
if (setPopup) {
setPopup({
type: "error",
message: fullErrorMsg,
});
} else {
alert(fullErrorMsg);
}
return;
}
if (shouldMarkAsDefault) {
const newLlmProvider = (await response.json()) as FullLLMProvider;
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
{
method: "POST",
}
);
if (!setDefaultResponse.ok) {
const errorMsg = (await setDefaultResponse.json()).detail;
const fullErrorMsg = `Failed to set provider as default: ${errorMsg}`;
if (setPopup) {
setPopup({
type: "error",
message: fullErrorMsg,
});
} else {
alert(fullErrorMsg);
}
return;
}
}
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
const successMsg = existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!";
if (setPopup) {
setPopup({
type: "success",
message: successMsg,
});
} else {
alert(successMsg);
}
setSubmitting(false);
}}
>
{({ values }) => (
<Form>
<TextFormField
name="name"
label="Provider Name"
subtext={
<>
Should be one of the providers listed at{" "}
<a
target="_blank"
href="https://docs.litellm.ai/docs/providers"
className="text-link"
>
https://docs.litellm.ai/docs/providers
</a>
.
</>
}
placeholder="Name of the custom provider"
/>
<Divider />
<SubLabel>
Fill in the following as is needed. Refer to the LiteLLM
documentation for the model provider name specified above in order
to determine which fields are required.
</SubLabel>
<TextFormField
name="api_key"
label="[Optional] API Key"
placeholder="API Key"
type="password"
/>
<TextFormField
name="api_base"
label="[Optional] API Base"
placeholder="API Base"
/>
<TextFormField
name="api_version"
label="[Optional] API Version"
placeholder="API Version"
/>
<Label>[Optional] Custom Configs</Label>
<SubLabel>
<>
<div>
Additional configurations needed by the model provider. Are
passed to litellm via environment variables.
</div>
<div className="mt-2">
For example, when configuring the Cloudflare provider, you would
need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your
Cloudflare account ID as the value.
</div>
</>
</SubLabel>
<FieldArray
name="custom_config_list"
render={(arrayHelpers: ArrayHelpers<any[]>) => (
<div>
{values.custom_config_list.map((_, index) => {
return (
<div key={index} className={index === 0 ? "mt-2" : "mt-6"}>
<div className="flex">
<div className="w-full mr-6 border border-border p-3 rounded">
<div>
<Label>Key</Label>
<Field
name={`custom_config_list[${index}][0]`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][0]`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
<div className="mt-3">
<Label>Value</Label>
<Field
name={`custom_config_list[${index}][1]`}
className={`
border
border-border
bg-background
rounded
w-full
py-2
px-3
mr-4
`}
autoComplete="off"
/>
<ErrorMessage
name={`custom_config_list[${index}][1]`}
component="div"
className="text-error text-sm mt-1"
/>
</div>
</div>
<div className="my-auto">
<FiX
className="my-auto w-10 h-10 cursor-pointer hover:bg-hover rounded p-2"
onClick={() => arrayHelpers.remove(index)}
/>
</div>
</div>
</div>
);
})}
<Button
onClick={() => {
arrayHelpers.push(["", ""]);
}}
className="mt-3"
color="green"
size="xs"
type="button"
icon={FiPlus}
>
Add New
</Button>
</div>
)}
/>
<Divider />
<TextArrayField
name="model_names"
label="Model Names"
values={values}
subtext={`List the individual models that you want to make
available as a part of this provider. At least one must be specified.
As an example, for OpenAI one model might be "gpt-4".`}
/>
<Divider />
<TextFormField
name="default_model_name"
subtext={`
The model to use by default for this provider unless
otherwise specified. Must be one of the models listed
above.`}
label="Default Model"
placeholder="E.g. gpt-4"
/>
<TextFormField
name="default_fast_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
for this provider. If not set, will use
the Default Model configured above.`}
label="[Optional] Fast Model"
placeholder="E.g. gpt-4"
/>
<Divider />
<div>
{/* NOTE: this is above the test button to make sure it's visible */}
{!isTestSuccessful && testError && (
<Text className="text-error mt-2">{testError}</Text>
)}
{isTestSuccessful && (
<Text className="text-success mt-2">
Test successful! LLM provider is ready to go.
</Text>
)}
<div className="flex w-full mt-4">
{isTestSuccessful ? (
<Button type="submit" size="xs">
{existingLlmProvider ? "Update" : "Enable"}
</Button>
) : (
<Button
type="button"
size="xs"
disabled={isTesting}
onClick={async () => {
setIsTesting(true);
console.log(values.custom_config_list);
const response = await fetch("/api/admin/llm/test", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
provider: values.name,
custom_config: customConfigProcessing(
values.custom_config_list
),
...values,
}),
});
setIsTesting(false);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
setTestError(errorMsg);
return;
}
setTestSuccessful(true);
setValidatedConfig(values);
}}
>
{isTesting ? <LoadingAnimation text="Testing" /> : "Test"}
</Button>
)}
{existingLlmProvider && (
<Button
type="button"
color="red"
className="ml-3"
size="xs"
icon={FiTrash}
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
{
method: "DELETE",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
alert(`Failed to delete provider: ${errorMsg}`);
return;
}
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
}}
>
Delete
</Button>
)}
</div>
</div>
</Form>
)}
</Formik>
);
}

View File

@ -0,0 +1,234 @@
"use client";
import { Modal } from "@/components/Modal";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { useState } from "react";
import useSWR, { mutate } from "swr";
import { Badge, Button, Text, Title } from "@tremor/react";
import { ThreeDotsLoader } from "@/components/Loading";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm";
function LLMProviderUpdateModal({
llmProviderDescriptor,
onClose,
existingLlmProvider,
shouldMarkAsDefault,
setPopup,
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null;
onClose: () => void;
existingLlmProvider?: FullLLMProvider;
shouldMarkAsDefault?: boolean;
setPopup?: (popup: PopupSpec) => void;
}) {
const providerName =
llmProviderDescriptor?.display_name ||
llmProviderDescriptor?.name ||
existingLlmProvider?.name ||
"Custom LLM Provider";
return (
<Modal title={`Setup ${providerName}`} onOutsideClick={() => onClose()}>
<div className="max-h-[70vh] overflow-y-auto px-4">
{llmProviderDescriptor ? (
<LLMProviderUpdateForm
llmProviderDescriptor={llmProviderDescriptor}
onClose={onClose}
existingLlmProvider={existingLlmProvider}
shouldMarkAsDefault={shouldMarkAsDefault}
setPopup={setPopup}
/>
) : (
<CustomLLMProviderUpdateForm
onClose={onClose}
existingLlmProvider={existingLlmProvider}
shouldMarkAsDefault={shouldMarkAsDefault}
setPopup={setPopup}
/>
)}
</div>
</Modal>
);
}
function LLMProviderDisplay({
llmProviderDescriptor,
existingLlmProvider,
shouldMarkAsDefault,
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null;
existingLlmProvider?: FullLLMProvider;
shouldMarkAsDefault?: boolean;
}) {
const [formIsVisible, setFormIsVisible] = useState(false);
const { popup, setPopup } = usePopup();
const providerName =
llmProviderDescriptor?.display_name ||
llmProviderDescriptor?.name ||
existingLlmProvider?.name;
return (
<div>
{popup}
<div className="border border-border p-3 rounded w-96 flex shadow-md">
<div className="my-auto">
<div className="font-bold">{providerName} </div>
{existingLlmProvider && !existingLlmProvider.is_default_provider && (
<div
className="text-xs text-link cursor-pointer"
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`,
{
method: "POST",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
setPopup({
type: "error",
message: `Failed to set provider as default: ${errorMsg}`,
});
return;
}
mutate(LLM_PROVIDERS_ADMIN_URL);
setPopup({
type: "success",
message: "Provider set as default successfully!",
});
}}
>
Set as default
</div>
)}
</div>
{existingLlmProvider && (
<div className="my-auto">
{existingLlmProvider.is_default_provider ? (
<Badge color="orange" className="ml-2" size="xs">
Default
</Badge>
) : (
<Badge color="green" className="ml-2" size="xs">
Enabled
</Badge>
)}
</div>
)}
<div className="ml-auto">
<Button
color={existingLlmProvider ? "green" : "blue"}
size="xs"
onClick={() => setFormIsVisible(true)}
>
{existingLlmProvider ? "Edit" : "Set up"}
</Button>
</div>
</div>
{formIsVisible && (
<LLMProviderUpdateModal
llmProviderDescriptor={llmProviderDescriptor}
onClose={() => setFormIsVisible(false)}
existingLlmProvider={existingLlmProvider}
shouldMarkAsDefault={shouldMarkAsDefault}
setPopup={setPopup}
/>
)}
</div>
);
}
function AddCustomLLMProvider({}) {
const [formIsVisible, setFormIsVisible] = useState(false);
if (formIsVisible) {
return (
<Modal
title={`Setup Custom LLM Provider`}
onOutsideClick={() => setFormIsVisible(false)}
>
<div className="max-h-[70vh] overflow-y-auto px-4">
<CustomLLMProviderUpdateForm
onClose={() => setFormIsVisible(false)}
/>
</div>
</Modal>
);
}
return (
<Button size="xs" onClick={() => setFormIsVisible(true)}>
Add Custom LLM Provider
</Button>
);
}
export function LLMConfiguration() {
const { data: llmProviderDescriptors } = useSWR<
WellKnownLLMProviderDescriptor[]
>("/api/admin/llm/built-in/options", errorHandlingFetcher);
const { data: existingLlmProviders } = useSWR<FullLLMProvider[]>(
LLM_PROVIDERS_ADMIN_URL,
errorHandlingFetcher
);
if (!llmProviderDescriptors || !existingLlmProviders) {
return <ThreeDotsLoader />;
}
const wellKnownLLMProviderNames = llmProviderDescriptors.map(
(llmProviderDescriptor) => llmProviderDescriptor.name
);
const customLLMProviders = existingLlmProviders.filter(
(llmProvider) => !wellKnownLLMProviderNames.includes(llmProvider.name)
);
return (
<>
<Text className="mb-4">
If multiple LLM providers are enabled, the default provider will be used
for all &quot;Default&quot; Personas. For user-created Personas, you can
select the LLM provider/model that best fits the use case!
</Text>
<Title className="mb-2">Default Providers</Title>
<div className="gap-y-4 flex flex-col">
{llmProviderDescriptors.map((llmProviderDescriptor) => {
const existingLlmProvider = existingLlmProviders.find(
(llmProvider) => llmProvider.name === llmProviderDescriptor.name
);
return (
<LLMProviderDisplay
key={llmProviderDescriptor.name}
llmProviderDescriptor={llmProviderDescriptor}
existingLlmProvider={existingLlmProvider}
shouldMarkAsDefault={existingLlmProviders.length === 0}
/>
);
})}
</div>
<Title className="mb-2 mt-4">Custom Providers</Title>
{customLLMProviders.length > 0 && (
<div className="gap-y-4 flex flex-col mb-4">
{customLLMProviders.map((llmProvider) => (
<LLMProviderDisplay
key={llmProvider.id}
llmProviderDescriptor={null}
existingLlmProvider={llmProvider}
/>
))}
</div>
)}
<AddCustomLLMProvider />
</>
);
}

View File

@ -0,0 +1,347 @@
import { LoadingAnimation } from "@/components/Loading";
import { Button, Divider, Text } from "@tremor/react";
import { Form, Formik } from "formik";
import { FiTrash } from "react-icons/fi";
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import {
SelectorFormField,
TextFormField,
} from "@/components/admin/connectors/Field";
import { useState } from "react";
import { useSWRConfig } from "swr";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
export function LLMProviderUpdateForm({
llmProviderDescriptor,
onClose,
existingLlmProvider,
shouldMarkAsDefault,
setPopup,
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
onClose: () => void;
existingLlmProvider?: FullLLMProvider;
shouldMarkAsDefault?: boolean;
setPopup?: (popup: PopupSpec) => void;
}) {
const { mutate } = useSWRConfig();
const [isTesting, setIsTesting] = useState(false);
const [testError, setTestError] = useState<string>("");
const [isTestSuccessful, setTestSuccessful] = useState(
existingLlmProvider ? true : false
);
// Define the initial values based on the provider's requirements
const initialValues = {
api_key: existingLlmProvider?.api_key ?? "",
api_base: existingLlmProvider?.api_base ?? "",
api_version: existingLlmProvider?.api_version ?? "",
default_model_name:
existingLlmProvider?.default_model_name ??
(llmProviderDescriptor.default_model ||
llmProviderDescriptor.llm_names[0]),
default_fast_model_name:
existingLlmProvider?.fast_default_model_name ??
(llmProviderDescriptor.default_fast_model || null),
custom_config:
existingLlmProvider?.custom_config ??
llmProviderDescriptor.custom_config_keys?.reduce(
(acc, key) => {
acc[key] = "";
return acc;
},
{} as { [key: string]: string }
),
};
const [validatedConfig, setValidatedConfig] = useState(
existingLlmProvider ? initialValues : null
);
// Setup validation schema if required
const validationSchema = Yup.object({
api_key: llmProviderDescriptor.api_key_required
? Yup.string().required("API Key is required")
: Yup.string(),
api_base: llmProviderDescriptor.api_base_required
? Yup.string().required("API Base is required")
: Yup.string(),
api_version: llmProviderDescriptor.api_version_required
? Yup.string().required("API Version is required")
: Yup.string(),
...(llmProviderDescriptor.custom_config_keys
? {
custom_config: Yup.object(
llmProviderDescriptor.custom_config_keys.reduce(
(acc, key) => {
acc[key] = Yup.string().required(`${key} is required`);
return acc;
},
{} as { [key: string]: Yup.StringSchema }
)
),
}
: {}),
default_model_name: Yup.string().required("Model name is required"),
default_fast_model_name: Yup.string().nullable(),
});
return (
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
// hijack this to re-enable testing on any change
validate={(values) => {
if (!isEqual(values, validatedConfig)) {
setTestSuccessful(false);
}
}}
onSubmit={async (values, { setSubmitting }) => {
setSubmitting(true);
if (!isTestSuccessful) {
setSubmitting(false);
return;
}
const response = await fetch(LLM_PROVIDERS_ADMIN_URL, {
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
name: llmProviderDescriptor.name,
...values,
fast_default_model_name:
values.default_fast_model_name || values.default_model_name,
}),
});
if (!response.ok) {
const errorMsg = (await response.json()).detail;
const fullErrorMsg = existingLlmProvider
? `Failed to update provider: ${errorMsg}`
: `Failed to enable provider: ${errorMsg}`;
if (setPopup) {
setPopup({
type: "error",
message: fullErrorMsg,
});
} else {
alert(fullErrorMsg);
}
return;
}
if (shouldMarkAsDefault) {
const newLlmProvider = (await response.json()) as FullLLMProvider;
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
{
method: "POST",
}
);
if (!setDefaultResponse.ok) {
const errorMsg = (await setDefaultResponse.json()).detail;
const fullErrorMsg = `Failed to set provider as default: ${errorMsg}`;
if (setPopup) {
setPopup({
type: "error",
message: fullErrorMsg,
});
} else {
alert(fullErrorMsg);
}
return;
}
}
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
const successMsg = existingLlmProvider
? "Provider updated successfully!"
: "Provider enabled successfully!";
if (setPopup) {
setPopup({
type: "success",
message: successMsg,
});
} else {
alert(successMsg);
}
setSubmitting(false);
}}
>
{({ values }) => (
<Form>
{llmProviderDescriptor.api_key_required && (
<TextFormField
name="api_key"
label="API Key"
placeholder="API Key"
type="password"
/>
)}
{llmProviderDescriptor.api_base_required && (
<TextFormField
name="api_base"
label="API Base"
placeholder="API Base"
/>
)}
{llmProviderDescriptor.api_version_required && (
<TextFormField
name="api_version"
label="API Version"
placeholder="API Version"
/>
)}
{llmProviderDescriptor.custom_config_keys?.map((key) => (
<div key={key}>
<TextFormField name={`custom_config.${key}`} label={key} />
</div>
))}
<Divider />
{llmProviderDescriptor.llm_names.length > 0 ? (
<SelectorFormField
name="default_model_name"
subtext="The model to use by default for this provider unless otherwise specified."
label="Default Model"
options={llmProviderDescriptor.llm_names.map((name) => ({
name,
value: name,
}))}
direction="up"
maxHeight="max-h-56"
/>
) : (
<TextFormField
name="default_model_name"
subtext="The model to use by default for this provider unless otherwise specified."
label="Default Model"
placeholder="E.g. gpt-4"
/>
)}
{llmProviderDescriptor.llm_names.length > 0 ? (
<SelectorFormField
name="default_fast_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
for this provider. If \`Default\` is specified, will use
the Default Model configured above.`}
label="[Optional] Fast Model"
options={llmProviderDescriptor.llm_names.map((name) => ({
name,
value: name,
}))}
includeDefault
direction="up"
maxHeight="max-h-56"
/>
) : (
<TextFormField
name="default_fast_model_name"
subtext={`The model to use for lighter flows like \`LLM Chunk Filter\`
for this provider. If \`Default\` is specified, will use
the Default Model configured above.`}
label="[Optional] Fast Model"
placeholder="E.g. gpt-4"
/>
)}
<Divider />
<div>
{/* NOTE: this is above the test button to make sure it's visible */}
{!isTestSuccessful && testError && (
<Text className="text-error mt-2">{testError}</Text>
)}
{isTestSuccessful && (
<Text className="text-success mt-2">
Test successful! LLM provider is ready to go.
</Text>
)}
<div className="flex w-full mt-4">
{isTestSuccessful ? (
<Button type="submit" size="xs">
{existingLlmProvider ? "Update" : "Enable"}
</Button>
) : (
<Button
type="button"
size="xs"
disabled={isTesting}
onClick={async () => {
setIsTesting(true);
const response = await fetch("/api/admin/llm/test", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
provider: llmProviderDescriptor.name,
...values,
}),
});
setIsTesting(false);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
setTestError(errorMsg);
return;
}
setTestSuccessful(true);
setValidatedConfig(values);
}}
>
{isTesting ? <LoadingAnimation text="Testing" /> : "Test"}
</Button>
)}
{existingLlmProvider && (
<Button
type="button"
color="red"
className="ml-3"
size="xs"
icon={FiTrash}
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}`,
{
method: "DELETE",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
alert(`Failed to delete provider: ${errorMsg}`);
return;
}
mutate(LLM_PROVIDERS_ADMIN_URL);
onClose();
}}
>
Delete
</Button>
)}
</div>
</div>
</Form>
)}
</Formik>
);
}

View File

@ -0,0 +1 @@
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";

View File

@ -0,0 +1,29 @@
export interface WellKnownLLMProviderDescriptor {
name: string;
display_name: string | null;
api_key_required: boolean;
api_base_required: boolean;
api_version_required: boolean;
custom_config_keys: string[] | null;
llm_names: string[];
default_model: string | null;
default_fast_model: string | null;
}
export interface LLMProvider {
name: string;
api_key: string | null;
api_base: string | null;
api_version: string | null;
custom_config: { [key: string]: string } | null;
default_model_name: string;
fast_default_model_name: string | null;
}
export interface FullLLMProvider extends LLMProvider {
id: number;
is_default_provider: boolean | null;
model_names: string[];
}

View File

@ -2,7 +2,6 @@
import { Form, Formik } from "formik";
import { useEffect, useState } from "react";
import { LoadingAnimation } from "@/components/Loading";
import { AdminPageTitle } from "@/components/admin/Title";
import {
BooleanFormField,
@ -10,52 +9,9 @@ import {
TextFormField,
} from "@/components/admin/connectors/Field";
import { Popup } from "@/components/admin/connectors/Popup";
import { TrashIcon } from "@/components/icons/icons";
import { ApiKeyForm } from "@/components/openai/ApiKeyForm";
import { GEN_AI_API_KEY_URL } from "@/components/openai/constants";
import { fetcher } from "@/lib/fetcher";
import { Button, Divider, Text, Title } from "@tremor/react";
import { Button, Divider, Text } from "@tremor/react";
import { FiCpu } from "react-icons/fi";
import useSWR, { mutate } from "swr";
const ExistingKeys = () => {
const { data, isLoading, error } = useSWR<{ api_key: string }>(
GEN_AI_API_KEY_URL,
fetcher
);
if (isLoading) {
return <LoadingAnimation text="Loading" />;
}
if (error) {
return <div className="text-error">Error loading existing keys</div>;
}
if (!data?.api_key) {
return null;
}
return (
<div>
<Title className="mb-2">Existing Key</Title>
<div className="flex mb-1">
<p className="text-sm italic my-auto">sk- ****...**{data?.api_key}</p>
<button
className="ml-1 my-auto hover:bg-hover rounded p-1"
onClick={async () => {
await fetch(GEN_AI_API_KEY_URL, {
method: "DELETE",
});
window.location.reload();
}}
>
<TrashIcon />
</button>
</div>
</div>
);
};
import { LLMConfiguration } from "./LLMConfiguration";
const LLMOptions = () => {
const [popup, setPopup] = useState<{
@ -219,27 +175,12 @@ const Page = () => {
return (
<div className="mx-auto container">
<AdminPageTitle
title="LLM Options"
title="LLM Setup"
icon={<FiCpu size={32} className="my-auto" />}
/>
<SectionHeader>LLM Keys</SectionHeader>
<LLMConfiguration />
<ExistingKeys />
<Title className="mb-2 mt-6">Update Key</Title>
<Text className="mb-2">
Specify an OpenAI API key and click the &quot;Submit&quot; button.
</Text>
<div className="border rounded-md border-border p-3">
<ApiKeyForm
handleResponse={(response) => {
if (response.ok) {
mutate(GEN_AI_API_KEY_URL);
}
}}
/>
</div>
<LLMOptions />
</div>
);

View File

@ -20,7 +20,7 @@ import {
WelcomeModal,
hasCompletedWelcomeFlowSS,
} from "@/components/initialSetup/welcome/WelcomeModalWrapper";
import { ApiKeyModal } from "@/components/openai/ApiKeyModal";
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
import { cookies } from "next/headers";
import { DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME } from "@/components/resizable/contants";
import { personaComparator } from "../admin/assistants/lib";
@ -169,9 +169,9 @@ export default async function Page({
<>
<InstantSSRAutoRefresh />
{shouldShowWelcomeModal && <WelcomeModal />}
{shouldShowWelcomeModal && <WelcomeModal user={user} />}
{!shouldShowWelcomeModal && !shouldDisplaySourcesIncompleteModal && (
<ApiKeyModal />
<ApiKeyModal user={user} />
)}
{shouldDisplaySourcesIncompleteModal && (
<NoCompleteSourcesModal ccPairs={ccPairs} />

View File

@ -7,7 +7,7 @@ import {
} from "@/lib/userSS";
import { redirect } from "next/navigation";
import { HealthCheckBanner } from "@/components/health/healthcheck";
import { ApiKeyModal } from "@/components/openai/ApiKeyModal";
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
import { fetchSS } from "@/lib/utilsSS";
import { CCPairBasicInfo, DocumentSet, Tag, User } from "@/lib/types";
import { cookies } from "next/headers";
@ -147,10 +147,10 @@ export default async function Home() {
<div className="m-3">
<HealthCheckBanner />
</div>
{shouldShowWelcomeModal && <WelcomeModal />}
{shouldShowWelcomeModal && <WelcomeModal user={user} />}
{!shouldShowWelcomeModal &&
!shouldDisplayNoSourcesModal &&
!shouldDisplaySourcesIncompleteModal && <ApiKeyModal />}
!shouldDisplaySourcesIncompleteModal && <ApiKeyModal user={user} />}
{shouldDisplayNoSourcesModal && <NoSourcesModal />}
{shouldDisplaySourcesIncompleteModal && (
<NoCompleteSourcesModal ccPairs={ccPairs} />

View File

@ -1,6 +1,7 @@
import { ChangeEvent, FC, useEffect, useRef, useState } from "react";
import { ChevronDownIcon } from "./icons/icons";
import { FiCheck, FiChevronDown } from "react-icons/fi";
import { Popover } from "./popover/Popover";
export interface Option<T> {
name: string;
@ -181,9 +182,11 @@ export function SearchMultiSelectDropdown({
export const CustomDropdown = ({
children,
dropdown,
direction = "down", // Default to 'down' if not specified
}: {
children: JSX.Element | string;
dropdown: JSX.Element | string;
direction?: "up" | "down";
}) => {
const [isOpen, setIsOpen] = useState(false);
const dropdownRef = useRef<HTMLDivElement>(null);
@ -211,7 +214,9 @@ export const CustomDropdown = ({
{isOpen && (
<div
onClick={() => setIsOpen(!isOpen)}
className="pt-2 absolute bottom w-full z-30 box-shadow"
className={`absolute ${
direction === "up" ? "bottom-full pb-2" : "pt-2 "
} w-full z-30 box-shadow`}
>
{dropdown}
</div>
@ -281,16 +286,21 @@ export function DefaultDropdown({
selected,
onSelect,
includeDefault = false,
direction = "down",
maxHeight,
}: {
options: StringOrNumberOption[];
selected: string | null;
onSelect: (value: string | number | null) => void;
includeDefault?: boolean;
direction?: "up" | "down";
maxHeight?: string;
}) {
const selectedOption = options.find((option) => option.value === selected);
return (
<CustomDropdown
direction={direction}
dropdown={
<div
className={`
@ -300,7 +310,7 @@ export function DefaultDropdown({
flex
flex-col
bg-background
max-h-96
${maxHeight || "max-h-96"}
overflow-y-auto
overscroll-contain`}
>

View File

@ -154,7 +154,7 @@ export async function Layout({ children }: { children: React.ReactNode }) {
<div className="ml-1">LLM</div>
</div>
),
link: "/admin/keys/openai",
link: "/admin/models/llm",
},
{
name: (

View File

@ -233,6 +233,9 @@ interface SelectorFormFieldProps {
options: StringOrNumberOption[];
subtext?: string | JSX.Element;
includeDefault?: boolean;
direction?: "up" | "down";
maxHeight?: string;
onSelect?: (selected: string | number | null) => void;
}
export function SelectorFormField({
@ -241,6 +244,9 @@ export function SelectorFormField({
options,
subtext,
includeDefault = false,
direction = "down",
maxHeight,
onSelect,
}: SelectorFormFieldProps) {
const [field] = useField<string>(name);
const { setFieldValue } = useFormikContext();
@ -254,8 +260,10 @@ export function SelectorFormField({
<DefaultDropdown
options={options}
selected={field.value}
onSelect={(selected) => setFieldValue(name, selected)}
onSelect={onSelect || ((selected) => setFieldValue(name, selected))}
includeDefault={includeDefault}
direction={direction}
maxHeight={maxHeight}
/>
</div>

View File

@ -37,7 +37,7 @@ export function NoCompleteSourcesModal({
title="⏳ None of your connectors have finished a full sync yet"
onOutsideClick={() => setIsHidden(true)}
>
<div className="text-base">
<div className="text-sm">
<div>
<div>
You&apos;ve connected some sources, but none of them have finished

View File

@ -1,6 +1,6 @@
"use client";
import { Button, Divider } from "@tremor/react";
import { Button, Divider, Text } from "@tremor/react";
import { Modal } from "../../Modal";
import Link from "next/link";
import { FiMessageSquare, FiShare2 } from "react-icons/fi";
@ -21,11 +21,11 @@ export function NoSourcesModal() {
>
<div className="text-base">
<div>
<p>
<Text>
Before using Search you&apos;ll need to connect at least one source.
Without any connected knowledge sources, there isn&apos;t anything
to search over.
</p>
</Text>
<Link href="/admin/add-connector">
<Button className="mt-3" size="xs" icon={FiShare2}>
Connect a Source!
@ -33,11 +33,11 @@ export function NoSourcesModal() {
</Link>
<Divider />
<div>
<p>
<Text>
Or, if you&apos;re looking for a pure ChatGPT-like experience
without any organization specific knowledge, then you can head
over to the Chat page and start chatting with Danswer right away!
</p>
</Text>
<Link href="/chat">
<Button className="mt-3" size="xs" icon={FiMessageSquare}>
Start Chatting!

View File

@ -9,8 +9,10 @@ import { COMPLETED_WELCOME_FLOW_COOKIE } from "./constants";
import { FiCheckCircle, FiMessageSquare, FiShare2 } from "react-icons/fi";
import { useEffect, useState } from "react";
import { BackButton } from "@/components/BackButton";
import { ApiKeyForm } from "@/components/openai/ApiKeyForm";
import { checkApiKey } from "@/components/openai/ApiKeyModal";
import { ApiKeyForm } from "@/components/llm/ApiKeyForm";
import { WellKnownLLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
import { checkLlmProvider } from "./lib";
import { User } from "@/lib/types";
function setWelcomeFlowComplete() {
Cookies.set(COMPLETED_WELCOME_FLOW_COOKIE, "true", { expires: 365 });
@ -52,20 +54,26 @@ function UsageTypeSection({
);
}
export function _WelcomeModal() {
export function _WelcomeModal({ user }: { user: User | null }) {
const router = useRouter();
const [selectedFlow, setSelectedFlow] = useState<null | "search" | "chat">(
null
);
const [isHidden, setIsHidden] = useState(false);
const [apiKeyVerified, setApiKeyVerified] = useState<boolean>(false);
const [providerOptions, setProviderOptions] = useState<
WellKnownLLMProviderDescriptor[]
>([]);
useEffect(() => {
checkApiKey().then((error) => {
if (!error) {
setApiKeyVerified(true);
}
});
async function fetchProviderInfo() {
const { providers, options, defaultCheckSuccessful } =
await checkLlmProvider(user);
setApiKeyVerified(providers.length > 0 && defaultCheckSuccessful);
setProviderOptions(options);
}
fetchProviderInfo();
}, []);
if (isHidden) {
@ -78,30 +86,27 @@ export function _WelcomeModal() {
case "search":
title = undefined;
body = (
<>
<div className="max-h-[85vh] overflow-y-auto px-4 pb-4">
<BackButton behaviorOverride={() => setSelectedFlow(null)} />
<div className="mt-3">
<Text className="font-bold mt-6 mb-2 flex">
<Text className="font-bold flex">
{apiKeyVerified && (
<FiCheckCircle className="my-auto mr-2 text-success" />
)}
Step 1: Provide OpenAI API Key
Step 1: Setup an LLM
</Text>
<div>
{apiKeyVerified ? (
<div>
API Key setup complete!
<Text className="mt-2">
LLM setup complete!
<br /> <br />
If you want to change the key later, you&apos;ll be able to
easily to do so in the Admin Panel.
</div>
</Text>
) : (
<ApiKeyForm
handleResponse={async (response) => {
if (response.ok) {
setApiKeyVerified(true);
}
}}
onSuccess={() => setApiKeyVerified(true)}
providerOptions={providerOptions}
/>
)}
</div>
@ -109,12 +114,12 @@ export function _WelcomeModal() {
Step 2: Connect Data Sources
</Text>
<div>
<p>
<Text>
Connectors are the way that Danswer gets data from your
organization&apos;s various data sources. Once setup, we&apos;ll
automatically sync data from your apps and docs into Danswer, so
you can search through all of them in one place.
</p>
</Text>
<div className="flex mt-3">
<Link
@ -133,59 +138,37 @@ export function _WelcomeModal() {
</div>
</div>
</div>
</>
</div>
);
break;
case "chat":
title = undefined;
body = (
<>
<div className="mt-3 max-h-[85vh] overflow-y-auto px-4 pb-4">
<BackButton behaviorOverride={() => setSelectedFlow(null)} />
<div className="mt-3">
<div>
To start using Danswer as a secure ChatGPT, we just need to
configure our LLM!
<br />
<br />
Danswer supports connections with a wide range of LLMs, including
self-hosted open-source LLMs. For more details, check out the{" "}
<a
className="text-link"
href="https://docs.danswer.dev/gen_ai_configs/overview"
>
documentation
</a>
.
<br />
<br />
If you haven&apos;t done anything special with the Gen AI configs,
then we default to use OpenAI.
</div>
<Text className="font-bold mt-6 mb-2 flex">
<Text className="font-bold flex">
{apiKeyVerified && (
<FiCheckCircle className="my-auto mr-2 text-success" />
)}
Step 1: Provide LLM API Key
Step 1: Setup an LLM
</Text>
<div>
{apiKeyVerified ? (
<div>
<Text className="mt-2">
LLM setup complete!
<br /> <br />
If you want to change the key later or choose a different LLM,
you&apos;ll be able to easily to do so in the Admin Panel / by
changing some environment variables.
</div>
you&apos;ll be able to easily to do so in the Admin Panel.
</Text>
) : (
<ApiKeyForm
handleResponse={async (response) => {
if (response.ok) {
setApiKeyVerified(true);
}
}}
/>
<div>
<ApiKeyForm
onSuccess={() => setApiKeyVerified(true)}
providerOptions={providerOptions}
/>
</div>
)}
</div>
@ -193,7 +176,7 @@ export function _WelcomeModal() {
Step 2: Start Chatting!
</Text>
<div>
<Text>
Click the button below to start chatting with the LLM setup above!
Don&apos;t worry, if you do decide later on you want to connect
your organization&apos;s knowledge, you can always do that in the{" "}
@ -209,7 +192,7 @@ export function _WelcomeModal() {
Admin Panel
</Link>
.
</div>
</Text>
<div className="flex mt-3">
<Link
@ -228,7 +211,7 @@ export function _WelcomeModal() {
</Link>
</div>
</div>
</>
</div>
);
break;
default:
@ -236,17 +219,17 @@ export function _WelcomeModal() {
body = (
<>
<div>
<p>How are you planning on using Danswer?</p>
<Text>How are you planning on using Danswer?</Text>
</div>
<Divider />
<UsageTypeSection
title="Search / Chat with Knowledge"
description={
<div>
<Text>
If you&apos;re looking to search through, chat with, or ask
direct questions of your organization&apos;s knowledge, then
this is the option for you!
</div>
</Text>
}
callToAction="Get Started"
onClick={() => setSelectedFlow("search")}
@ -255,10 +238,10 @@ export function _WelcomeModal() {
<UsageTypeSection
title="Secure ChatGPT"
description={
<>
<Text>
If you&apos;re looking for a pure ChatGPT-like experience, then
this is the option for you!
</>
</Text>
}
icon={FiMessageSquare}
callToAction="Get Started"

View File

@ -4,6 +4,7 @@ import {
_WelcomeModal,
} from "./WelcomeModal";
import { COMPLETED_WELCOME_FLOW_COOKIE } from "./constants";
import { User } from "@/lib/types";
export function hasCompletedWelcomeFlowSS() {
const cookieStore = cookies();
@ -13,11 +14,11 @@ export function hasCompletedWelcomeFlowSS() {
);
}
export function WelcomeModal() {
export function WelcomeModal({ user }: { user: User | null }) {
const hasCompletedWelcomeFlow = hasCompletedWelcomeFlowSS();
if (hasCompletedWelcomeFlow) {
return <_CompletedWelcomeFlowDummyComponent />;
}
return <_WelcomeModal />;
return <_WelcomeModal user={user} />;
}

View File

@ -0,0 +1,56 @@
import {
FullLLMProvider,
WellKnownLLMProviderDescriptor,
} from "@/app/admin/models/llm/interfaces";
import { User } from "@/lib/types";
const DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY = "defaultLlmProviderTestComplete";
function checkDefaultLLMProviderTestComplete() {
return (
localStorage.getItem(DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY) === "true"
);
}
function setDefaultLLMProviderTestComplete() {
localStorage.setItem(DEFAULT_LLM_PROVIDER_TEST_COMPLETE_KEY, "true");
}
function shouldCheckDefaultLLMProvider(user: User | null) {
return (
!checkDefaultLLMProviderTestComplete() && (!user || user.role === "admin")
);
}
export async function checkLlmProvider(user: User | null) {
/* NOTE: should only be called on the client side, after initial render */
const checkDefault = shouldCheckDefaultLLMProvider(user);
const tasks = [
fetch("/api/llm/provider"),
fetch("/api/admin/llm/built-in/options"),
checkDefault
? fetch("/api/admin/llm/test/default", { method: "POST" })
: (async () => null)(),
];
const [providerResponse, optionsResponse, defaultCheckResponse] =
await Promise.all(tasks);
let providers: FullLLMProvider[] = [];
if (providerResponse?.ok) {
providers = await providerResponse.json();
}
let options: WellKnownLLMProviderDescriptor[] = [];
if (optionsResponse?.ok) {
options = await optionsResponse.json();
}
let defaultCheckSuccessful =
!checkDefault || defaultCheckResponse?.ok || false;
if (defaultCheckSuccessful) {
setDefaultLLMProviderTestComplete();
}
return { providers, options, defaultCheckSuccessful };
}

View File

@ -0,0 +1,77 @@
import { Popup } from "../admin/connectors/Popup";
import { useState } from "react";
import { TabGroup, TabList, Tab, TabPanels, TabPanel } from "@tremor/react";
import { WellKnownLLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
import { LLMProviderUpdateForm } from "@/app/admin/models/llm/LLMProviderUpdateForm";
import { CustomLLMProviderUpdateForm } from "@/app/admin/models/llm/CustomLLMProviderUpdateForm";
export const ApiKeyForm = ({
onSuccess,
providerOptions,
}: {
onSuccess: () => void;
providerOptions: WellKnownLLMProviderDescriptor[];
}) => {
const [popup, setPopup] = useState<{
message: string;
type: "success" | "error";
} | null>(null);
const defaultProvider = providerOptions[0]?.name;
const providerNameToIndexMap = new Map<string, number>();
providerOptions.forEach((provider, index) => {
providerNameToIndexMap.set(provider.name, index);
});
providerNameToIndexMap.set("custom", providerOptions.length);
const providerIndexToNameMap = new Map<number, string>();
Array.from(providerNameToIndexMap.keys()).forEach((key) => {
providerIndexToNameMap.set(providerNameToIndexMap.get(key)!, key);
});
const [providerName, setProviderName] = useState<string>(defaultProvider);
return (
<div>
{popup && <Popup message={popup.message} type={popup.type} />}
<TabGroup
index={providerNameToIndexMap.get(providerName) || 0}
onIndexChange={(index) =>
setProviderName(providerIndexToNameMap.get(index) || defaultProvider)
}
>
<TabList className="mt-3 mb-4">
<>
{providerOptions.map((provider) => (
<Tab key={provider.name}>
{provider.display_name || provider.name}
</Tab>
))}
<Tab key="custom">Custom</Tab>
</>
</TabList>
<TabPanels>
{providerOptions.map((provider) => {
return (
<TabPanel key={provider.name}>
<LLMProviderUpdateForm
llmProviderDescriptor={provider}
onClose={() => onSuccess()}
shouldMarkAsDefault
setPopup={setPopup}
/>
</TabPanel>
);
})}
<TabPanel key="custom">
<CustomLLMProviderUpdateForm
onClose={() => onSuccess()}
shouldMarkAsDefault
setPopup={setPopup}
/>
</TabPanel>
</TabPanels>
</TabGroup>
</div>
);
};

View File

@ -0,0 +1,74 @@
"use client";
import { useState, useEffect } from "react";
import { ApiKeyForm } from "./ApiKeyForm";
import { Modal } from "../Modal";
import { Divider } from "@tremor/react";
import {
FullLLMProvider,
WellKnownLLMProviderDescriptor,
} from "@/app/admin/models/llm/interfaces";
import { checkLlmProvider } from "../initialSetup/welcome/lib";
import { User } from "@/lib/types";
export const ApiKeyModal = ({ user }: { user: User | null }) => {
const [forceHidden, setForceHidden] = useState<boolean>(false);
const [validProviderExists, setValidProviderExists] = useState<boolean>(true);
const [providerOptions, setProviderOptions] = useState<
WellKnownLLMProviderDescriptor[]
>([]);
useEffect(() => {
async function fetchProviderInfo() {
const { providers, options, defaultCheckSuccessful } =
await checkLlmProvider(user);
setValidProviderExists(providers.length > 0 && defaultCheckSuccessful);
setProviderOptions(options);
}
fetchProviderInfo();
}, []);
// don't show if
// (1) a valid provider has been setup or
// (2) there are no provider options (e.g. user isn't an admin)
// (3) user explicitly hides the modal
if (validProviderExists || !providerOptions.length || forceHidden) {
return null;
}
return (
<Modal
title="LLM Key Setup"
className="max-w-4xl"
onOutsideClick={() => setForceHidden(true)}
>
<div className="max-h-[75vh] overflow-y-auto flex flex-col px-4">
<div>
<div className="mb-5 text-sm">
Please setup an LLM below in order to start using Danswer Search or
Danswer Chat. Don&apos;t worry, you can always change this later in
the Admin Panel.
<br />
<br />
Or if you&apos;d rather look around first,{" "}
<strong
onClick={() => setForceHidden(true)}
className="text-link cursor-pointer"
>
skip this step
</strong>
.
</div>
<ApiKeyForm
onSuccess={() => {
setForceHidden(true);
}}
providerOptions={providerOptions}
/>
</div>
</div>
</Modal>
);
};

View File

@ -1,82 +0,0 @@
import { Form, Formik } from "formik";
import { Popup } from "../admin/connectors/Popup";
import { useState } from "react";
import { TextFormField } from "../admin/connectors/Field";
import { GEN_AI_API_KEY_URL } from "./constants";
import { LoadingAnimation } from "../Loading";
import { Button } from "@tremor/react";
interface Props {
handleResponse?: (response: Response) => void;
}
export const ApiKeyForm = ({ handleResponse }: Props) => {
const [popup, setPopup] = useState<{
message: string;
type: "success" | "error";
} | null>(null);
return (
<div>
{popup && <Popup message={popup.message} type={popup.type} />}
<Formik
initialValues={{ apiKey: "" }}
onSubmit={async ({ apiKey }, formikHelpers) => {
const response = await fetch(GEN_AI_API_KEY_URL, {
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ api_key: apiKey }),
});
if (handleResponse) {
handleResponse(response);
}
if (response.ok) {
setPopup({
message: "Updated API key!",
type: "success",
});
formikHelpers.resetForm();
} else {
const body = await response.json();
if (body.detail) {
setPopup({ message: body.detail, type: "error" });
} else {
setPopup({
message:
"Unable to set API key. Check if the provided key is valid.",
type: "error",
});
}
setTimeout(() => {
setPopup(null);
}, 10000);
}
}}
>
{({ isSubmitting }) =>
isSubmitting ? (
<div className="text-base">
<LoadingAnimation text="Validating API key" />
</div>
) : (
<Form>
<TextFormField name="apiKey" type="password" label="API Key:" />
<div className="flex">
<Button
size="xs"
type="submit"
disabled={isSubmitting}
className="w-48 mx-auto"
>
Submit
</Button>
</div>
</Form>
)
}
</Formik>
</div>
);
};

View File

@ -1,68 +0,0 @@
"use client";
import { useState, useEffect } from "react";
import { ApiKeyForm } from "./ApiKeyForm";
import { Modal } from "../Modal";
import { Divider, Text } from "@tremor/react";
export async function checkApiKey() {
const response = await fetch("/api/manage/admin/genai-api-key/validate");
if (!response.ok && (response.status === 404 || response.status === 400)) {
const jsonResponse = await response.json();
return jsonResponse.detail;
}
return null;
}
export const ApiKeyModal = () => {
const [errorMsg, setErrorMsg] = useState<string | null>(null);
useEffect(() => {
checkApiKey().then((error) => {
if (error) {
setErrorMsg(error);
}
});
}, []);
if (!errorMsg) {
return null;
}
return (
<Modal
title="LLM Key Setup"
className="max-w-4xl"
onOutsideClick={() => setErrorMsg(null)}
>
<div>
<div>
<div className="mb-2.5 text-base">
Please provide a valid OpenAI API key below in order to start using
Danswer Search or Danswer Chat.
<br />
<br />
Or if you&apos;d rather look around first,{" "}
<strong
onClick={() => setErrorMsg(null)}
className="text-link cursor-pointer"
>
skip this step
</strong>
.
</div>
<Divider />
<ApiKeyForm
handleResponse={(response) => {
if (response.ok) {
setErrorMsg(null);
}
}}
/>
</div>
</div>
</Modal>
);
};

View File

@ -1 +0,0 @@
export const GEN_AI_API_KEY_URL = "/api/manage/admin/genai-api-key";

View File

@ -2,6 +2,7 @@ import { Persona } from "@/app/admin/assistants/interfaces";
import { CCPairBasicInfo, DocumentSet, User } from "../types";
import { getCurrentUserSS } from "../userSS";
import { fetchSS } from "../utilsSS";
import { FullLLMProvider } from "@/app/admin/models/llm/interfaces";
export async function fetchAssistantEditorInfoSS(
personaId?: number | string
@ -10,8 +11,7 @@ export async function fetchAssistantEditorInfoSS(
{
ccPairs: CCPairBasicInfo[];
documentSets: DocumentSet[];
llmOverrideOptions: string[];
defaultLLM: string;
llmProviders: FullLLMProvider[];
user: User | null;
existingPersona: Persona | null;
},
@ -22,8 +22,7 @@ export async function fetchAssistantEditorInfoSS(
const tasks = [
fetchSS("/manage/indexing-status"),
fetchSS("/manage/document-set"),
fetchSS("/persona/utils/list-available-models"),
fetchSS("/persona/utils/default-model"),
fetchSS("/llm/provider"),
// duplicate fetch, but shouldn't be too big of a deal
// this page is not a high traffic page
getCurrentUserSS(),
@ -37,15 +36,13 @@ export async function fetchAssistantEditorInfoSS(
const [
ccPairsInfoResponse,
documentSetsResponse,
llmOverridesResponse,
defaultLLMResponse,
llmProvidersResponse,
user,
personaResponse,
] = (await Promise.all(tasks)) as [
Response,
Response,
Response,
Response,
User | null,
Response | null,
];
@ -66,21 +63,13 @@ export async function fetchAssistantEditorInfoSS(
}
const documentSets = (await documentSetsResponse.json()) as DocumentSet[];
if (!llmOverridesResponse.ok) {
if (!llmProvidersResponse.ok) {
return [
null,
`Failed to fetch LLM override options - ${await llmOverridesResponse.text()}`,
`Failed to fetch LLM providers - ${await llmProvidersResponse.text()}`,
];
}
const llmOverrideOptions = (await llmOverridesResponse.json()) as string[];
if (!defaultLLMResponse.ok) {
return [
null,
`Failed to fetch default LLM - ${await defaultLLMResponse.text()}`,
];
}
const defaultLLM = (await defaultLLMResponse.json()) as string;
const llmProviders = (await llmProvidersResponse.json()) as FullLLMProvider[];
if (personaId && personaResponse && !personaResponse.ok) {
return [null, `Failed to fetch Persona - ${await personaResponse.text()}`];
@ -93,8 +82,7 @@ export async function fetchAssistantEditorInfoSS(
{
ccPairs,
documentSets,
llmOverrideOptions,
defaultLLM,
llmProviders,
user,
existingPersona,
},