From f5b3333df32c36af326e7a5580965b236f710bd9 Mon Sep 17 00:00:00 2001 From: Weves Date: Sat, 13 Apr 2024 18:07:13 -0700 Subject: [PATCH] Add UI-based LLM selection --- ...1c1ac29467_add_tables_for_ui_based_llm_.py | 49 ++ backend/danswer/chat/load_yamls.py | 1 + backend/danswer/chat/process_message.py | 15 +- backend/danswer/configs/constants.py | 5 +- .../slack/handlers/handle_message.py | 58 ++- backend/danswer/db/chat.py | 3 + backend/danswer/db/engine.py | 2 +- backend/danswer/db/llm.py | 96 ++++ backend/danswer/db/models.py | 34 +- backend/danswer/db/persona.py | 7 + backend/danswer/db/slack_bot_config.py | 1 + .../danswer/dynamic_configs/port_configs.py | 75 +++ backend/danswer/llm/answering/answer.py | 17 +- backend/danswer/llm/answering/doc_pruning.py | 2 +- backend/danswer/llm/answering/models.py | 34 -- .../llm/answering/prompts/citations_prompt.py | 11 +- backend/danswer/llm/chat_llm.py | 53 ++- backend/danswer/llm/factory.py | 92 +++- backend/danswer/llm/gpt_4_all.py | 11 +- backend/danswer/llm/interfaces.py | 12 + backend/danswer/llm/options.py | 109 +++++ backend/danswer/llm/utils.py | 63 +-- backend/danswer/main.py | 25 +- .../one_shot_answer/answer_question.py | 4 +- .../danswer/server/features/persona/api.py | 48 -- .../danswer/server/features/persona/models.py | 3 + .../danswer/server/manage/administrative.py | 102 +--- backend/danswer/server/manage/llm/api.py | 157 ++++++ backend/danswer/server/manage/llm/models.py | 89 ++++ backend/requirements/default.txt | 1 + web/package-lock.json | 7 + web/package.json | 2 + .../app/admin/assistants/AssistantEditor.tsx | 102 +++- web/src/app/admin/assistants/interfaces.ts | 1 + web/src/app/admin/assistants/lib.ts | 3 + web/src/app/admin/models/embedding/page.tsx | 11 +- .../llm/CustomLLMProviderUpdateForm.tsx | 450 ++++++++++++++++++ .../app/admin/models/llm/LLMConfiguration.tsx | 234 +++++++++ .../models/llm/LLMProviderUpdateForm.tsx | 347 ++++++++++++++ web/src/app/admin/models/llm/constants.ts | 1 + web/src/app/admin/models/llm/interfaces.ts | 29 ++ .../{keys/openai => models/llm}/page.tsx | 67 +-- web/src/app/chat/page.tsx | 6 +- web/src/app/search/page.tsx | 6 +- web/src/components/Dropdown.tsx | 14 +- web/src/components/admin/Layout.tsx | 2 +- web/src/components/admin/connectors/Field.tsx | 10 +- .../search/NoCompleteSourceModal.tsx | 2 +- .../initialSetup/search/NoSourcesModal.tsx | 10 +- .../initialSetup/welcome/WelcomeModal.tsx | 111 ++--- .../welcome/WelcomeModalWrapper.tsx | 5 +- .../components/initialSetup/welcome/lib.ts | 56 +++ web/src/components/llm/ApiKeyForm.tsx | 77 +++ web/src/components/llm/ApiKeyModal.tsx | 74 +++ web/src/components/openai/ApiKeyForm.tsx | 82 ---- web/src/components/openai/ApiKeyModal.tsx | 68 --- web/src/components/openai/constants.ts | 1 - .../assistants/fetchPersonaEditorInfoSS.ts | 28 +- 58 files changed, 2284 insertions(+), 701 deletions(-) create mode 100644 backend/alembic/versions/401c1ac29467_add_tables_for_ui_based_llm_.py create mode 100644 backend/danswer/db/llm.py create mode 100644 backend/danswer/llm/options.py create mode 100644 backend/danswer/server/manage/llm/api.py create mode 100644 backend/danswer/server/manage/llm/models.py create mode 100644 web/src/app/admin/models/llm/CustomLLMProviderUpdateForm.tsx create mode 100644 web/src/app/admin/models/llm/LLMConfiguration.tsx create mode 100644 web/src/app/admin/models/llm/LLMProviderUpdateForm.tsx create mode 100644 web/src/app/admin/models/llm/constants.ts create mode 100644 web/src/app/admin/models/llm/interfaces.ts rename web/src/app/admin/{keys/openai => models/llm}/page.tsx (76%) create mode 100644 web/src/components/initialSetup/welcome/lib.ts create mode 100644 web/src/components/llm/ApiKeyForm.tsx create mode 100644 web/src/components/llm/ApiKeyModal.tsx delete mode 100644 web/src/components/openai/ApiKeyForm.tsx delete mode 100644 web/src/components/openai/ApiKeyModal.tsx delete mode 100644 web/src/components/openai/constants.ts diff --git a/backend/alembic/versions/401c1ac29467_add_tables_for_ui_based_llm_.py b/backend/alembic/versions/401c1ac29467_add_tables_for_ui_based_llm_.py new file mode 100644 index 000000000..2a50e1d0f --- /dev/null +++ b/backend/alembic/versions/401c1ac29467_add_tables_for_ui_based_llm_.py @@ -0,0 +1,49 @@ +"""Add tables for UI-based LLM configuration + +Revision ID: 401c1ac29467 +Revises: 703313b75876 +Create Date: 2024-04-13 18:07:29.153817 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "401c1ac29467" +down_revision = "703313b75876" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "llm_provider", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("api_key", sa.String(), nullable=True), + sa.Column("api_base", sa.String(), nullable=True), + sa.Column("api_version", sa.String(), nullable=True), + sa.Column( + "custom_config", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + sa.Column("default_model_name", sa.String(), nullable=False), + sa.Column("fast_default_model_name", sa.String(), nullable=True), + sa.Column("is_default_provider", sa.Boolean(), unique=True, nullable=True), + sa.Column("model_names", postgresql.ARRAY(sa.String()), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), + ) + + op.add_column( + "persona", + sa.Column("llm_model_provider_override", sa.String(), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("persona", "llm_model_provider_override") + + op.drop_table("llm_provider") diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index 1b1e615bb..abb10461a 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -88,6 +88,7 @@ def load_personas_from_yaml( llm_relevance_filter=persona.get("llm_relevance_filter"), starter_messages=persona.get("starter_messages"), llm_filter_extraction=persona.get("llm_filter_extraction"), + llm_model_provider_override=None, llm_model_version_override=None, recency_bias=RecencyBiasSetting(persona["recency_bias"]), prompts=cast(list[PromptDBModel] | None, prompts), diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 21b87296f..c02dc81f2 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -34,11 +34,10 @@ from danswer.llm.answering.answer import Answer from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import CitationConfig from danswer.llm.answering.models import DocumentPruningConfig -from danswer.llm.answering.models import LLMConfig from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_llm_for_persona from danswer.llm.utils import get_default_llm_tokenizer from danswer.search.models import OptionalSearchSetting from danswer.search.models import SearchRequest @@ -135,8 +134,8 @@ def stream_chat_message_objects( ) try: - llm = get_default_llm( - gen_ai_model_version_override=persona.llm_model_version_override + llm = get_llm_for_persona( + persona, new_msg_req.llm_override or chat_session.llm_override ) except GenAIDisabledException: llm = None @@ -373,9 +372,11 @@ def stream_chat_message_objects( new_msg_req.prompt_override or chat_session.prompt_override ), ), - llm_config=LLMConfig.from_persona( - persona, - llm_override=(new_msg_req.llm_override or chat_session.llm_override), + llm=( + llm + or get_llm_for_persona( + persona, new_msg_req.llm_override or chat_session.llm_override + ) ), doc_relevance_list=llm_relevance_list, message_history=[ diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 584c846b2..0628911be 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -24,6 +24,7 @@ MATCH_HIGHLIGHTS = "match_highlights" # not be used for QA. For example, Google Drive file types which can't be parsed # are still useful as a search result but not for QA. IGNORE_FOR_QA = "ignore_for_qa" +# NOTE: deprecated, only used for porting key from old system GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key" PUBLIC_DOC_PAT = "PUBLIC" PUBLIC_DOCUMENT_SET = "__PUBLIC" @@ -52,10 +53,6 @@ SECTION_SEPARATOR = "\n\n" INDEX_SEPARATOR = "===" -# Key-Value store constants -GEN_AI_DETECTED_MODEL = "gen_ai_detected_model" - - # Messages DISABLED_GEN_AI_MSG = ( "Your System Admin has disabled the Generative AI functionalities of Danswer.\n" diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index fc1c038ae..c24d28847 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -36,13 +36,15 @@ from danswer.danswerbot.slack.utils import slack_usage_report from danswer.danswerbot.slack.utils import SlackRateLimiter from danswer.danswerbot.slack.utils import update_emote_react from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.models import Persona from danswer.db.models import SlackBotConfig from danswer.db.models import SlackBotResponseType +from danswer.db.persona import fetch_persona_by_id from danswer.llm.answering.prompts.citations_prompt import ( compute_max_document_tokens_for_persona, ) +from danswer.llm.factory import get_llm_for_persona from danswer.llm.utils import check_number_of_tokens -from danswer.llm.utils import get_default_llm_version from danswer.llm.utils import get_max_input_tokens from danswer.one_shot_answer.answer_question import get_search_answer from danswer.one_shot_answer.models import DirectQARequest @@ -237,32 +239,38 @@ def handle_message( max_document_tokens: int | None = None max_history_tokens: int | None = None - if len(new_message_request.messages) > 1: - llm_name = get_default_llm_version()[0] - if persona and persona.llm_model_version_override: - llm_name = persona.llm_model_version_override - - # In cases of threads, split the available tokens between docs and thread context - input_tokens = get_max_input_tokens(model_name=llm_name) - max_history_tokens = int(input_tokens * thread_context_percent) - - remaining_tokens = input_tokens - max_history_tokens - - query_text = new_message_request.messages[0].message - if persona: - max_document_tokens = compute_max_document_tokens_for_persona( - persona=persona, - actual_user_input=query_text, - max_llm_token_override=remaining_tokens, - ) - else: - max_document_tokens = ( - remaining_tokens - - 512 # Needs to be more than any of the QA prompts - - check_number_of_tokens(query_text) - ) with Session(get_sqlalchemy_engine()) as db_session: + if len(new_message_request.messages) > 1: + persona = cast( + Persona, + fetch_persona_by_id(db_session, new_message_request.persona_id), + ) + llm = get_llm_for_persona(persona) + + # In cases of threads, split the available tokens between docs and thread context + input_tokens = get_max_input_tokens( + model_name=llm.config.model_name, + model_provider=llm.config.model_provider, + ) + max_history_tokens = int(input_tokens * thread_context_percent) + + remaining_tokens = input_tokens - max_history_tokens + + query_text = new_message_request.messages[0].message + if persona: + max_document_tokens = compute_max_document_tokens_for_persona( + persona=persona, + actual_user_input=query_text, + max_llm_token_override=remaining_tokens, + ) + else: + max_document_tokens = ( + remaining_tokens + - 512 # Needs to be more than any of the QA prompts + - check_number_of_tokens(query_text) + ) + # This also handles creating the query event in postgres answer = get_search_answer( query_req=new_message_request, diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index d45fe95a7..8000123a2 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -498,6 +498,7 @@ def upsert_persona( recency_bias: RecencyBiasSetting, prompts: list[Prompt] | None, document_sets: list[DBDocumentSet] | None, + llm_model_provider_override: str | None, llm_model_version_override: str | None, starter_messages: list[StarterMessage] | None, is_public: bool, @@ -524,6 +525,7 @@ def upsert_persona( persona.llm_filter_extraction = llm_filter_extraction persona.recency_bias = recency_bias persona.default_persona = default_persona + persona.llm_model_provider_override = llm_model_provider_override persona.llm_model_version_override = llm_model_version_override persona.starter_messages = starter_messages persona.deleted = False # Un-delete if previously deleted @@ -553,6 +555,7 @@ def upsert_persona( default_persona=default_persona, prompts=prompts or [], document_sets=document_sets or [], + llm_model_provider_override=llm_model_provider_override, llm_model_version_override=llm_model_version_override, starter_messages=starter_messages, ) diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index 1be57179c..39d37c2b9 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -71,7 +71,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine: return _ASYNC_ENGINE -def get_session_context_manager() -> ContextManager: +def get_session_context_manager() -> ContextManager[Session]: return contextlib.contextmanager(get_session)() diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py new file mode 100644 index 000000000..a502180db --- /dev/null +++ b/backend/danswer/db/llm.py @@ -0,0 +1,96 @@ +from sqlalchemy import delete +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.db.models import LLMProvider as LLMProviderModel +from danswer.server.manage.llm.models import FullLLMProvider +from danswer.server.manage.llm.models import LLMProviderUpsertRequest + + +def upsert_llm_provider( + db_session: Session, llm_provider: LLMProviderUpsertRequest +) -> FullLLMProvider: + existing_llm_provider = db_session.scalar( + select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name) + ) + if existing_llm_provider: + existing_llm_provider.api_key = llm_provider.api_key + existing_llm_provider.api_base = llm_provider.api_base + existing_llm_provider.api_version = llm_provider.api_version + existing_llm_provider.custom_config = llm_provider.custom_config + existing_llm_provider.default_model_name = llm_provider.default_model_name + existing_llm_provider.fast_default_model_name = ( + llm_provider.fast_default_model_name + ) + existing_llm_provider.model_names = llm_provider.model_names + db_session.commit() + return FullLLMProvider.from_model(existing_llm_provider) + + # if it does not exist, create a new entry + llm_provider_model = LLMProviderModel( + name=llm_provider.name, + api_key=llm_provider.api_key, + api_base=llm_provider.api_base, + api_version=llm_provider.api_version, + custom_config=llm_provider.custom_config, + default_model_name=llm_provider.default_model_name, + fast_default_model_name=llm_provider.fast_default_model_name, + model_names=llm_provider.model_names, + is_default_provider=None, + ) + db_session.add(llm_provider_model) + db_session.commit() + + return FullLLMProvider.from_model(llm_provider_model) + + +def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]: + return list(db_session.scalars(select(LLMProviderModel)).all()) + + +def fetch_default_provider(db_session: Session) -> FullLLMProvider | None: + provider_model = db_session.scalar( + select(LLMProviderModel).where( + LLMProviderModel.is_default_provider == True # noqa: E712 + ) + ) + if not provider_model: + return None + return FullLLMProvider.from_model(provider_model) + + +def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None: + provider_model = db_session.scalar( + select(LLMProviderModel).where(LLMProviderModel.name == provider_name) + ) + if not provider_model: + return None + return FullLLMProvider.from_model(provider_model) + + +def remove_llm_provider(db_session: Session, provider_id: int) -> None: + db_session.execute( + delete(LLMProviderModel).where(LLMProviderModel.id == provider_id) + ) + db_session.commit() + + +def update_default_provider(db_session: Session, provider_id: int) -> None: + new_default = db_session.scalar( + select(LLMProviderModel).where(LLMProviderModel.id == provider_id) + ) + if not new_default: + raise ValueError(f"LLM Provider with id {provider_id} does not exist") + + existing_default = db_session.scalar( + select(LLMProviderModel).where( + LLMProviderModel.is_default_provider == True # noqa: E712 + ) + ) + if existing_default: + existing_default.is_default_provider = None + # required to ensure that the below does not cause a unique constraint violation + db_session.flush() + + new_default.is_default_provider = True + db_session.commit() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 8d1722138..ec5d249c5 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -8,8 +8,8 @@ from typing import Optional from typing import TypedDict from uuid import UUID -from fastapi_users.db import SQLAlchemyBaseOAuthAccountTableUUID -from fastapi_users.db import SQLAlchemyBaseUserTableUUID +from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID +from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID from sqlalchemy import Boolean from sqlalchemy import DateTime @@ -695,6 +695,33 @@ Structures, Organizational, Configurations Tables """ +class LLMProvider(Base): + __tablename__ = "llm_provider" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + name: Mapped[str] = mapped_column(String, unique=True) + api_key: Mapped[str | None] = mapped_column(String, nullable=True) + api_base: Mapped[str | None] = mapped_column(String, nullable=True) + api_version: Mapped[str | None] = mapped_column(String, nullable=True) + # custom configs that should be passed to the LLM provider at inference time + # (e.g. `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, etc. for bedrock) + custom_config: Mapped[dict[str, str] | None] = mapped_column( + postgresql.JSONB(), nullable=True + ) + default_model_name: Mapped[str] = mapped_column(String) + fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True) + + # The LLMs that are available for this provider. Only required if not a default provider. + # If a default provider, then the LLM options are pulled from the `options.py` file. + # If needed, can be pulled out as a separate table in the future. + model_names: Mapped[list[str] | None] = mapped_column( + postgresql.ARRAY(String), nullable=True + ) + + # should only be set for a single provider + is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True) + + class DocumentSet(Base): __tablename__ = "document_set" @@ -792,6 +819,9 @@ class Persona(Base): # globablly via env variables. For flexibility, validity is not currently enforced # NOTE: only is applied on the actual response generation - is not used for things like # auto-detected time filters, relevance filters, etc. + llm_model_provider_override: Mapped[str | None] = mapped_column( + String, nullable=True + ) llm_model_version_override: Mapped[str | None] = mapped_column( String, nullable=True ) diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index 7b1116b5f..a2dece7fb 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -1,11 +1,13 @@ from uuid import UUID from fastapi import HTTPException +from sqlalchemy import select from sqlalchemy.orm import Session from danswer.db.chat import get_prompts_by_ids from danswer.db.chat import upsert_persona from danswer.db.document_set import get_document_sets_by_ids +from danswer.db.models import Persona from danswer.db.models import Persona__User from danswer.db.models import User from danswer.server.features.persona.models import CreatePersonaRequest @@ -69,6 +71,7 @@ def create_update_persona( recency_bias=create_persona_request.recency_bias, prompts=prompts, document_sets=document_sets, + llm_model_provider_override=create_persona_request.llm_model_provider_override, llm_model_version_override=create_persona_request.llm_model_version_override, starter_messages=create_persona_request.starter_messages, is_public=create_persona_request.is_public, @@ -91,3 +94,7 @@ def create_update_persona( logger.exception("Failed to create persona") raise HTTPException(status_code=400, detail=str(e)) return PersonaSnapshot.from_model(persona) + + +def fetch_persona_by_id(db_session: Session, persona_id: int) -> Persona | None: + return db_session.scalar(select(Persona).where(Persona.id == persona_id)) diff --git a/backend/danswer/db/slack_bot_config.py b/backend/danswer/db/slack_bot_config.py index 9b792ff08..973d76244 100644 --- a/backend/danswer/db/slack_bot_config.py +++ b/backend/danswer/db/slack_bot_config.py @@ -59,6 +59,7 @@ def create_slack_bot_persona( recency_bias=RecencyBiasSetting.AUTO, prompts=None, document_sets=document_sets, + llm_model_provider_override=None, llm_model_version_override=None, starter_messages=None, is_public=True, diff --git a/backend/danswer/dynamic_configs/port_configs.py b/backend/danswer/dynamic_configs/port_configs.py index 34abcff74..d0c55e698 100644 --- a/backend/danswer/dynamic_configs/port_configs.py +++ b/backend/danswer/dynamic_configs/port_configs.py @@ -1,9 +1,27 @@ import json from pathlib import Path +from typing import cast from danswer.configs.app_configs import DYNAMIC_CONFIG_DIR_PATH +from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY +from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION +from danswer.configs.model_configs import GEN_AI_API_ENDPOINT +from danswer.configs.model_configs import GEN_AI_API_KEY +from danswer.configs.model_configs import GEN_AI_API_VERSION +from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.db.engine import get_session_context_manager +from danswer.db.llm import fetch_existing_llm_providers +from danswer.db.llm import update_default_provider +from danswer.db.llm import upsert_llm_provider +from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.factory import PostgresBackedDynamicConfigStore from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.server.manage.llm.models import LLMProviderUpsertRequest +from danswer.utils.logger import setup_logger + + +logger = setup_logger() def read_file_system_store(directory_path: str) -> dict: @@ -38,3 +56,60 @@ def insert_into_postgres(store_data: dict) -> None: def port_filesystem_to_postgres(directory_path: str = DYNAMIC_CONFIG_DIR_PATH) -> None: store_data = read_file_system_store(directory_path) insert_into_postgres(store_data) + + +def port_api_key_to_postgres() -> None: + # can't port over custom, no longer supported + if GEN_AI_MODEL_PROVIDER == "custom": + return + + with get_session_context_manager() as db_session: + # if we already have ported things over / setup providers in the db, don't do anything + if len(fetch_existing_llm_providers(db_session)) > 0: + return + + api_key = GEN_AI_API_KEY + try: + api_key = cast( + str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY) + ) + except ConfigNotFoundError: + pass + + # if no API key set, don't port anything over + if not api_key: + return + + default_model_name = GEN_AI_MODEL_VERSION + if GEN_AI_MODEL_PROVIDER == "openai" and not default_model_name: + default_model_name = "gpt-4" + + # if no default model name found, don't port anything over + if not default_model_name: + return + + default_fast_model_name = FAST_GEN_AI_MODEL_VERSION + if GEN_AI_MODEL_PROVIDER == "openai" and not default_fast_model_name: + default_fast_model_name = "gpt-3.5-turbo" + + llm_provider_upsert = LLMProviderUpsertRequest( + name=GEN_AI_MODEL_PROVIDER, + api_key=api_key, + api_base=GEN_AI_API_ENDPOINT, + api_version=GEN_AI_API_VERSION, + # can't port over any custom configs, since we don't know + # all the possible keys and values that could be in there + custom_config=None, + default_model_name=default_model_name, + fast_default_model_name=default_fast_model_name, + model_names=None, + ) + llm_provider = upsert_llm_provider(db_session, llm_provider_upsert) + update_default_provider(db_session, llm_provider.id) + logger.info(f"Ported over LLM provider:\n\n{llm_provider}") + + # delete the old API key + try: + get_dynamic_config_store().delete(GEN_AI_API_KEY_STORAGE_KEY) + except ConfigNotFoundError: + pass diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 6c07eccda..e6c1437b6 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -12,7 +12,6 @@ from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.configs.chat_configs import QA_TIMEOUT from danswer.llm.answering.doc_pruning import prune_documents from danswer.llm.answering.models import AnswerStyleConfig -from danswer.llm.answering.models import LLMConfig from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig from danswer.llm.answering.models import StreamProcessor @@ -26,7 +25,7 @@ from danswer.llm.answering.stream_processing.citation_processing import ( from danswer.llm.answering.stream_processing.quotes_processing import ( build_quotes_processor, ) -from danswer.llm.factory import get_default_llm +from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_tokenizer @@ -53,7 +52,7 @@ class Answer: question: str, docs: list[LlmDoc], answer_style_config: AnswerStyleConfig, - llm_config: LLMConfig, + llm: LLM, prompt_config: PromptConfig, # must be the same length as `docs`. If None, all docs are considered "relevant" doc_relevance_list: list[bool] | None = None, @@ -74,15 +73,9 @@ class Answer: self.single_message_history = single_message_history self.answer_style_config = answer_style_config - self.llm_config = llm_config self.prompt_config = prompt_config - self.llm = get_default_llm( - gen_ai_model_provider=self.llm_config.model_provider, - gen_ai_model_version_override=self.llm_config.model_version, - timeout=timeout, - temperature=self.llm_config.temperature, - ) + self.llm = llm self.llm_tokenizer = get_default_llm_tokenizer() self._final_prompt: list[BaseMessage] | None = None @@ -101,7 +94,7 @@ class Answer: docs=self.docs, doc_relevance_list=self.doc_relevance_list, prompt_config=self.prompt_config, - llm_config=self.llm_config, + llm_config=self.llm.config, question=self.question, document_pruning_config=self.answer_style_config.document_pruning_config, ) @@ -116,7 +109,7 @@ class Answer: self._final_prompt = build_citations_prompt( question=self.question, message_history=self.message_history, - llm_config=self.llm_config, + llm_config=self.llm.config, prompt_config=self.prompt_config, context_docs=self.pruned_docs, all_doc_useful=self.answer_style_config.citation_config.all_docs_useful, diff --git a/backend/danswer/llm/answering/doc_pruning.py b/backend/danswer/llm/answering/doc_pruning.py index bf0f2be25..fa243895a 100644 --- a/backend/danswer/llm/answering/doc_pruning.py +++ b/backend/danswer/llm/answering/doc_pruning.py @@ -7,9 +7,9 @@ from danswer.chat.models import ( from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.llm.answering.models import DocumentPruningConfig -from danswer.llm.answering.models import LLMConfig from danswer.llm.answering.models import PromptConfig from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens +from danswer.llm.interfaces import LLMConfig from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import tokenizer_trim_content from danswer.prompts.prompt_utils import build_doc_context_str diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py index 42d1218cf..47282de88 100644 --- a/backend/danswer/llm/answering/models.py +++ b/backend/danswer/llm/answering/models.py @@ -9,15 +9,11 @@ from pydantic import root_validator from danswer.chat.models import AnswerQuestionStreamReturn from danswer.configs.constants import MessageType -from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER -from danswer.llm.override_models import LLMOverride from danswer.llm.override_models import PromptOverride -from danswer.llm.utils import get_default_llm_version if TYPE_CHECKING: from danswer.db.models import ChatMessage from danswer.db.models import Prompt - from danswer.db.models import Persona StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn] @@ -86,36 +82,6 @@ class AnswerStyleConfig(BaseModel): return values -class LLMConfig(BaseModel): - """Final representation of the LLM configuration passed into - the `Answer` object.""" - - model_provider: str - model_version: str - temperature: float - - @classmethod - def from_persona( - cls, persona: "Persona", llm_override: LLMOverride | None = None - ) -> "LLMConfig": - model_provider_override = llm_override.model_provider if llm_override else None - model_version_override = llm_override.model_version if llm_override else None - temperature_override = llm_override.temperature if llm_override else None - - return cls( - model_provider=model_provider_override or GEN_AI_MODEL_PROVIDER, - model_version=( - model_version_override - or persona.llm_model_version_override - or get_default_llm_version()[0] - ), - temperature=temperature_override or 0.0, - ) - - class Config: - frozen = True - - class PromptConfig(BaseModel): """Final representation of the Prompt configuration passed into the `Answer` object.""" diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py index 88bb30c9c..c807f0672 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -11,9 +11,10 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.db.chat import get_default_prompt from danswer.db.models import Persona -from danswer.llm.answering.models import LLMConfig from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig +from danswer.llm.factory import get_llm_for_persona +from danswer.llm.interfaces import LLMConfig from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import get_max_input_tokens @@ -131,7 +132,9 @@ def compute_max_document_tokens( max_input_tokens = ( max_llm_token_override if max_llm_token_override - else get_max_input_tokens(model_name=llm_config.model_version) + else get_max_input_tokens( + model_name=llm_config.model_name, model_provider=llm_config.model_provider + ) ) prompt_tokens = get_prompt_tokens(prompt_config) @@ -152,7 +155,7 @@ def compute_max_document_tokens_for_persona( prompt = persona.prompts[0] if persona.prompts else get_default_prompt() return compute_max_document_tokens( prompt_config=PromptConfig.from_model(prompt), - llm_config=LLMConfig.from_persona(persona), + llm_config=get_llm_for_persona(persona).config, actual_user_input=actual_user_input, max_llm_token_override=max_llm_token_override, ) @@ -162,7 +165,7 @@ def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int: """Maximum tokens allows in the input to the LLM (of any type).""" input_tokens = get_max_input_tokens( - model_name=llm_config.model_version, model_provider=llm_config.model_provider + model_name=llm_config.model_name, model_provider=llm_config.model_provider ) return input_tokens - _MISC_BUFFER diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index c61db946d..9d151a98e 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -1,4 +1,5 @@ import abc +import os from collections.abc import Iterator import litellm # type:ignore @@ -12,11 +13,9 @@ from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_API_VERSION from danswer.configs.model_configs import GEN_AI_LLM_PROVIDER_TYPE from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS -from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.llm.interfaces import LLM -from danswer.llm.utils import get_default_llm_version +from danswer.llm.interfaces import LLMConfig from danswer.llm.utils import message_generator_to_string_generator from danswer.llm.utils import should_be_verbose from danswer.utils.logger import setup_logger @@ -69,6 +68,7 @@ class LangChainChatLLM(LLM, abc.ABC): def stream(self, prompt: LanguageModelInput) -> Iterator[str]: if LOG_ALL_MODEL_INTERACTIONS: + self.log_model_configs() self._log_prompt(prompt) if DISABLE_LITELLM_STREAMING: @@ -85,23 +85,6 @@ class LangChainChatLLM(LLM, abc.ABC): logger.debug(f"Raw Model Output:\n{full_output}") -def _get_model_str( - model_provider: str | None, - model_version: str | None, -) -> str: - if model_provider and model_version: - return model_provider + "/" + model_version - - if model_version: - # Litellm defaults to openai if no provider specified - # It's implicit so no need to specify here either - return model_version - - # User specified something wrong, just use Danswer default - base, _ = get_default_llm_version() - return base - - class DefaultMultiLLM(LangChainChatLLM): """Uses Litellm library to allow easy configuration to use a multitude of LLMs See https://python.langchain.com/docs/integrations/chat/litellm""" @@ -115,25 +98,35 @@ class DefaultMultiLLM(LangChainChatLLM): self, api_key: str | None, timeout: int, - model_provider: str = GEN_AI_MODEL_PROVIDER, - model_version: str | None = GEN_AI_MODEL_VERSION, + model_provider: str, + model_name: str, api_base: str | None = GEN_AI_API_ENDPOINT, api_version: str | None = GEN_AI_API_VERSION, custom_llm_provider: str | None = GEN_AI_LLM_PROVIDER_TYPE, max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, temperature: float = GEN_AI_TEMPERATURE, + custom_config: dict[str, str] | None = None, ): + self._model_provider = model_provider + self._model_version = model_name + self._temperature = temperature + # Litellm Langchain integration currently doesn't take in the api key param # Can place this in the call below once integration is in litellm.api_key = api_key or "dummy-key" litellm.api_version = api_version - model_version = model_version or get_default_llm_version()[0] + # NOTE: have to set these as environment variables for Litellm since + # not all are able to passed in but they always support them set as env + # variables + if custom_config: + for k, v in custom_config.items(): + os.environ[k] = v self._llm = ChatLiteLLM( # type: ignore - model=model_version - if custom_llm_provider - else _get_model_str(model_provider, model_version), + model=( + model_name if custom_llm_provider else f"{model_provider}/{model_name}" + ), api_base=api_base, custom_llm_provider=custom_llm_provider, max_tokens=max_output_tokens, @@ -148,6 +141,14 @@ class DefaultMultiLLM(LangChainChatLLM): max_retries=0, # retries are handled outside of langchain ) + @property + def config(self) -> LLMConfig: + return LLMConfig( + model_provider=self._model_provider, + model_name=self._model_version, + temperature=self._temperature, + ) + @property def llm(self) -> ChatLiteLLM: return self._llm diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index f274aa790..d1838b618 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -1,48 +1,90 @@ from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.chat_configs import QA_TIMEOUT -from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_TEMPERATURE +from danswer.db.engine import get_session_context_manager +from danswer.db.llm import fetch_default_provider +from danswer.db.llm import fetch_provider +from danswer.db.models import Persona from danswer.llm.chat_llm import DefaultMultiLLM -from danswer.llm.custom_llm import CustomModelServer from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.gpt_4_all import DanswerGPT4All from danswer.llm.interfaces import LLM -from danswer.llm.utils import get_default_llm_version -from danswer.llm.utils import get_gen_ai_api_key +from danswer.llm.override_models import LLMOverride + + +def get_llm_for_persona( + persona: Persona, llm_override: LLMOverride | None = None +) -> LLM: + model_provider_override = llm_override.model_provider if llm_override else None + model_version_override = llm_override.model_version if llm_override else None + temperature_override = llm_override.temperature if llm_override else None + + return get_default_llm( + gen_ai_model_provider=model_provider_override + or persona.llm_model_provider_override, + gen_ai_model_version_override=( + model_version_override or persona.llm_model_version_override + ), + temperature=temperature_override or GEN_AI_TEMPERATURE, + ) def get_default_llm( - gen_ai_model_provider: str = GEN_AI_MODEL_PROVIDER, - api_key: str | None = None, timeout: int = QA_TIMEOUT, temperature: float = GEN_AI_TEMPERATURE, use_fast_llm: bool = False, + gen_ai_model_provider: str | None = None, gen_ai_model_version_override: str | None = None, ) -> LLM: - """A single place to fetch the configured LLM for Danswer - Also allows overriding certain LLM defaults""" if DISABLE_GENERATIVE_AI: raise GenAIDisabledException() - if gen_ai_model_version_override: - model_version = gen_ai_model_version_override - else: - base, fast = get_default_llm_version() - model_version = fast if use_fast_llm else base - if api_key is None: - api_key = get_gen_ai_api_key() + # TODO: pass this in + with get_session_context_manager() as session: + if gen_ai_model_provider is None: + llm_provider = fetch_default_provider(session) + else: + llm_provider = fetch_provider(session, gen_ai_model_provider) - if gen_ai_model_provider.lower() == "custom": - return CustomModelServer(api_key=api_key, timeout=timeout) + if not llm_provider: + raise ValueError("No default LLM provider found") - if gen_ai_model_provider.lower() == "gpt4all": - return DanswerGPT4All( - model_version=model_version, timeout=timeout, temperature=temperature - ) + model_name = gen_ai_model_version_override or ( + (llm_provider.fast_default_model_name or llm_provider.default_model_name) + if use_fast_llm + else llm_provider.default_model_name + ) + if not model_name: + raise ValueError("No default model name found") - return DefaultMultiLLM( - model_version=model_version, - api_key=api_key, + return get_llm( + provider=llm_provider.name, + model=model_name, + api_key=llm_provider.api_key, + api_base=llm_provider.api_base, + api_version=llm_provider.api_version, + custom_config=llm_provider.custom_config, timeout=timeout, temperature=temperature, ) + + +def get_llm( + provider: str, + model: str, + api_key: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + custom_config: dict[str, str] | None = None, + temperature: float = GEN_AI_TEMPERATURE, + timeout: int = QA_TIMEOUT, +) -> LLM: + return DefaultMultiLLM( + model_provider=provider, + model_name=model, + api_key=api_key, + api_base=api_base, + api_version=api_version, + timeout=timeout, + temperature=temperature, + custom_config=custom_config, + ) diff --git a/backend/danswer/llm/gpt_4_all.py b/backend/danswer/llm/gpt_4_all.py index 78e5d8bac..c7cf6a615 100644 --- a/backend/danswer/llm/gpt_4_all.py +++ b/backend/danswer/llm/gpt_4_all.py @@ -4,11 +4,9 @@ from typing import Any from langchain.schema.language_model import LanguageModelInput from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.llm.interfaces import LLM from danswer.llm.utils import convert_lm_input_to_basic_string -from danswer.llm.utils import get_default_llm_version from danswer.utils.logger import setup_logger @@ -38,7 +36,10 @@ except ImportError: class DanswerGPT4All(LLM): """Option to run an LLM locally, however this is significantly slower and - answers tend to be much worse""" + answers tend to be much worse + + NOTE: currently unused, but kept for future reference / if we want to add this back. + """ @property def requires_warm_up(self) -> bool: @@ -53,14 +54,14 @@ class DanswerGPT4All(LLM): def __init__( self, timeout: int, - model_version: str | None = GEN_AI_MODEL_VERSION, + model_version: str, max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, temperature: float = GEN_AI_TEMPERATURE, ): self.timeout = timeout self.max_output_tokens = max_output_tokens self.temperature = temperature - self.gpt4all_model = GPT4All(model_version or get_default_llm_version()[0]) + self.gpt4all_model = GPT4All(model_version) def log_model_configs(self) -> None: logger.debug( diff --git a/backend/danswer/llm/interfaces.py b/backend/danswer/llm/interfaces.py index 41fe428bb..c1cbe6253 100644 --- a/backend/danswer/llm/interfaces.py +++ b/backend/danswer/llm/interfaces.py @@ -2,6 +2,7 @@ import abc from collections.abc import Iterator from langchain.schema.language_model import LanguageModelInput +from pydantic import BaseModel from danswer.utils.logger import setup_logger @@ -9,6 +10,12 @@ from danswer.utils.logger import setup_logger logger = setup_logger() +class LLMConfig(BaseModel): + model_provider: str + model_name: str + temperature: float + + class LLM(abc.ABC): """Mimics the LangChain LLM / BaseChatModel interfaces to make it easy to use these implementations to connect to a variety of LLM providers.""" @@ -22,6 +29,11 @@ class LLM(abc.ABC): def requires_api_key(self) -> bool: return True + @property + @abc.abstractmethod + def config(self) -> LLMConfig: + raise NotImplementedError + @abc.abstractmethod def log_model_configs(self) -> None: raise NotImplementedError diff --git a/backend/danswer/llm/options.py b/backend/danswer/llm/options.py new file mode 100644 index 000000000..835f1c74e --- /dev/null +++ b/backend/danswer/llm/options.py @@ -0,0 +1,109 @@ +import litellm # type: ignore +from pydantic import BaseModel + + +class WellKnownLLMProviderDescriptor(BaseModel): + name: str + display_name: str | None = None + api_key_required: bool + api_base_required: bool + api_version_required: bool + custom_config_keys: list[str] | None = None + + llm_names: list[str] + default_model: str | None = None + default_fast_model: str | None = None + + +OPENAI_PROVIDER_NAME = "openai" +OPEN_AI_MODEL_NAMES = [ + "gpt-4", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-32k", + "gpt-4-0613", + "gpt-4-32k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-3.5-turbo", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-3.5-turbo-0301", +] + +BEDROCK_PROVIDER_NAME = "bedrock" +# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named +# models +BEDROCK_MODEL_NAMES = [model for model in litellm.bedrock_models if "/" not in model][ + ::-1 +] + +ANTHROPIC_PROVIDER_NAME = "anthropic" +ANTHROPIC_MODEL_NAMES = [model for model in litellm.anthropic_models][::-1] + +AZURE_PROVIDER_NAME = "azure" + + +_PROVIDER_TO_MODELS_MAP = { + OPENAI_PROVIDER_NAME: OPEN_AI_MODEL_NAMES, + BEDROCK_PROVIDER_NAME: BEDROCK_MODEL_NAMES, + ANTHROPIC_PROVIDER_NAME: ANTHROPIC_MODEL_NAMES, +} + + +def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]: + return [ + WellKnownLLMProviderDescriptor( + name="openai", + display_name="OpenAI", + api_key_required=True, + api_base_required=False, + api_version_required=False, + custom_config_keys=[], + llm_names=fetch_models_for_provider("openai"), + default_model="gpt-4", + default_fast_model="gpt-3.5-turbo", + ), + WellKnownLLMProviderDescriptor( + name=ANTHROPIC_PROVIDER_NAME, + display_name="Anthropic", + api_key_required=True, + api_base_required=False, + api_version_required=False, + custom_config_keys=[], + llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME), + default_model="claude-3-opus-20240229", + default_fast_model="claude-3-sonnet-20240229", + ), + WellKnownLLMProviderDescriptor( + name=AZURE_PROVIDER_NAME, + display_name="Azure OpenAI", + api_key_required=True, + api_base_required=True, + api_version_required=True, + custom_config_keys=[], + llm_names=fetch_models_for_provider(AZURE_PROVIDER_NAME), + ), + WellKnownLLMProviderDescriptor( + name=BEDROCK_PROVIDER_NAME, + display_name="AWS Bedrock", + api_key_required=False, + api_base_required=False, + api_version_required=False, + custom_config_keys=[ + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_REGION_NAME", + ], + llm_names=fetch_models_for_provider(BEDROCK_PROVIDER_NAME), + default_model="anthropic.claude-3-sonnet-20240229-v1:0", + default_fast_model="anthropic.claude-3-haiku-20240307-v1:0", + ), + ] + + +def fetch_models_for_provider(provider_name: str) -> list[str]: + return _PROVIDER_TO_MODELS_MAP.get(provider_name, []) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index d9c59c7b6..595ccc687 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -1,9 +1,7 @@ from collections.abc import Callable from collections.abc import Iterator from copy import copy -from functools import lru_cache from typing import Any -from typing import cast from typing import TYPE_CHECKING from typing import Union @@ -20,19 +18,12 @@ from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage from tiktoken.core import Encoding -from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY -from danswer.configs.constants import GEN_AI_DETECTED_MODEL from danswer.configs.constants import MessageType from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE -from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION -from danswer.configs.model_configs import GEN_AI_API_KEY from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.db.models import ChatMessage -from danswer.dynamic_configs.factory import get_dynamic_config_store -from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.llm.interfaces import LLM from danswer.search.models import InferenceChunk from danswer.utils.logger import setup_logger @@ -47,34 +38,6 @@ _LLM_TOKENIZER: Any = None _LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None -@lru_cache() -def get_default_llm_version() -> tuple[str, str]: - default_openai_model = "gpt-3.5-turbo-16k-0613" - if GEN_AI_MODEL_VERSION: - llm_version = GEN_AI_MODEL_VERSION - else: - if GEN_AI_MODEL_PROVIDER != "openai": - logger.warning("No LLM Model Version set") - # Either this value is unused or it will throw an error - llm_version = default_openai_model - else: - kv_store = get_dynamic_config_store() - try: - llm_version = cast(str, kv_store.load(GEN_AI_DETECTED_MODEL)) - except ConfigNotFoundError: - llm_version = default_openai_model - - if FAST_GEN_AI_MODEL_VERSION: - fast_llm_version = FAST_GEN_AI_MODEL_VERSION - else: - if GEN_AI_MODEL_PROVIDER == "openai": - fast_llm_version = default_openai_model - else: - fast_llm_version = llm_version - - return llm_version, fast_llm_version - - def get_default_llm_tokenizer() -> Encoding: """Currently only supports the OpenAI default tokenizer: tiktoken""" global _LLM_TOKENIZER @@ -219,17 +182,6 @@ def check_number_of_tokens( return len(encode_fn(text)) -def get_gen_ai_api_key() -> str | None: - # first check if the key has been provided by the UI - try: - return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)) - except ConfigNotFoundError: - pass - - # if not provided by the UI, fallback to the env variable - return GEN_AI_API_KEY - - def test_llm(llm: LLM) -> str | None: # try for up to 2 timeouts (e.g. 10 seconds in total) error_msg = None @@ -246,7 +198,7 @@ def test_llm(llm: LLM) -> str | None: def get_llm_max_tokens( model_map: dict, - model_name: str | None = GEN_AI_MODEL_VERSION, + model_name: str, model_provider: str = GEN_AI_MODEL_PROVIDER, ) -> int: """Best effort attempt to get the max tokens for the LLM""" @@ -254,13 +206,10 @@ def get_llm_max_tokens( # This is an override, so always return this return GEN_AI_MAX_TOKENS - model_name = model_name or get_default_llm_version()[0] - try: - if model_provider == "openai": + model_obj = model_map.get(f"{model_provider}/{model_name}") + if not model_obj: model_obj = model_map[model_name] - else: - model_obj = model_map[f"{model_provider}/{model_name}"] if "max_input_tokens" in model_obj: return model_obj["max_input_tokens"] @@ -277,8 +226,8 @@ def get_llm_max_tokens( def get_max_input_tokens( - model_name: str | None = GEN_AI_MODEL_VERSION, - model_provider: str = GEN_AI_MODEL_PROVIDER, + model_name: str, + model_provider: str, output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, ) -> int: # NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually @@ -289,8 +238,6 @@ def get_max_input_tokens( # model_map is litellm.model_cost litellm_model_map = litellm.model_cost - model_name = model_name or get_default_llm_version()[0] - input_toks = ( get_llm_max_tokens( model_name=model_name, diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 8e27482ae..c43833e56 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -33,8 +33,6 @@ from danswer.configs.app_configs import SECRET from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.constants import AuthType -from danswer.configs.model_configs import GEN_AI_API_ENDPOINT -from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.db.chat import delete_old_default_personas from danswer.db.connector import create_initial_default_connector from danswer.db.connector_credential_pair import associate_default_cc_pair @@ -48,9 +46,8 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import expire_index_attempts from danswer.db.swap_index import check_index_swap from danswer.document_index.factory import get_default_document_index +from danswer.dynamic_configs.port_configs import port_api_key_to_postgres from danswer.dynamic_configs.port_configs import port_filesystem_to_postgres -from danswer.llm.factory import get_default_llm -from danswer.llm.utils import get_default_llm_version from danswer.search.retrieval.search_runner import download_nltk_data from danswer.search.search_nlp_models import warm_up_encoders from danswer.server.auth_check import check_router_auth @@ -67,6 +64,8 @@ from danswer.server.features.prompt.api import basic_router as prompt_router from danswer.server.gpts.api import router as gpts_router from danswer.server.manage.administrative import router as admin_router from danswer.server.manage.get_state import router as state_router +from danswer.server.manage.llm.api import admin_router as llm_admin_router +from danswer.server.manage.llm.api import basic_router as llm_router from danswer.server.manage.secondary_index import router as secondary_index_router from danswer.server.manage.slack_bot import router as slack_bot_management_router from danswer.server.manage.users import router as user_router @@ -156,17 +155,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: if DISABLE_GENERATIVE_AI: logger.info("Generative AI Q&A disabled") - else: - logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}") - base, fast = get_default_llm_version() - logger.info(f"Using LLM Model Version: {base}") - if base != fast: - logger.info(f"Using Fast LLM Model Version: {fast}") - if GEN_AI_API_ENDPOINT: - logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}") - - # Any additional model configs logged here - get_default_llm().log_model_configs() if MULTILINGUAL_QUERY_EXPANSION: logger.info( @@ -180,6 +168,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: "Skipping port of persistent volumes. Maybe these have already been removed?" ) + try: + port_api_key_to_postgres() + except Exception as e: + logger.debug(f"Failed to port API keys. Exception: {e}. Continuing...") + with Session(engine) as db_session: check_index_swap(db_session=db_session) db_embedding_model = get_current_db_embedding_model(db_session) @@ -281,6 +274,8 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended(application, gpts_router) include_router_with_global_prefix_prepended(application, settings_router) include_router_with_global_prefix_prepended(application, settings_admin_router) + include_router_with_global_prefix_prepended(application, llm_admin_router) + include_router_with_global_prefix_prepended(application, llm_router) if AUTH_TYPE == AuthType.DISABLED: # Server logs this during auth setup verification step diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index ff6e04a21..d4707d12d 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -28,9 +28,9 @@ from danswer.llm.answering.answer import Answer from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import CitationConfig from danswer.llm.answering.models import DocumentPruningConfig -from danswer.llm.answering.models import LLMConfig from danswer.llm.answering.models import PromptConfig from danswer.llm.answering.models import QuotesConfig +from danswer.llm.factory import get_llm_for_persona from danswer.llm.utils import get_default_llm_token_encode from danswer.one_shot_answer.models import DirectQARequest from danswer.one_shot_answer.models import OneShotQAResponse @@ -212,7 +212,7 @@ def stream_answer_objects( docs=[llm_doc_from_inference_section(section) for section in top_sections], answer_style_config=answer_config, prompt_config=PromptConfig.from_model(prompt), - llm_config=LLMConfig.from_persona(chat_session.persona), + llm=get_llm_for_persona(persona=chat_session.persona), doc_relevance_list=search_pipeline.section_relevance_list, single_message_history=history_str, timeout=timeout, diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index bfaea792f..f316560d0 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -5,7 +5,6 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user from danswer.auth.users import current_user -from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.db.chat import get_persona_by_id from danswer.db.chat import get_personas from danswer.db.chat import mark_persona_as_deleted @@ -15,7 +14,6 @@ from danswer.db.engine import get_session from danswer.db.models import User from danswer.db.persona import create_update_persona from danswer.llm.answering.prompts.utils import build_dummy_prompt -from danswer.llm.utils import get_default_llm_version from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import PersonaSnapshot from danswer.server.features.persona.models import PromptTemplateResponse @@ -168,49 +166,3 @@ def build_final_template_prompt( retrieval_disabled=retrieval_disabled, ) ) - - -"""Utility endpoints for selecting which model to use for a persona. -Putting here for now, since we have no other flows which use this.""" - -GPT_4_MODEL_VERSIONS = [ - "gpt-4", - "gpt-4-turbo-preview", - "gpt-4-1106-preview", - "gpt-4-32k", - "gpt-4-0613", - "gpt-4-32k-0613", - "gpt-4-0314", - "gpt-4-32k-0314", -] -GPT_3_5_TURBO_MODEL_VERSIONS = [ - "gpt-3.5-turbo", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-1106", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-3.5-turbo-0301", -] - - -@basic_router.get("/utils/list-available-models") -def list_available_model_versions( - _: User | None = Depends(current_user), -) -> list[str]: - # currently only support selecting different models for OpenAI - if GEN_AI_MODEL_PROVIDER != "openai": - return [] - - return GPT_4_MODEL_VERSIONS + GPT_3_5_TURBO_MODEL_VERSIONS - - -@basic_router.get("/utils/default-model") -def get_default_model( - _: User | None = Depends(current_user), -) -> str: - # currently only support selecting different models for OpenAI - if GEN_AI_MODEL_PROVIDER != "openai": - return "" - - return get_default_llm_version()[0] diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index 8826be2c3..d313e54d8 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -20,6 +20,7 @@ class CreatePersonaRequest(BaseModel): recency_bias: RecencyBiasSetting prompt_ids: list[int] document_set_ids: list[int] + llm_model_provider_override: str | None = None llm_model_version_override: str | None = None starter_messages: list[StarterMessage] | None = None # For Private Personas, who should be able to access these @@ -38,6 +39,7 @@ class PersonaSnapshot(BaseModel): num_chunks: float | None llm_relevance_filter: bool llm_filter_extraction: bool + llm_model_provider_override: str | None llm_model_version_override: str | None starter_messages: list[StarterMessage] | None default_persona: bool @@ -66,6 +68,7 @@ class PersonaSnapshot(BaseModel): num_chunks=persona.num_chunks, llm_relevance_filter=persona.llm_relevance_filter, llm_filter_extraction=persona.llm_filter_extraction, + llm_model_provider_override=persona.llm_model_provider_override, llm_model_version_override=persona.llm_model_version_override, starter_messages=persona.starter_messages, default_persona=persona.default_persona, diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index 02d980b04..9b3f8df21 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -1,5 +1,4 @@ import json -from collections.abc import Callable from datetime import datetime from datetime import timedelta from datetime import timezone @@ -16,13 +15,9 @@ from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ from danswer.configs.app_configs import TOKEN_BUDGET_GLOBALLY_ENABLED from danswer.configs.constants import DocumentSource from danswer.configs.constants import ENABLE_TOKEN_BUDGET -from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY -from danswer.configs.constants import GEN_AI_DETECTED_MODEL from danswer.configs.constants import TOKEN_BUDGET from danswer.configs.constants import TOKEN_BUDGET_SETTINGS from danswer.configs.constants import TOKEN_BUDGET_TIME_PERIOD -from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.db.connector_credential_pair import get_connector_credential_pair from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed from danswer.db.engine import get_session @@ -35,18 +30,13 @@ from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError -from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm -from danswer.llm.factory import get_default_llm_version -from danswer.llm.utils import get_gen_ai_api_key from danswer.llm.utils import test_llm from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.manage.models import BoostDoc from danswer.server.manage.models import BoostUpdateRequest from danswer.server.manage.models import HiddenUpdateRequest -from danswer.server.models import ApiKey from danswer.utils.logger import setup_logger -from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel router = APIRouter(prefix="/manage") logger = setup_logger() @@ -123,52 +113,6 @@ def document_hidden_update( raise HTTPException(status_code=400, detail=str(e)) -def _validate_llm_key(genai_api_key: str | None) -> None: - # Checking new API key, may change things for this if kv-store is updated - get_default_llm_version.cache_clear() - kv_store = get_dynamic_config_store() - try: - kv_store.delete(GEN_AI_DETECTED_MODEL) - except ConfigNotFoundError: - pass - - gpt_4_version = "gpt-4" # 32k is not available to most people - gpt4_llm = None - try: - llm = get_default_llm(api_key=genai_api_key, timeout=10) - if GEN_AI_MODEL_PROVIDER == "openai" and not GEN_AI_MODEL_VERSION: - gpt4_llm = get_default_llm( - gen_ai_model_version_override=gpt_4_version, - api_key=genai_api_key, - timeout=10, - ) - except GenAIDisabledException: - return - - functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))] - if gpt4_llm: - functions_with_args.append((test_llm, (gpt4_llm,))) - - parallel_results = run_functions_tuples_in_parallel( - functions_with_args, allow_failures=False - ) - - error_msg = parallel_results[0] - - if error_msg: - if genai_api_key is None: - raise HTTPException(status_code=404, detail="Key not found") - raise HTTPException(status_code=400, detail=error_msg) - - # Mark check as successful - curr_time = datetime.now(tz=timezone.utc) - kv_store.store(GEN_AI_KEY_CHECK_TIME, curr_time.timestamp()) - - # None for no errors - if gpt4_llm and parallel_results[1] is None: - kv_store.store(GEN_AI_DETECTED_MODEL, gpt_4_version) - - @router.get("/admin/genai-api-key/validate") def validate_existing_genai_api_key( _: User = Depends(current_admin_user), @@ -187,46 +131,18 @@ def validate_existing_genai_api_key( # First time checking the key, nothing unusual pass - genai_api_key = get_gen_ai_api_key() - _validate_llm_key(genai_api_key) - - -@router.get("/admin/genai-api-key", response_model=ApiKey) -def get_gen_ai_api_key_from_dynamic_config_store( - _: User = Depends(current_admin_user), -) -> ApiKey: - """ - NOTE: Only gets value from dynamic config store as to not expose env variables. - """ try: - # only get last 4 characters of key to not expose full key - return ApiKey( - api_key=cast( - str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY) - )[-4:] - ) - except ConfigNotFoundError: - raise HTTPException(status_code=404, detail="Key not found") + llm = get_default_llm(timeout=10) + except ValueError: + raise HTTPException(status_code=404, detail="LLM not setup") + error = test_llm(llm) + if error: + raise HTTPException(status_code=400, detail=error) -@router.put("/admin/genai-api-key") -def store_genai_api_key( - request: ApiKey, - _: User = Depends(current_admin_user), -) -> None: - if not request.api_key: - raise HTTPException(400, "No API key provided") - - _validate_llm_key(request.api_key) - - get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key) - - -@router.delete("/admin/genai-api-key") -def delete_genai_api_key( - _: User = Depends(current_admin_user), -) -> None: - get_dynamic_config_store().delete(GEN_AI_API_KEY_STORAGE_KEY) + # Mark check as successful + curr_time = datetime.now(tz=timezone.utc) + kv_store.store(GEN_AI_KEY_CHECK_TIME, curr_time.timestamp()) @router.post("/admin/deletion-attempt") diff --git a/backend/danswer/server/manage/llm/api.py b/backend/danswer/server/manage/llm/api.py new file mode 100644 index 000000000..3b4673522 --- /dev/null +++ b/backend/danswer/server/manage/llm/api.py @@ -0,0 +1,157 @@ +from collections.abc import Callable + +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from danswer.auth.users import current_admin_user +from danswer.auth.users import current_user +from danswer.db.engine import get_session +from danswer.db.llm import fetch_existing_llm_providers +from danswer.db.llm import remove_llm_provider +from danswer.db.llm import update_default_provider +from danswer.db.llm import upsert_llm_provider +from danswer.db.models import User +from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_llm +from danswer.llm.options import fetch_available_well_known_llms +from danswer.llm.options import WellKnownLLMProviderDescriptor +from danswer.llm.utils import test_llm +from danswer.server.manage.llm.models import FullLLMProvider +from danswer.server.manage.llm.models import LLMProviderDescriptor +from danswer.server.manage.llm.models import LLMProviderUpsertRequest +from danswer.server.manage.llm.models import TestLLMRequest +from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel + +logger = setup_logger() + + +admin_router = APIRouter(prefix="/admin/llm") +basic_router = APIRouter(prefix="/llm") + + +@admin_router.get("/built-in/options") +def fetch_llm_options( + _: User | None = Depends(current_admin_user), +) -> list[WellKnownLLMProviderDescriptor]: + return fetch_available_well_known_llms() + + +@admin_router.post("/test") +def test_llm_configuration( + test_llm_request: TestLLMRequest, + _: User | None = Depends(current_admin_user), +) -> None: + llm = get_llm( + provider=test_llm_request.provider, + model=test_llm_request.default_model_name, + api_key=test_llm_request.api_key, + api_base=test_llm_request.api_base, + api_version=test_llm_request.api_version, + custom_config=test_llm_request.custom_config, + ) + functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))] + + if ( + test_llm_request.default_fast_model_name + and test_llm_request.default_fast_model_name + != test_llm_request.default_model_name + ): + fast_llm = get_llm( + provider=test_llm_request.provider, + model=test_llm_request.default_fast_model_name, + api_key=test_llm_request.api_key, + api_base=test_llm_request.api_base, + api_version=test_llm_request.api_version, + custom_config=test_llm_request.custom_config, + ) + functions_with_args.append((test_llm, (fast_llm,))) + + parallel_results = run_functions_tuples_in_parallel( + functions_with_args, allow_failures=False + ) + error = parallel_results[0] or ( + parallel_results[1] if len(parallel_results) > 1 else None + ) + + if error: + raise HTTPException(status_code=400, detail=error) + + +@admin_router.post("/test/default") +def test_default_provider( + _: User | None = Depends(current_admin_user), +) -> None: + try: + llm = get_default_llm() + fast_llm = get_default_llm(use_fast_llm=True) + except ValueError: + logger.exception("Failed to fetch default LLM Provider") + raise HTTPException(status_code=400, detail="No LLM Provider setup") + + functions_with_args: list[tuple[Callable, tuple]] = [ + (test_llm, (llm,)), + (test_llm, (fast_llm,)), + ] + parallel_results = run_functions_tuples_in_parallel( + functions_with_args, allow_failures=False + ) + error = parallel_results[0] or ( + parallel_results[1] if len(parallel_results) > 1 else None + ) + if error: + raise HTTPException(status_code=400, detail=error) + + +@admin_router.get("/provider") +def list_llm_providers( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[FullLLMProvider]: + return [ + FullLLMProvider.from_model(llm_provider_model) + for llm_provider_model in fetch_existing_llm_providers(db_session) + ] + + +@admin_router.put("/provider") +def put_llm_provider( + llm_provider: LLMProviderUpsertRequest, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> FullLLMProvider: + return upsert_llm_provider(db_session, llm_provider) + + +@admin_router.delete("/provider/{provider_id}") +def delete_llm_provider( + provider_id: int, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + remove_llm_provider(db_session, provider_id) + + +@admin_router.post("/provider/{provider_id}/default") +def set_provider_as_default( + provider_id: int, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> None: + update_default_provider(db_session, provider_id) + + +"""Endpoints for all""" + + +@basic_router.get("/provider") +def list_llm_provider_basics( + _: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> list[LLMProviderDescriptor]: + return [ + LLMProviderDescriptor.from_model(llm_provider_model) + for llm_provider_model in fetch_existing_llm_providers(db_session) + ] diff --git a/backend/danswer/server/manage/llm/models.py b/backend/danswer/server/manage/llm/models.py new file mode 100644 index 000000000..628bb0a7a --- /dev/null +++ b/backend/danswer/server/manage/llm/models.py @@ -0,0 +1,89 @@ +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +from danswer.llm.options import fetch_models_for_provider + +if TYPE_CHECKING: + from danswer.db.models import LLMProvider as LLMProviderModel + + +class TestLLMRequest(BaseModel): + # provider level + provider: str + api_key: str | None = None + api_base: str | None = None + api_version: str | None = None + custom_config: dict[str, str] | None = None + + # model level + default_model_name: str + default_fast_model_name: str | None = None + + +class LLMProviderDescriptor(BaseModel): + """A descriptor for an LLM provider that can be safely viewed by + non-admin users. Used when giving a list of available LLMs.""" + + name: str + model_names: list[str] + default_model_name: str + fast_default_model_name: str | None + is_default_provider: bool | None + + @classmethod + def from_model( + cls, llm_provider_model: "LLMProviderModel" + ) -> "LLMProviderDescriptor": + return cls( + name=llm_provider_model.name, + default_model_name=llm_provider_model.default_model_name, + fast_default_model_name=llm_provider_model.fast_default_model_name, + is_default_provider=llm_provider_model.is_default_provider, + model_names=( + llm_provider_model.model_names + or fetch_models_for_provider(llm_provider_model.name) + or [llm_provider_model.default_model_name] + ), + ) + + +class LLMProvider(BaseModel): + name: str + api_key: str | None + api_base: str | None + api_version: str | None + custom_config: dict[str, str] | None + default_model_name: str + fast_default_model_name: str | None + + +class LLMProviderUpsertRequest(LLMProvider): + # should only be used for a "custom" provider + # for default providers, the built-in model names are used + model_names: list[str] | None + + +class FullLLMProvider(LLMProvider): + id: int + is_default_provider: bool | None + model_names: list[str] + + @classmethod + def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider": + return cls( + id=llm_provider_model.id, + name=llm_provider_model.name, + api_key=llm_provider_model.api_key, + api_base=llm_provider_model.api_base, + api_version=llm_provider_model.api_version, + custom_config=llm_provider_model.custom_config, + default_model_name=llm_provider_model.default_model_name, + fast_default_model_name=llm_provider_model.fast_default_model_name, + is_default_provider=llm_provider_model.is_default_provider, + model_names=( + llm_provider_model.model_names + or fetch_models_for_provider(llm_provider_model.name) + or [llm_provider_model.default_model_name] + ), + ) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index eab0f8935..fea64bded 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -3,6 +3,7 @@ alembic==1.10.4 asyncpg==0.27.0 atlassian-python-api==3.37.0 beautifulsoup4==4.12.2 +boto3==1.34.84 celery==5.3.4 chardet==5.2.0 dask==2023.8.1 diff --git a/web/package-lock.json b/web/package-lock.json index 85323bd85..d78585315 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -15,12 +15,14 @@ "@radix-ui/react-popover": "^1.0.7", "@tremor/react": "^3.9.2", "@types/js-cookie": "^3.0.3", + "@types/lodash": "^4.17.0", "@types/node": "18.15.11", "@types/react": "18.0.32", "@types/react-dom": "18.0.11", "autoprefixer": "^10.4.14", "formik": "^2.2.9", "js-cookie": "^3.0.5", + "lodash": "^4.17.21", "mdast-util-find-and-replace": "^3.0.1", "next": "^14.1.0", "postcss": "^8.4.31", @@ -1740,6 +1742,11 @@ "integrity": "sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==", "dev": true }, + "node_modules/@types/lodash": { + "version": "4.17.0", + "resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.0.tgz", + "integrity": "sha512-t7dhREVv6dbNj0q17X12j7yDG4bD/DHYX7o5/DbDxobP0HnGPgpRz2Ej77aL7TZT3DSw13fqUTj8J4mMnqa7WA==" + }, "node_modules/@types/mdast": { "version": "4.0.3", "resolved": "https://registry.npmjs.org/@types/mdast/-/mdast-4.0.3.tgz", diff --git a/web/package.json b/web/package.json index 37788280d..d25c82232 100644 --- a/web/package.json +++ b/web/package.json @@ -16,12 +16,14 @@ "@radix-ui/react-popover": "^1.0.7", "@tremor/react": "^3.9.2", "@types/js-cookie": "^3.0.3", + "@types/lodash": "^4.17.0", "@types/node": "18.15.11", "@types/react": "18.0.32", "@types/react-dom": "18.0.11", "autoprefixer": "^10.4.14", "formik": "^2.2.9", "js-cookie": "^3.0.5", + "lodash": "^4.17.21", "mdast-util-find-and-replace": "^3.0.1", "next": "^14.1.0", "postcss": "^8.4.31", diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index 05c5a71a2..bcded7031 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -31,6 +31,15 @@ import { Bubble } from "@/components/Bubble"; import { GroupsIcon } from "@/components/icons/icons"; import { SuccessfulPersonaUpdateRedirectType } from "./enums"; import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; +import { FullLLMProvider } from "../models/llm/interfaces"; +import { Option } from "@/components/Dropdown"; + +const DEFAULT_LLM_PROVIDER_TO_DISPLAY_NAME: Record = { + openai: "OpenAI", + azure: "Azure OpenAI", + anthropic: "Anthropic", + bedrock: "AWS Bedrock", +}; function Label({ children }: { children: string | JSX.Element }) { return ( @@ -46,20 +55,18 @@ export function AssistantEditor({ existingPersona, ccPairs, documentSets, - llmOverrideOptions, - defaultLLM, user, defaultPublic, redirectType, + llmProviders, }: { existingPersona?: Persona | null; ccPairs: CCPairBasicInfo[]; documentSets: DocumentSet[]; - llmOverrideOptions: string[]; - defaultLLM: string; user: User | null; defaultPublic: boolean; redirectType: SuccessfulPersonaUpdateRedirectType; + llmProviders: FullLLMProvider[]; }) { const router = useRouter(); const { popup, setPopup } = usePopup(); @@ -98,6 +105,21 @@ export function AssistantEditor({ } }, []); + const defaultLLM = llmProviders.find( + (llmProvider) => llmProvider.is_default_provider + )?.default_model_name; + + const modelOptionsByProvider = new Map[]>(); + llmProviders.forEach((llmProvider) => { + const providerOptions = llmProvider.model_names.map((modelName) => { + return { + name: modelName, + value: modelName, + }; + }); + modelOptionsByProvider.set(llmProvider.name, providerOptions); + }); + return (
{popup} @@ -118,6 +140,8 @@ export function AssistantEditor({ include_citations: existingPersona?.prompts[0]?.include_citations ?? true, llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false, + llm_model_provider_override: + existingPersona?.llm_model_provider_override ?? null, llm_model_version_override: existingPersona?.llm_model_version_override ?? null, starter_messages: existingPersona?.starter_messages ?? [], @@ -139,6 +163,7 @@ export function AssistantEditor({ include_citations: Yup.boolean().required(), llm_relevance_filter: Yup.boolean().required(), llm_model_version_override: Yup.string().nullable(), + llm_model_provider_override: Yup.string().nullable(), starter_messages: Yup.array().of( Yup.object().shape({ name: Yup.string().required(), @@ -178,6 +203,18 @@ export function AssistantEditor({ return; } + if ( + values.llm_model_provider_override && + !values.llm_model_version_override + ) { + setPopup({ + type: "error", + message: + "Must select a model if a non-default LLM provider is chosen.", + }); + return; + } + formikHelpers.setSubmitting(true); // if disable_retrieval is set, set num_chunks to 0 @@ -428,7 +465,7 @@ export function AssistantEditor({ )} - {llmOverrideOptions.length > 0 && defaultLLM && ( + {llmProviders.length > 0 && ( <> -
- { - return { - name: llmOption, - value: llmOption, - }; - })} - includeDefault={true} - /> +
+
+ LLM Provider + ({ + name: + DEFAULT_LLM_PROVIDER_TO_DISPLAY_NAME[ + llmProvider.name + ] || llmProvider.name, + value: llmProvider.name, + }))} + includeDefault={true} + onSelect={(selected) => { + if ( + selected !== values.llm_model_provider_override + ) { + setFieldValue( + "llm_model_version_override", + null + ); + } + setFieldValue( + "llm_model_provider_override", + selected + ); + }} + /> +
+ + {values.llm_model_provider_override && ( +
+ Model + +
+ )}
diff --git a/web/src/app/admin/assistants/interfaces.ts b/web/src/app/admin/assistants/interfaces.ts index 6e694d80f..00544aacb 100644 --- a/web/src/app/admin/assistants/interfaces.ts +++ b/web/src/app/admin/assistants/interfaces.ts @@ -30,6 +30,7 @@ export interface Persona { num_chunks?: number; llm_relevance_filter?: boolean; llm_filter_extraction?: boolean; + llm_model_provider_override?: string; llm_model_version_override?: string; starter_messages: StarterMessage[] | null; default_persona: boolean; diff --git a/web/src/app/admin/assistants/lib.ts b/web/src/app/admin/assistants/lib.ts index e5075f33e..39672f0c9 100644 --- a/web/src/app/admin/assistants/lib.ts +++ b/web/src/app/admin/assistants/lib.ts @@ -10,6 +10,7 @@ interface PersonaCreationRequest { include_citations: boolean; is_public: boolean; llm_relevance_filter: boolean | null; + llm_model_provider_override: string | null; llm_model_version_override: string | null; starter_messages: StarterMessage[] | null; users?: string[]; @@ -28,6 +29,7 @@ interface PersonaUpdateRequest { include_citations: boolean; is_public: boolean; llm_relevance_filter: boolean | null; + llm_model_provider_override: string | null; llm_model_version_override: string | null; starter_messages: StarterMessage[] | null; users?: string[]; @@ -117,6 +119,7 @@ function buildPersonaAPIBody( recency_bias: "base_decay", prompt_ids: [promptId], document_set_ids, + llm_model_provider_override: creationRequest.llm_model_provider_override, llm_model_version_override: creationRequest.llm_model_version_override, starter_messages: creationRequest.starter_messages, users, diff --git a/web/src/app/admin/models/embedding/page.tsx b/web/src/app/admin/models/embedding/page.tsx index 0612fe2c6..ccda9af19 100644 --- a/web/src/app/admin/models/embedding/page.tsx +++ b/web/src/app/admin/models/embedding/page.tsx @@ -1,13 +1,10 @@ "use client"; -import { LoadingAnimation, ThreeDotsLoader } from "@/components/Loading"; +import { ThreeDotsLoader } from "@/components/Loading"; import { AdminPageTitle } from "@/components/admin/Title"; -import { KeyIcon, TrashIcon } from "@/components/icons/icons"; -import { ApiKeyForm } from "@/components/openai/ApiKeyForm"; -import { GEN_AI_API_KEY_URL } from "@/components/openai/constants"; -import { errorHandlingFetcher, fetcher } from "@/lib/fetcher"; -import { Button, Card, Divider, Text, Title } from "@tremor/react"; -import { FiCpu, FiPackage } from "react-icons/fi"; +import { errorHandlingFetcher } from "@/lib/fetcher"; +import { Button, Card, Text, Title } from "@tremor/react"; +import { FiPackage } from "react-icons/fi"; import useSWR, { mutate } from "swr"; import { ModelOption, ModelSelector } from "./ModelSelector"; import { useState } from "react"; diff --git a/web/src/app/admin/models/llm/CustomLLMProviderUpdateForm.tsx b/web/src/app/admin/models/llm/CustomLLMProviderUpdateForm.tsx new file mode 100644 index 000000000..b81730ac1 --- /dev/null +++ b/web/src/app/admin/models/llm/CustomLLMProviderUpdateForm.tsx @@ -0,0 +1,450 @@ +import { LoadingAnimation } from "@/components/Loading"; +import { Button, Divider, Text } from "@tremor/react"; +import { + ArrayHelpers, + ErrorMessage, + Field, + FieldArray, + Form, + Formik, +} from "formik"; +import { FiPlus, FiTrash, FiX } from "react-icons/fi"; +import { LLM_PROVIDERS_ADMIN_URL } from "./constants"; +import { + Label, + SubLabel, + TextArrayField, + TextFormField, +} from "@/components/admin/connectors/Field"; +import { useState } from "react"; +import { useSWRConfig } from "swr"; +import { FullLLMProvider } from "./interfaces"; +import { PopupSpec } from "@/components/admin/connectors/Popup"; +import * as Yup from "yup"; +import isEqual from "lodash/isEqual"; + +function customConfigProcessing(customConfigsList: [string, string][]) { + const customConfig: { [key: string]: string } = {}; + customConfigsList.forEach(([key, value]) => { + customConfig[key] = value; + }); + return customConfig; +} + +export function CustomLLMProviderUpdateForm({ + onClose, + existingLlmProvider, + shouldMarkAsDefault, + setPopup, +}: { + onClose: () => void; + existingLlmProvider?: FullLLMProvider; + shouldMarkAsDefault?: boolean; + setPopup?: (popup: PopupSpec) => void; +}) { + const { mutate } = useSWRConfig(); + + const [isTesting, setIsTesting] = useState(false); + const [testError, setTestError] = useState(""); + const [isTestSuccessful, setTestSuccessful] = useState( + existingLlmProvider ? true : false + ); + + // Define the initial values based on the provider's requirements + const initialValues = { + name: existingLlmProvider?.name ?? "", + api_key: existingLlmProvider?.api_key ?? "", + api_base: existingLlmProvider?.api_base ?? "", + api_version: existingLlmProvider?.api_version ?? "", + default_model_name: existingLlmProvider?.default_model_name ?? null, + default_fast_model_name: + existingLlmProvider?.fast_default_model_name ?? null, + model_names: existingLlmProvider?.model_names ?? [], + custom_config_list: existingLlmProvider?.custom_config + ? Object.entries(existingLlmProvider.custom_config) + : [], + }; + + const [validatedConfig, setValidatedConfig] = useState( + existingLlmProvider ? initialValues : null + ); + + // Setup validation schema if required + const validationSchema = Yup.object({ + name: Yup.string().required("Name is required"), + api_key: Yup.string(), + api_base: Yup.string(), + api_version: Yup.string(), + model_names: Yup.array(Yup.string().required("Model name is required")), + default_model_name: Yup.string().required("Model name is required"), + default_fast_model_name: Yup.string().nullable(), + custom_config_list: Yup.array(), + }); + + return ( + { + if (!isEqual(values, validatedConfig)) { + setTestSuccessful(false); + } + }} + onSubmit={async (values, { setSubmitting }) => { + setSubmitting(true); + + if (!isTestSuccessful) { + setSubmitting(false); + return; + } + + if (values.model_names.length === 0) { + const fullErrorMsg = "At least one model name is required"; + if (setPopup) { + setPopup({ + type: "error", + message: fullErrorMsg, + }); + } else { + alert(fullErrorMsg); + } + setSubmitting(false); + return; + } + + const response = await fetch(LLM_PROVIDERS_ADMIN_URL, { + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + ...values, + custom_config: customConfigProcessing(values.custom_config_list), + }), + }); + + if (!response.ok) { + const errorMsg = (await response.json()).detail; + const fullErrorMsg = existingLlmProvider + ? `Failed to update provider: ${errorMsg}` + : `Failed to enable provider: ${errorMsg}`; + if (setPopup) { + setPopup({ + type: "error", + message: fullErrorMsg, + }); + } else { + alert(fullErrorMsg); + } + return; + } + + if (shouldMarkAsDefault) { + const newLlmProvider = (await response.json()) as FullLLMProvider; + const setDefaultResponse = await fetch( + `${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`, + { + method: "POST", + } + ); + if (!setDefaultResponse.ok) { + const errorMsg = (await setDefaultResponse.json()).detail; + const fullErrorMsg = `Failed to set provider as default: ${errorMsg}`; + if (setPopup) { + setPopup({ + type: "error", + message: fullErrorMsg, + }); + } else { + alert(fullErrorMsg); + } + return; + } + } + + mutate(LLM_PROVIDERS_ADMIN_URL); + onClose(); + + const successMsg = existingLlmProvider + ? "Provider updated successfully!" + : "Provider enabled successfully!"; + if (setPopup) { + setPopup({ + type: "success", + message: successMsg, + }); + } else { + alert(successMsg); + } + + setSubmitting(false); + }} + > + {({ values }) => ( +
+ + Should be one of the providers listed at{" "} + + https://docs.litellm.ai/docs/providers + + . + + } + placeholder="Name of the custom provider" + /> + + + + + Fill in the following as is needed. Refer to the LiteLLM + documentation for the model provider name specified above in order + to determine which fields are required. + + + + + + + + + + + <> +
+ Additional configurations needed by the model provider. Are + passed to litellm via environment variables. +
+ +
+ For example, when configuring the Cloudflare provider, you would + need to set `CLOUDFLARE_ACCOUNT_ID` as the key and your + Cloudflare account ID as the value. +
+ +
+ + ) => ( +
+ {values.custom_config_list.map((_, index) => { + return ( +
+
+
+
+ + + +
+ +
+ + + +
+
+
+ arrayHelpers.remove(index)} + /> +
+
+
+ ); + })} + + +
+ )} + /> + + + + + + + + + + + + + +
+ {/* NOTE: this is above the test button to make sure it's visible */} + {!isTestSuccessful && testError && ( + {testError} + )} + {isTestSuccessful && ( + + Test successful! LLM provider is ready to go. + + )} + +
+ {isTestSuccessful ? ( + + ) : ( + + )} + {existingLlmProvider && ( + + )} +
+
+ + )} +
+ ); +} diff --git a/web/src/app/admin/models/llm/LLMConfiguration.tsx b/web/src/app/admin/models/llm/LLMConfiguration.tsx new file mode 100644 index 000000000..744579677 --- /dev/null +++ b/web/src/app/admin/models/llm/LLMConfiguration.tsx @@ -0,0 +1,234 @@ +"use client"; + +import { Modal } from "@/components/Modal"; +import { errorHandlingFetcher } from "@/lib/fetcher"; +import { useState } from "react"; +import useSWR, { mutate } from "swr"; +import { Badge, Button, Text, Title } from "@tremor/react"; +import { ThreeDotsLoader } from "@/components/Loading"; +import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; +import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; +import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm"; +import { LLM_PROVIDERS_ADMIN_URL } from "./constants"; +import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm"; + +function LLMProviderUpdateModal({ + llmProviderDescriptor, + onClose, + existingLlmProvider, + shouldMarkAsDefault, + setPopup, +}: { + llmProviderDescriptor: WellKnownLLMProviderDescriptor | null; + onClose: () => void; + existingLlmProvider?: FullLLMProvider; + shouldMarkAsDefault?: boolean; + setPopup?: (popup: PopupSpec) => void; +}) { + const providerName = + llmProviderDescriptor?.display_name || + llmProviderDescriptor?.name || + existingLlmProvider?.name || + "Custom LLM Provider"; + return ( + onClose()}> +
+ {llmProviderDescriptor ? ( + + ) : ( + + )} +
+
+ ); +} + +function LLMProviderDisplay({ + llmProviderDescriptor, + existingLlmProvider, + shouldMarkAsDefault, +}: { + llmProviderDescriptor: WellKnownLLMProviderDescriptor | null; + existingLlmProvider?: FullLLMProvider; + shouldMarkAsDefault?: boolean; +}) { + const [formIsVisible, setFormIsVisible] = useState(false); + const { popup, setPopup } = usePopup(); + + const providerName = + llmProviderDescriptor?.display_name || + llmProviderDescriptor?.name || + existingLlmProvider?.name; + return ( +
+ {popup} +
+
+
{providerName}
+ {existingLlmProvider && !existingLlmProvider.is_default_provider && ( +
{ + const response = await fetch( + `${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`, + { + method: "POST", + } + ); + if (!response.ok) { + const errorMsg = (await response.json()).detail; + setPopup({ + type: "error", + message: `Failed to set provider as default: ${errorMsg}`, + }); + return; + } + + mutate(LLM_PROVIDERS_ADMIN_URL); + setPopup({ + type: "success", + message: "Provider set as default successfully!", + }); + }} + > + Set as default +
+ )} +
+ + {existingLlmProvider && ( +
+ {existingLlmProvider.is_default_provider ? ( + + Default + + ) : ( + + Enabled + + )} +
+ )} + +
+ +
+
+ {formIsVisible && ( + setFormIsVisible(false)} + existingLlmProvider={existingLlmProvider} + shouldMarkAsDefault={shouldMarkAsDefault} + setPopup={setPopup} + /> + )} +
+ ); +} + +function AddCustomLLMProvider({}) { + const [formIsVisible, setFormIsVisible] = useState(false); + + if (formIsVisible) { + return ( + setFormIsVisible(false)} + > +
+ setFormIsVisible(false)} + /> +
+
+ ); + } + + return ( + + ); +} + +export function LLMConfiguration() { + const { data: llmProviderDescriptors } = useSWR< + WellKnownLLMProviderDescriptor[] + >("/api/admin/llm/built-in/options", errorHandlingFetcher); + const { data: existingLlmProviders } = useSWR( + LLM_PROVIDERS_ADMIN_URL, + errorHandlingFetcher + ); + + if (!llmProviderDescriptors || !existingLlmProviders) { + return ; + } + + const wellKnownLLMProviderNames = llmProviderDescriptors.map( + (llmProviderDescriptor) => llmProviderDescriptor.name + ); + const customLLMProviders = existingLlmProviders.filter( + (llmProvider) => !wellKnownLLMProviderNames.includes(llmProvider.name) + ); + + return ( + <> + + If multiple LLM providers are enabled, the default provider will be used + for all "Default" Personas. For user-created Personas, you can + select the LLM provider/model that best fits the use case! + + + Default Providers +
+ {llmProviderDescriptors.map((llmProviderDescriptor) => { + const existingLlmProvider = existingLlmProviders.find( + (llmProvider) => llmProvider.name === llmProviderDescriptor.name + ); + + return ( + + ); + })} +
+ + Custom Providers + {customLLMProviders.length > 0 && ( +
+ {customLLMProviders.map((llmProvider) => ( + + ))} +
+ )} + + + + ); +} diff --git a/web/src/app/admin/models/llm/LLMProviderUpdateForm.tsx b/web/src/app/admin/models/llm/LLMProviderUpdateForm.tsx new file mode 100644 index 000000000..a9ddb8e4d --- /dev/null +++ b/web/src/app/admin/models/llm/LLMProviderUpdateForm.tsx @@ -0,0 +1,347 @@ +import { LoadingAnimation } from "@/components/Loading"; +import { Button, Divider, Text } from "@tremor/react"; +import { Form, Formik } from "formik"; +import { FiTrash } from "react-icons/fi"; +import { LLM_PROVIDERS_ADMIN_URL } from "./constants"; +import { + SelectorFormField, + TextFormField, +} from "@/components/admin/connectors/Field"; +import { useState } from "react"; +import { useSWRConfig } from "swr"; +import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces"; +import { PopupSpec } from "@/components/admin/connectors/Popup"; +import * as Yup from "yup"; +import isEqual from "lodash/isEqual"; + +export function LLMProviderUpdateForm({ + llmProviderDescriptor, + onClose, + existingLlmProvider, + shouldMarkAsDefault, + setPopup, +}: { + llmProviderDescriptor: WellKnownLLMProviderDescriptor; + onClose: () => void; + existingLlmProvider?: FullLLMProvider; + shouldMarkAsDefault?: boolean; + setPopup?: (popup: PopupSpec) => void; +}) { + const { mutate } = useSWRConfig(); + + const [isTesting, setIsTesting] = useState(false); + const [testError, setTestError] = useState(""); + const [isTestSuccessful, setTestSuccessful] = useState( + existingLlmProvider ? true : false + ); + + // Define the initial values based on the provider's requirements + const initialValues = { + api_key: existingLlmProvider?.api_key ?? "", + api_base: existingLlmProvider?.api_base ?? "", + api_version: existingLlmProvider?.api_version ?? "", + default_model_name: + existingLlmProvider?.default_model_name ?? + (llmProviderDescriptor.default_model || + llmProviderDescriptor.llm_names[0]), + default_fast_model_name: + existingLlmProvider?.fast_default_model_name ?? + (llmProviderDescriptor.default_fast_model || null), + custom_config: + existingLlmProvider?.custom_config ?? + llmProviderDescriptor.custom_config_keys?.reduce( + (acc, key) => { + acc[key] = ""; + return acc; + }, + {} as { [key: string]: string } + ), + }; + + const [validatedConfig, setValidatedConfig] = useState( + existingLlmProvider ? initialValues : null + ); + + // Setup validation schema if required + const validationSchema = Yup.object({ + api_key: llmProviderDescriptor.api_key_required + ? Yup.string().required("API Key is required") + : Yup.string(), + api_base: llmProviderDescriptor.api_base_required + ? Yup.string().required("API Base is required") + : Yup.string(), + api_version: llmProviderDescriptor.api_version_required + ? Yup.string().required("API Version is required") + : Yup.string(), + ...(llmProviderDescriptor.custom_config_keys + ? { + custom_config: Yup.object( + llmProviderDescriptor.custom_config_keys.reduce( + (acc, key) => { + acc[key] = Yup.string().required(`${key} is required`); + return acc; + }, + {} as { [key: string]: Yup.StringSchema } + ) + ), + } + : {}), + default_model_name: Yup.string().required("Model name is required"), + default_fast_model_name: Yup.string().nullable(), + }); + + return ( + { + if (!isEqual(values, validatedConfig)) { + setTestSuccessful(false); + } + }} + onSubmit={async (values, { setSubmitting }) => { + setSubmitting(true); + + if (!isTestSuccessful) { + setSubmitting(false); + return; + } + + const response = await fetch(LLM_PROVIDERS_ADMIN_URL, { + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + name: llmProviderDescriptor.name, + ...values, + fast_default_model_name: + values.default_fast_model_name || values.default_model_name, + }), + }); + + if (!response.ok) { + const errorMsg = (await response.json()).detail; + const fullErrorMsg = existingLlmProvider + ? `Failed to update provider: ${errorMsg}` + : `Failed to enable provider: ${errorMsg}`; + if (setPopup) { + setPopup({ + type: "error", + message: fullErrorMsg, + }); + } else { + alert(fullErrorMsg); + } + return; + } + + if (shouldMarkAsDefault) { + const newLlmProvider = (await response.json()) as FullLLMProvider; + const setDefaultResponse = await fetch( + `${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`, + { + method: "POST", + } + ); + if (!setDefaultResponse.ok) { + const errorMsg = (await setDefaultResponse.json()).detail; + const fullErrorMsg = `Failed to set provider as default: ${errorMsg}`; + if (setPopup) { + setPopup({ + type: "error", + message: fullErrorMsg, + }); + } else { + alert(fullErrorMsg); + } + return; + } + } + + mutate(LLM_PROVIDERS_ADMIN_URL); + onClose(); + + const successMsg = existingLlmProvider + ? "Provider updated successfully!" + : "Provider enabled successfully!"; + if (setPopup) { + setPopup({ + type: "success", + message: successMsg, + }); + } else { + alert(successMsg); + } + + setSubmitting(false); + }} + > + {({ values }) => ( +
+ {llmProviderDescriptor.api_key_required && ( + + )} + + {llmProviderDescriptor.api_base_required && ( + + )} + + {llmProviderDescriptor.api_version_required && ( + + )} + + {llmProviderDescriptor.custom_config_keys?.map((key) => ( +
+ +
+ ))} + + + + {llmProviderDescriptor.llm_names.length > 0 ? ( + ({ + name, + value: name, + }))} + direction="up" + maxHeight="max-h-56" + /> + ) : ( + + )} + + {llmProviderDescriptor.llm_names.length > 0 ? ( + ({ + name, + value: name, + }))} + includeDefault + direction="up" + maxHeight="max-h-56" + /> + ) : ( + + )} + + + +
+ {/* NOTE: this is above the test button to make sure it's visible */} + {!isTestSuccessful && testError && ( + {testError} + )} + {isTestSuccessful && ( + + Test successful! LLM provider is ready to go. + + )} + +
+ {isTestSuccessful ? ( + + ) : ( + + )} + {existingLlmProvider && ( + + )} +
+
+ + )} +
+ ); +} diff --git a/web/src/app/admin/models/llm/constants.ts b/web/src/app/admin/models/llm/constants.ts new file mode 100644 index 000000000..2db434ee9 --- /dev/null +++ b/web/src/app/admin/models/llm/constants.ts @@ -0,0 +1 @@ +export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider"; diff --git a/web/src/app/admin/models/llm/interfaces.ts b/web/src/app/admin/models/llm/interfaces.ts new file mode 100644 index 000000000..f9a5bf558 --- /dev/null +++ b/web/src/app/admin/models/llm/interfaces.ts @@ -0,0 +1,29 @@ +export interface WellKnownLLMProviderDescriptor { + name: string; + display_name: string | null; + + api_key_required: boolean; + api_base_required: boolean; + api_version_required: boolean; + custom_config_keys: string[] | null; + + llm_names: string[]; + default_model: string | null; + default_fast_model: string | null; +} + +export interface LLMProvider { + name: string; + api_key: string | null; + api_base: string | null; + api_version: string | null; + custom_config: { [key: string]: string } | null; + default_model_name: string; + fast_default_model_name: string | null; +} + +export interface FullLLMProvider extends LLMProvider { + id: number; + is_default_provider: boolean | null; + model_names: string[]; +} diff --git a/web/src/app/admin/keys/openai/page.tsx b/web/src/app/admin/models/llm/page.tsx similarity index 76% rename from web/src/app/admin/keys/openai/page.tsx rename to web/src/app/admin/models/llm/page.tsx index 0d122e80f..330718ee6 100644 --- a/web/src/app/admin/keys/openai/page.tsx +++ b/web/src/app/admin/models/llm/page.tsx @@ -2,7 +2,6 @@ import { Form, Formik } from "formik"; import { useEffect, useState } from "react"; -import { LoadingAnimation } from "@/components/Loading"; import { AdminPageTitle } from "@/components/admin/Title"; import { BooleanFormField, @@ -10,52 +9,9 @@ import { TextFormField, } from "@/components/admin/connectors/Field"; import { Popup } from "@/components/admin/connectors/Popup"; -import { TrashIcon } from "@/components/icons/icons"; -import { ApiKeyForm } from "@/components/openai/ApiKeyForm"; -import { GEN_AI_API_KEY_URL } from "@/components/openai/constants"; -import { fetcher } from "@/lib/fetcher"; -import { Button, Divider, Text, Title } from "@tremor/react"; +import { Button, Divider, Text } from "@tremor/react"; import { FiCpu } from "react-icons/fi"; -import useSWR, { mutate } from "swr"; - -const ExistingKeys = () => { - const { data, isLoading, error } = useSWR<{ api_key: string }>( - GEN_AI_API_KEY_URL, - fetcher - ); - - if (isLoading) { - return ; - } - - if (error) { - return
Error loading existing keys
; - } - - if (!data?.api_key) { - return null; - } - - return ( -
- Existing Key -
-

sk- ****...**{data?.api_key}

- -
-
- ); -}; +import { LLMConfiguration } from "./LLMConfiguration"; const LLMOptions = () => { const [popup, setPopup] = useState<{ @@ -219,27 +175,12 @@ const Page = () => { return (
} /> - LLM Keys + - - - Update Key - - Specify an OpenAI API key and click the "Submit" button. - -
- { - if (response.ok) { - mutate(GEN_AI_API_KEY_URL); - } - }} - /> -
); diff --git a/web/src/app/chat/page.tsx b/web/src/app/chat/page.tsx index f99887048..ae846d395 100644 --- a/web/src/app/chat/page.tsx +++ b/web/src/app/chat/page.tsx @@ -20,7 +20,7 @@ import { WelcomeModal, hasCompletedWelcomeFlowSS, } from "@/components/initialSetup/welcome/WelcomeModalWrapper"; -import { ApiKeyModal } from "@/components/openai/ApiKeyModal"; +import { ApiKeyModal } from "@/components/llm/ApiKeyModal"; import { cookies } from "next/headers"; import { DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME } from "@/components/resizable/contants"; import { personaComparator } from "../admin/assistants/lib"; @@ -169,9 +169,9 @@ export default async function Page({ <> - {shouldShowWelcomeModal && } + {shouldShowWelcomeModal && } {!shouldShowWelcomeModal && !shouldDisplaySourcesIncompleteModal && ( - + )} {shouldDisplaySourcesIncompleteModal && ( diff --git a/web/src/app/search/page.tsx b/web/src/app/search/page.tsx index f4f217a97..47e97cf51 100644 --- a/web/src/app/search/page.tsx +++ b/web/src/app/search/page.tsx @@ -7,7 +7,7 @@ import { } from "@/lib/userSS"; import { redirect } from "next/navigation"; import { HealthCheckBanner } from "@/components/health/healthcheck"; -import { ApiKeyModal } from "@/components/openai/ApiKeyModal"; +import { ApiKeyModal } from "@/components/llm/ApiKeyModal"; import { fetchSS } from "@/lib/utilsSS"; import { CCPairBasicInfo, DocumentSet, Tag, User } from "@/lib/types"; import { cookies } from "next/headers"; @@ -147,10 +147,10 @@ export default async function Home() {
- {shouldShowWelcomeModal && } + {shouldShowWelcomeModal && } {!shouldShowWelcomeModal && !shouldDisplayNoSourcesModal && - !shouldDisplaySourcesIncompleteModal && } + !shouldDisplaySourcesIncompleteModal && } {shouldDisplayNoSourcesModal && } {shouldDisplaySourcesIncompleteModal && ( diff --git a/web/src/components/Dropdown.tsx b/web/src/components/Dropdown.tsx index 3cb1ba70d..9b42ec56b 100644 --- a/web/src/components/Dropdown.tsx +++ b/web/src/components/Dropdown.tsx @@ -1,6 +1,7 @@ import { ChangeEvent, FC, useEffect, useRef, useState } from "react"; import { ChevronDownIcon } from "./icons/icons"; import { FiCheck, FiChevronDown } from "react-icons/fi"; +import { Popover } from "./popover/Popover"; export interface Option { name: string; @@ -181,9 +182,11 @@ export function SearchMultiSelectDropdown({ export const CustomDropdown = ({ children, dropdown, + direction = "down", // Default to 'down' if not specified }: { children: JSX.Element | string; dropdown: JSX.Element | string; + direction?: "up" | "down"; }) => { const [isOpen, setIsOpen] = useState(false); const dropdownRef = useRef(null); @@ -211,7 +214,9 @@ export const CustomDropdown = ({ {isOpen && (
setIsOpen(!isOpen)} - className="pt-2 absolute bottom w-full z-30 box-shadow" + className={`absolute ${ + direction === "up" ? "bottom-full pb-2" : "pt-2 " + } w-full z-30 box-shadow`} > {dropdown}
@@ -281,16 +286,21 @@ export function DefaultDropdown({ selected, onSelect, includeDefault = false, + direction = "down", + maxHeight, }: { options: StringOrNumberOption[]; selected: string | null; onSelect: (value: string | number | null) => void; includeDefault?: boolean; + direction?: "up" | "down"; + maxHeight?: string; }) { const selectedOption = options.find((option) => option.value === selected); return ( diff --git a/web/src/components/admin/Layout.tsx b/web/src/components/admin/Layout.tsx index f02295a98..e403b00a1 100644 --- a/web/src/components/admin/Layout.tsx +++ b/web/src/components/admin/Layout.tsx @@ -154,7 +154,7 @@ export async function Layout({ children }: { children: React.ReactNode }) {
LLM
), - link: "/admin/keys/openai", + link: "/admin/models/llm", }, { name: ( diff --git a/web/src/components/admin/connectors/Field.tsx b/web/src/components/admin/connectors/Field.tsx index 10816bd50..825dc3ad8 100644 --- a/web/src/components/admin/connectors/Field.tsx +++ b/web/src/components/admin/connectors/Field.tsx @@ -233,6 +233,9 @@ interface SelectorFormFieldProps { options: StringOrNumberOption[]; subtext?: string | JSX.Element; includeDefault?: boolean; + direction?: "up" | "down"; + maxHeight?: string; + onSelect?: (selected: string | number | null) => void; } export function SelectorFormField({ @@ -241,6 +244,9 @@ export function SelectorFormField({ options, subtext, includeDefault = false, + direction = "down", + maxHeight, + onSelect, }: SelectorFormFieldProps) { const [field] = useField(name); const { setFieldValue } = useFormikContext(); @@ -254,8 +260,10 @@ export function SelectorFormField({ setFieldValue(name, selected)} + onSelect={onSelect || ((selected) => setFieldValue(name, selected))} includeDefault={includeDefault} + direction={direction} + maxHeight={maxHeight} />
diff --git a/web/src/components/initialSetup/search/NoCompleteSourceModal.tsx b/web/src/components/initialSetup/search/NoCompleteSourceModal.tsx index 2a43a9e73..42aba7ab1 100644 --- a/web/src/components/initialSetup/search/NoCompleteSourceModal.tsx +++ b/web/src/components/initialSetup/search/NoCompleteSourceModal.tsx @@ -37,7 +37,7 @@ export function NoCompleteSourcesModal({ title="⏳ None of your connectors have finished a full sync yet" onOutsideClick={() => setIsHidden(true)} > -
+
You've connected some sources, but none of them have finished diff --git a/web/src/components/initialSetup/search/NoSourcesModal.tsx b/web/src/components/initialSetup/search/NoSourcesModal.tsx index 7510c9393..eb0cd9c02 100644 --- a/web/src/components/initialSetup/search/NoSourcesModal.tsx +++ b/web/src/components/initialSetup/search/NoSourcesModal.tsx @@ -1,6 +1,6 @@ "use client"; -import { Button, Divider } from "@tremor/react"; +import { Button, Divider, Text } from "@tremor/react"; import { Modal } from "../../Modal"; import Link from "next/link"; import { FiMessageSquare, FiShare2 } from "react-icons/fi"; @@ -21,11 +21,11 @@ export function NoSourcesModal() { >
-

+ Before using Search you'll need to connect at least one source. Without any connected knowledge sources, there isn't anything to search over. -

+ -
- - ) - } - -
- ); -}; diff --git a/web/src/components/openai/ApiKeyModal.tsx b/web/src/components/openai/ApiKeyModal.tsx deleted file mode 100644 index 1c38160e9..000000000 --- a/web/src/components/openai/ApiKeyModal.tsx +++ /dev/null @@ -1,68 +0,0 @@ -"use client"; - -import { useState, useEffect } from "react"; -import { ApiKeyForm } from "./ApiKeyForm"; -import { Modal } from "../Modal"; -import { Divider, Text } from "@tremor/react"; - -export async function checkApiKey() { - const response = await fetch("/api/manage/admin/genai-api-key/validate"); - if (!response.ok && (response.status === 404 || response.status === 400)) { - const jsonResponse = await response.json(); - return jsonResponse.detail; - } - return null; -} - -export const ApiKeyModal = () => { - const [errorMsg, setErrorMsg] = useState(null); - - useEffect(() => { - checkApiKey().then((error) => { - if (error) { - setErrorMsg(error); - } - }); - }, []); - - if (!errorMsg) { - return null; - } - - return ( - setErrorMsg(null)} - > -
-
-
- Please provide a valid OpenAI API key below in order to start using - Danswer Search or Danswer Chat. -
-
- Or if you'd rather look around first,{" "} - setErrorMsg(null)} - className="text-link cursor-pointer" - > - skip this step - - . -
- - - - { - if (response.ok) { - setErrorMsg(null); - } - }} - /> -
-
-
- ); -}; diff --git a/web/src/components/openai/constants.ts b/web/src/components/openai/constants.ts deleted file mode 100644 index 533e2bf54..000000000 --- a/web/src/components/openai/constants.ts +++ /dev/null @@ -1 +0,0 @@ -export const GEN_AI_API_KEY_URL = "/api/manage/admin/genai-api-key"; diff --git a/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts b/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts index bba6e8b53..614bcb81f 100644 --- a/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts +++ b/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts @@ -2,6 +2,7 @@ import { Persona } from "@/app/admin/assistants/interfaces"; import { CCPairBasicInfo, DocumentSet, User } from "../types"; import { getCurrentUserSS } from "../userSS"; import { fetchSS } from "../utilsSS"; +import { FullLLMProvider } from "@/app/admin/models/llm/interfaces"; export async function fetchAssistantEditorInfoSS( personaId?: number | string @@ -10,8 +11,7 @@ export async function fetchAssistantEditorInfoSS( { ccPairs: CCPairBasicInfo[]; documentSets: DocumentSet[]; - llmOverrideOptions: string[]; - defaultLLM: string; + llmProviders: FullLLMProvider[]; user: User | null; existingPersona: Persona | null; }, @@ -22,8 +22,7 @@ export async function fetchAssistantEditorInfoSS( const tasks = [ fetchSS("/manage/indexing-status"), fetchSS("/manage/document-set"), - fetchSS("/persona/utils/list-available-models"), - fetchSS("/persona/utils/default-model"), + fetchSS("/llm/provider"), // duplicate fetch, but shouldn't be too big of a deal // this page is not a high traffic page getCurrentUserSS(), @@ -37,15 +36,13 @@ export async function fetchAssistantEditorInfoSS( const [ ccPairsInfoResponse, documentSetsResponse, - llmOverridesResponse, - defaultLLMResponse, + llmProvidersResponse, user, personaResponse, ] = (await Promise.all(tasks)) as [ Response, Response, Response, - Response, User | null, Response | null, ]; @@ -66,21 +63,13 @@ export async function fetchAssistantEditorInfoSS( } const documentSets = (await documentSetsResponse.json()) as DocumentSet[]; - if (!llmOverridesResponse.ok) { + if (!llmProvidersResponse.ok) { return [ null, - `Failed to fetch LLM override options - ${await llmOverridesResponse.text()}`, + `Failed to fetch LLM providers - ${await llmProvidersResponse.text()}`, ]; } - const llmOverrideOptions = (await llmOverridesResponse.json()) as string[]; - - if (!defaultLLMResponse.ok) { - return [ - null, - `Failed to fetch default LLM - ${await defaultLLMResponse.text()}`, - ]; - } - const defaultLLM = (await defaultLLMResponse.json()) as string; + const llmProviders = (await llmProvidersResponse.json()) as FullLLMProvider[]; if (personaId && personaResponse && !personaResponse.ok) { return [null, `Failed to fetch Persona - ${await personaResponse.text()}`]; @@ -93,8 +82,7 @@ export async function fetchAssistantEditorInfoSS( { ccPairs, documentSets, - llmOverrideOptions, - defaultLLM, + llmProviders, user, existingPersona, },