diff --git a/backend/alembic/versions/baf71f781b9e_add_llm_model_version_override_to_.py b/backend/alembic/versions/baf71f781b9e_add_llm_model_version_override_to_.py new file mode 100644 index 000000000000..1939ae78ff58 --- /dev/null +++ b/backend/alembic/versions/baf71f781b9e_add_llm_model_version_override_to_.py @@ -0,0 +1,26 @@ +"""Add llm_model_version_override to Persona + +Revision ID: baf71f781b9e +Revises: 50b683a8295c +Create Date: 2023-12-06 21:56:50.286158 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "baf71f781b9e" +down_revision = "50b683a8295c" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "persona", + sa.Column("llm_model_version_override", sa.String(), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("persona", "llm_model_version_override") diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 2e132bdfac48..1fb61404bb36 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -324,6 +324,7 @@ def upsert_persona( persona_id: int | None = None, default_persona: bool = False, document_sets: list[DocumentSetDBModel] | None = None, + llm_model_version_override: str | None = None, commit: bool = True, overwrite_duplicate_named_persona: bool = False, ) -> Persona: @@ -355,6 +356,7 @@ def upsert_persona( persona.num_chunks = num_chunks persona.apply_llm_relevance_filter = apply_llm_relevance_filter persona.default_persona = default_persona + persona.llm_model_version_override = llm_model_version_override # Do not delete any associations manually added unless # a new updated list is provided @@ -375,6 +377,7 @@ def upsert_persona( apply_llm_relevance_filter=apply_llm_relevance_filter, default_persona=default_persona, document_sets=document_sets if document_sets else [], + llm_model_version_override=llm_model_version_override, ) db_session.add(persona) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index f532930ed000..e458160e398e 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -603,6 +603,13 @@ class Persona(Base): apply_llm_relevance_filter: Mapped[bool | None] = mapped_column( Boolean, nullable=True ) + # allows the Persona to specify a different LLM version than is controlled + # globablly via env variables. For flexibility, validity is not currently enforced + # NOTE: only is applied on the actual response generation - is not used for things like + # auto-detected time filters, relevance filters, etc. + llm_model_version_override: Mapped[str | None] = mapped_column( + String, nullable=True + ) # Default personas are configured via backend during deployment # Treated specially (cannot be user edited etc.) default_persona: Mapped[bool] = mapped_column(Boolean, default=False) diff --git a/backend/danswer/direct_qa/factory.py b/backend/danswer/direct_qa/factory.py index 65e34e7b082b..dd471cbe8262 100644 --- a/backend/danswer/direct_qa/factory.py +++ b/backend/danswer/direct_qa/factory.py @@ -54,7 +54,11 @@ def get_qa_model_for_persona( timeout: int = QA_TIMEOUT, ) -> QAModel: return QABlock( - llm=get_default_llm(api_key=api_key, timeout=timeout), + llm=get_default_llm( + api_key=api_key, + timeout=timeout, + gen_ai_model_version_override=persona.llm_model_version_override, + ), qa_handler=PersonaBasedQAHandler( system_prompt=persona.system_text or "", task_prompt=persona.hint_text or "" ), diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index b4b06ac77500..a7a6a96a59c9 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -14,10 +14,16 @@ def get_default_llm( api_key: str | None = None, timeout: int = QA_TIMEOUT, use_fast_llm: bool = False, + gen_ai_model_version_override: str | None = None, ) -> LLM: """A single place to fetch the configured LLM for Danswer Also allows overriding certain LLM defaults""" - model_version = FAST_GEN_AI_MODEL_VERSION if use_fast_llm else GEN_AI_MODEL_VERSION + if gen_ai_model_version_override: + model_version = gen_ai_model_version_override + else: + model_version = ( + FAST_GEN_AI_MODEL_VERSION if use_fast_llm else GEN_AI_MODEL_VERSION + ) if api_key is None: api_key = get_gen_ai_api_key() diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index a3cd23594df1..919773120ea2 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -5,6 +5,8 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user from danswer.auth.users import current_user +from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.db.chat import fetch_persona_by_id from danswer.db.chat import fetch_personas from danswer.db.chat import mark_persona_as_deleted @@ -50,6 +52,7 @@ def create_persona( num_chunks=create_persona_request.num_chunks, apply_llm_relevance_filter=create_persona_request.apply_llm_relevance_filter, document_sets=document_sets, + llm_model_version_override=create_persona_request.llm_model_version_override, ) except ValueError as e: logger.exception("Failed to update persona") @@ -84,6 +87,7 @@ def update_persona( num_chunks=update_persona_request.num_chunks, apply_llm_relevance_filter=update_persona_request.apply_llm_relevance_filter, document_sets=document_sets, + llm_model_version_override=update_persona_request.llm_model_version_override, persona_id=persona_id, ) except ValueError as e: @@ -134,3 +138,47 @@ def build_final_template_prompt( system_prompt=system_prompt, task_prompt=task_prompt ).build_dummy_prompt() ) + + +"""Utility endpoints for selecting which model to use for a persona. +Putting here for now, since we have no other flows which use this.""" + +GPT_4_MODEL_VERSIONS = [ + "gpt-4-1106-preview", + "gpt-4", + "gpt-4-32k", + "gpt-4-0613", + "gpt-4-32k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", +] +GPT_3_5_TURBO_MODEL_VERSIONS = [ + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-3.5-turbo-0301", +] + + +@router.get("/persona-utils/list-available-models") +def list_available_model_versions( + _: User | None = Depends(current_admin_user), +) -> list[str]: + # currently only support selecting different models for OpenAI + if GEN_AI_MODEL_PROVIDER != "openai": + return [] + + return GPT_4_MODEL_VERSIONS + GPT_3_5_TURBO_MODEL_VERSIONS + + +@router.get("/persona-utils/default-model") +def get_default_model( + _: User | None = Depends(current_admin_user), +) -> str: + # currently only support selecting different models for OpenAI + if GEN_AI_MODEL_PROVIDER != "openai": + return "" + + return GEN_AI_MODEL_VERSION diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index 8c39c449aea1..fed8503a8f38 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -12,6 +12,7 @@ class CreatePersonaRequest(BaseModel): task_prompt: str num_chunks: int | None = None apply_llm_relevance_filter: bool | None = None + llm_model_version_override: str | None = None class PersonaSnapshot(BaseModel): @@ -21,6 +22,7 @@ class PersonaSnapshot(BaseModel): system_prompt: str task_prompt: str document_sets: list[DocumentSet] + llm_model_version_override: str | None @classmethod def from_model(cls, persona: Persona) -> "PersonaSnapshot": @@ -34,6 +36,7 @@ class PersonaSnapshot(BaseModel): DocumentSet.from_model(document_set_model) for document_set_model in persona.document_sets ], + llm_model_version_override=persona.llm_model_version_override, ) diff --git a/web/src/app/admin/personas/PersonaEditor.tsx b/web/src/app/admin/personas/PersonaEditor.tsx index cab3b11d703e..b34cee5ef679 100644 --- a/web/src/app/admin/personas/PersonaEditor.tsx +++ b/web/src/app/admin/personas/PersonaEditor.tsx @@ -1,15 +1,8 @@ "use client"; import { DocumentSet } from "@/lib/types"; -import { Button, Divider } from "@tremor/react"; -import { - ArrayHelpers, - ErrorMessage, - Field, - FieldArray, - Form, - Formik, -} from "formik"; +import { Button, Divider, Text } from "@tremor/react"; +import { ArrayHelpers, FieldArray, Form, Formik } from "formik"; import * as Yup from "yup"; import { buildFinalPrompt, createPersona, updatePersona } from "./lib"; @@ -20,6 +13,7 @@ import Link from "next/link"; import { useEffect, useState } from "react"; import { BooleanFormField, + SelectorFormField, TextFormField, } from "@/components/admin/connectors/Field"; @@ -40,9 +34,13 @@ function SubLabel({ children }: { children: string | JSX.Element }) { export function PersonaEditor({ existingPersona, documentSets, + llmOverrideOptions, + defaultLLM, }: { existingPersona?: Persona | null; documentSets: DocumentSet[]; + llmOverrideOptions: string[]; + defaultLLM: string; }) { const router = useRouter(); const { popup, setPopup } = usePopup(); @@ -74,6 +72,7 @@ export function PersonaEditor({
{popup} { formikHelpers.setSubmitting(true); @@ -259,6 +261,41 @@ export function PersonaEditor({ + {llmOverrideOptions.length > 0 && defaultLLM && ( + <> + [Advanced] Model Selection + + + Pick which LLM to use for this Persona. If left as Default, + will use {defaultLLM}. +
+
+ For more information on the different LLMs, checkout the{" "} + + OpenAI docs + + . +
+ + { + return { + name: llmOption, + value: llmOption, + }; + })} + includeDefault={true} + /> + + )} + + + [Advanced] Retrieval Customization + ); + } + + if (!defaultLLMResponse.ok) { + return ( + + ); + } + const documentSets = (await documentSetsResponse.json()) as DocumentSet[]; const persona = (await personaResponse.json()) as Persona; + const llmOverrideOptions = (await llmOverridesResponse.json()) as string[]; + const defaultLLM = (await defaultLLMResponse.json()) as string; return (
+ +

Edit Persona

- +
diff --git a/web/src/app/admin/personas/interfaces.ts b/web/src/app/admin/personas/interfaces.ts index 312d7da37ea8..aaa3f3d35aac 100644 --- a/web/src/app/admin/personas/interfaces.ts +++ b/web/src/app/admin/personas/interfaces.ts @@ -9,4 +9,5 @@ export interface Persona { document_sets: DocumentSet[]; num_chunks?: number; apply_llm_relevance_filter?: boolean; + llm_model_version_override?: string; } diff --git a/web/src/app/admin/personas/new/page.tsx b/web/src/app/admin/personas/new/page.tsx index ccb4de5a9e33..4eba0f570f5f 100644 --- a/web/src/app/admin/personas/new/page.tsx +++ b/web/src/app/admin/personas/new/page.tsx @@ -8,7 +8,12 @@ import { Card } from "@tremor/react"; import { AdminPageTitle } from "@/components/admin/Title"; export default async function Page() { - const documentSetsResponse = await fetchSS("/manage/document-set"); + const [documentSetsResponse, llmOverridesResponse, defaultLLMResponse] = + await Promise.all([ + fetchSS("/manage/document-set"), + fetchSS("/persona-utils/list-available-models"), + fetchSS("/persona-utils/default-model"), + ]); if (!documentSetsResponse.ok) { return ( @@ -18,9 +23,28 @@ export default async function Page() { /> ); } - const documentSets = (await documentSetsResponse.json()) as DocumentSet[]; + if (!llmOverridesResponse.ok) { + return ( + + ); + } + const llmOverrideOptions = (await llmOverridesResponse.json()) as string[]; + + if (!defaultLLMResponse.ok) { + return ( + + ); + } + const defaultLLM = (await defaultLLMResponse.json()) as string; + return (
@@ -31,7 +55,11 @@ export default async function Page() { /> - +
);