Add model choice to Persona

This commit is contained in:
Weves
2023-12-07 00:11:21 -08:00
committed by Chris Weaver
parent 26e808d2a1
commit 56785e6065
11 changed files with 217 additions and 18 deletions

View File

@@ -0,0 +1,26 @@
"""Add llm_model_version_override to Persona
Revision ID: baf71f781b9e
Revises: 50b683a8295c
Create Date: 2023-12-06 21:56:50.286158
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "baf71f781b9e"
down_revision = "50b683a8295c"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona",
sa.Column("llm_model_version_override", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("persona", "llm_model_version_override")

View File

@@ -324,6 +324,7 @@ def upsert_persona(
persona_id: int | None = None, persona_id: int | None = None,
default_persona: bool = False, default_persona: bool = False,
document_sets: list[DocumentSetDBModel] | None = None, document_sets: list[DocumentSetDBModel] | None = None,
llm_model_version_override: str | None = None,
commit: bool = True, commit: bool = True,
overwrite_duplicate_named_persona: bool = False, overwrite_duplicate_named_persona: bool = False,
) -> Persona: ) -> Persona:
@@ -355,6 +356,7 @@ def upsert_persona(
persona.num_chunks = num_chunks persona.num_chunks = num_chunks
persona.apply_llm_relevance_filter = apply_llm_relevance_filter persona.apply_llm_relevance_filter = apply_llm_relevance_filter
persona.default_persona = default_persona persona.default_persona = default_persona
persona.llm_model_version_override = llm_model_version_override
# Do not delete any associations manually added unless # Do not delete any associations manually added unless
# a new updated list is provided # a new updated list is provided
@@ -375,6 +377,7 @@ def upsert_persona(
apply_llm_relevance_filter=apply_llm_relevance_filter, apply_llm_relevance_filter=apply_llm_relevance_filter,
default_persona=default_persona, default_persona=default_persona,
document_sets=document_sets if document_sets else [], document_sets=document_sets if document_sets else [],
llm_model_version_override=llm_model_version_override,
) )
db_session.add(persona) db_session.add(persona)

View File

@@ -603,6 +603,13 @@ class Persona(Base):
apply_llm_relevance_filter: Mapped[bool | None] = mapped_column( apply_llm_relevance_filter: Mapped[bool | None] = mapped_column(
Boolean, nullable=True Boolean, nullable=True
) )
# allows the Persona to specify a different LLM version than is controlled
# globablly via env variables. For flexibility, validity is not currently enforced
# NOTE: only is applied on the actual response generation - is not used for things like
# auto-detected time filters, relevance filters, etc.
llm_model_version_override: Mapped[str | None] = mapped_column(
String, nullable=True
)
# Default personas are configured via backend during deployment # Default personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.) # Treated specially (cannot be user edited etc.)
default_persona: Mapped[bool] = mapped_column(Boolean, default=False) default_persona: Mapped[bool] = mapped_column(Boolean, default=False)

View File

@@ -54,7 +54,11 @@ def get_qa_model_for_persona(
timeout: int = QA_TIMEOUT, timeout: int = QA_TIMEOUT,
) -> QAModel: ) -> QAModel:
return QABlock( return QABlock(
llm=get_default_llm(api_key=api_key, timeout=timeout), llm=get_default_llm(
api_key=api_key,
timeout=timeout,
gen_ai_model_version_override=persona.llm_model_version_override,
),
qa_handler=PersonaBasedQAHandler( qa_handler=PersonaBasedQAHandler(
system_prompt=persona.system_text or "", task_prompt=persona.hint_text or "" system_prompt=persona.system_text or "", task_prompt=persona.hint_text or ""
), ),

View File

@@ -14,10 +14,16 @@ def get_default_llm(
api_key: str | None = None, api_key: str | None = None,
timeout: int = QA_TIMEOUT, timeout: int = QA_TIMEOUT,
use_fast_llm: bool = False, use_fast_llm: bool = False,
gen_ai_model_version_override: str | None = None,
) -> LLM: ) -> LLM:
"""A single place to fetch the configured LLM for Danswer """A single place to fetch the configured LLM for Danswer
Also allows overriding certain LLM defaults""" Also allows overriding certain LLM defaults"""
model_version = FAST_GEN_AI_MODEL_VERSION if use_fast_llm else GEN_AI_MODEL_VERSION if gen_ai_model_version_override:
model_version = gen_ai_model_version_override
else:
model_version = (
FAST_GEN_AI_MODEL_VERSION if use_fast_llm else GEN_AI_MODEL_VERSION
)
if api_key is None: if api_key is None:
api_key = get_gen_ai_api_key() api_key = get_gen_ai_api_key()

View File

@@ -5,6 +5,8 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user from danswer.auth.users import current_user
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.chat import fetch_persona_by_id from danswer.db.chat import fetch_persona_by_id
from danswer.db.chat import fetch_personas from danswer.db.chat import fetch_personas
from danswer.db.chat import mark_persona_as_deleted from danswer.db.chat import mark_persona_as_deleted
@@ -50,6 +52,7 @@ def create_persona(
num_chunks=create_persona_request.num_chunks, num_chunks=create_persona_request.num_chunks,
apply_llm_relevance_filter=create_persona_request.apply_llm_relevance_filter, apply_llm_relevance_filter=create_persona_request.apply_llm_relevance_filter,
document_sets=document_sets, document_sets=document_sets,
llm_model_version_override=create_persona_request.llm_model_version_override,
) )
except ValueError as e: except ValueError as e:
logger.exception("Failed to update persona") logger.exception("Failed to update persona")
@@ -84,6 +87,7 @@ def update_persona(
num_chunks=update_persona_request.num_chunks, num_chunks=update_persona_request.num_chunks,
apply_llm_relevance_filter=update_persona_request.apply_llm_relevance_filter, apply_llm_relevance_filter=update_persona_request.apply_llm_relevance_filter,
document_sets=document_sets, document_sets=document_sets,
llm_model_version_override=update_persona_request.llm_model_version_override,
persona_id=persona_id, persona_id=persona_id,
) )
except ValueError as e: except ValueError as e:
@@ -134,3 +138,47 @@ def build_final_template_prompt(
system_prompt=system_prompt, task_prompt=task_prompt system_prompt=system_prompt, task_prompt=task_prompt
).build_dummy_prompt() ).build_dummy_prompt()
) )
"""Utility endpoints for selecting which model to use for a persona.
Putting here for now, since we have no other flows which use this."""
GPT_4_MODEL_VERSIONS = [
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
]
GPT_3_5_TURBO_MODEL_VERSIONS = [
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
]
@router.get("/persona-utils/list-available-models")
def list_available_model_versions(
_: User | None = Depends(current_admin_user),
) -> list[str]:
# currently only support selecting different models for OpenAI
if GEN_AI_MODEL_PROVIDER != "openai":
return []
return GPT_4_MODEL_VERSIONS + GPT_3_5_TURBO_MODEL_VERSIONS
@router.get("/persona-utils/default-model")
def get_default_model(
_: User | None = Depends(current_admin_user),
) -> str:
# currently only support selecting different models for OpenAI
if GEN_AI_MODEL_PROVIDER != "openai":
return ""
return GEN_AI_MODEL_VERSION

View File

@@ -12,6 +12,7 @@ class CreatePersonaRequest(BaseModel):
task_prompt: str task_prompt: str
num_chunks: int | None = None num_chunks: int | None = None
apply_llm_relevance_filter: bool | None = None apply_llm_relevance_filter: bool | None = None
llm_model_version_override: str | None = None
class PersonaSnapshot(BaseModel): class PersonaSnapshot(BaseModel):
@@ -21,6 +22,7 @@ class PersonaSnapshot(BaseModel):
system_prompt: str system_prompt: str
task_prompt: str task_prompt: str
document_sets: list[DocumentSet] document_sets: list[DocumentSet]
llm_model_version_override: str | None
@classmethod @classmethod
def from_model(cls, persona: Persona) -> "PersonaSnapshot": def from_model(cls, persona: Persona) -> "PersonaSnapshot":
@@ -34,6 +36,7 @@ class PersonaSnapshot(BaseModel):
DocumentSet.from_model(document_set_model) DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets for document_set_model in persona.document_sets
], ],
llm_model_version_override=persona.llm_model_version_override,
) )

View File

@@ -1,15 +1,8 @@
"use client"; "use client";
import { DocumentSet } from "@/lib/types"; import { DocumentSet } from "@/lib/types";
import { Button, Divider } from "@tremor/react"; import { Button, Divider, Text } from "@tremor/react";
import { import { ArrayHelpers, FieldArray, Form, Formik } from "formik";
ArrayHelpers,
ErrorMessage,
Field,
FieldArray,
Form,
Formik,
} from "formik";
import * as Yup from "yup"; import * as Yup from "yup";
import { buildFinalPrompt, createPersona, updatePersona } from "./lib"; import { buildFinalPrompt, createPersona, updatePersona } from "./lib";
@@ -20,6 +13,7 @@ import Link from "next/link";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { import {
BooleanFormField, BooleanFormField,
SelectorFormField,
TextFormField, TextFormField,
} from "@/components/admin/connectors/Field"; } from "@/components/admin/connectors/Field";
@@ -40,9 +34,13 @@ function SubLabel({ children }: { children: string | JSX.Element }) {
export function PersonaEditor({ export function PersonaEditor({
existingPersona, existingPersona,
documentSets, documentSets,
llmOverrideOptions,
defaultLLM,
}: { }: {
existingPersona?: Persona | null; existingPersona?: Persona | null;
documentSets: DocumentSet[]; documentSets: DocumentSet[];
llmOverrideOptions: string[];
defaultLLM: string;
}) { }) {
const router = useRouter(); const router = useRouter();
const { popup, setPopup } = usePopup(); const { popup, setPopup } = usePopup();
@@ -74,6 +72,7 @@ export function PersonaEditor({
<div className="dark"> <div className="dark">
{popup} {popup}
<Formik <Formik
enableReinitialize={true}
initialValues={{ initialValues={{
name: existingPersona?.name ?? "", name: existingPersona?.name ?? "",
description: existingPersona?.description ?? "", description: existingPersona?.description ?? "",
@@ -86,6 +85,8 @@ export function PersonaEditor({
num_chunks: existingPersona?.num_chunks ?? null, num_chunks: existingPersona?.num_chunks ?? null,
apply_llm_relevance_filter: apply_llm_relevance_filter:
existingPersona?.apply_llm_relevance_filter ?? false, existingPersona?.apply_llm_relevance_filter ?? false,
llm_model_version_override:
existingPersona?.llm_model_version_override ?? null,
}} }}
validationSchema={Yup.object().shape({ validationSchema={Yup.object().shape({
name: Yup.string().required("Must give the Persona a name!"), name: Yup.string().required("Must give the Persona a name!"),
@@ -101,6 +102,7 @@ export function PersonaEditor({
document_set_ids: Yup.array().of(Yup.number()), document_set_ids: Yup.array().of(Yup.number()),
num_chunks: Yup.number().max(20).nullable(), num_chunks: Yup.number().max(20).nullable(),
apply_llm_relevance_filter: Yup.boolean().required(), apply_llm_relevance_filter: Yup.boolean().required(),
llm_model_version_override: Yup.string().nullable(),
})} })}
onSubmit={async (values, formikHelpers) => { onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true); formikHelpers.setSubmitting(true);
@@ -259,6 +261,41 @@ export function PersonaEditor({
<Divider /> <Divider />
{llmOverrideOptions.length > 0 && defaultLLM && (
<>
<SectionHeader>[Advanced] Model Selection</SectionHeader>
<Text>
Pick which LLM to use for this Persona. If left as Default,
will use <b className="italic">{defaultLLM}</b>.
<br />
<br />
For more information on the different LLMs, checkout the{" "}
<a
href="https://platform.openai.com/docs/models"
target="_blank"
className="text-blue-500"
>
OpenAI docs
</a>
.
</Text>
<SelectorFormField
name="llm_model_version_override"
options={llmOverrideOptions.map((llmOption) => {
return {
name: llmOption,
value: llmOption,
};
})}
includeDefault={true}
/>
</>
)}
<Divider />
<SectionHeader>[Advanced] Retrieval Customization</SectionHeader> <SectionHeader>[Advanced] Retrieval Customization</SectionHeader>
<TextFormField <TextFormField

View File

@@ -6,13 +6,24 @@ import { DocumentSet } from "@/lib/types";
import { BackButton } from "@/components/BackButton"; import { BackButton } from "@/components/BackButton";
import { Card, Title } from "@tremor/react"; import { Card, Title } from "@tremor/react";
import { DeletePersonaButton } from "./DeletePersonaButton"; import { DeletePersonaButton } from "./DeletePersonaButton";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
export default async function Page({ export default async function Page({
params, params,
}: { }: {
params: { personaId: string }; params: { personaId: string };
}) { }) {
const personaResponse = await fetchSS(`/persona/${params.personaId}`); const [
personaResponse,
documentSetsResponse,
llmOverridesResponse,
defaultLLMResponse,
] = await Promise.all([
fetchSS(`/persona/${params.personaId}`),
fetchSS("/manage/document-set"),
fetchSS("/persona-utils/list-available-models"),
fetchSS("/persona-utils/default-model"),
]);
if (!personaResponse.ok) { if (!personaResponse.ok) {
return ( return (
@@ -23,8 +34,6 @@ export default async function Page({
); );
} }
const documentSetsResponse = await fetchSS("/manage/document-set");
if (!documentSetsResponse.ok) { if (!documentSetsResponse.ok) {
return ( return (
<ErrorCallout <ErrorCallout
@@ -34,18 +43,45 @@ export default async function Page({
); );
} }
if (!llmOverridesResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch LLM override options - ${await documentSetsResponse.text()}`}
/>
);
}
if (!defaultLLMResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch default LLM - ${await documentSetsResponse.text()}`}
/>
);
}
const documentSets = (await documentSetsResponse.json()) as DocumentSet[]; const documentSets = (await documentSetsResponse.json()) as DocumentSet[];
const persona = (await personaResponse.json()) as Persona; const persona = (await personaResponse.json()) as Persona;
const llmOverrideOptions = (await llmOverridesResponse.json()) as string[];
const defaultLLM = (await defaultLLMResponse.json()) as string;
return ( return (
<div className="dark"> <div className="dark">
<InstantSSRAutoRefresh />
<BackButton /> <BackButton />
<div className="pb-2 mb-4 flex"> <div className="pb-2 mb-4 flex">
<h1 className="text-3xl font-bold pl-2">Edit Persona</h1> <h1 className="text-3xl font-bold pl-2">Edit Persona</h1>
</div> </div>
<Card> <Card>
<PersonaEditor existingPersona={persona} documentSets={documentSets} /> <PersonaEditor
existingPersona={persona}
documentSets={documentSets}
llmOverrideOptions={llmOverrideOptions}
defaultLLM={defaultLLM}
/>
</Card> </Card>
<div className="mt-12"> <div className="mt-12">

View File

@@ -9,4 +9,5 @@ export interface Persona {
document_sets: DocumentSet[]; document_sets: DocumentSet[];
num_chunks?: number; num_chunks?: number;
apply_llm_relevance_filter?: boolean; apply_llm_relevance_filter?: boolean;
llm_model_version_override?: string;
} }

View File

@@ -8,7 +8,12 @@ import { Card } from "@tremor/react";
import { AdminPageTitle } from "@/components/admin/Title"; import { AdminPageTitle } from "@/components/admin/Title";
export default async function Page() { export default async function Page() {
const documentSetsResponse = await fetchSS("/manage/document-set"); const [documentSetsResponse, llmOverridesResponse, defaultLLMResponse] =
await Promise.all([
fetchSS("/manage/document-set"),
fetchSS("/persona-utils/list-available-models"),
fetchSS("/persona-utils/default-model"),
]);
if (!documentSetsResponse.ok) { if (!documentSetsResponse.ok) {
return ( return (
@@ -18,9 +23,28 @@ export default async function Page() {
/> />
); );
} }
const documentSets = (await documentSetsResponse.json()) as DocumentSet[]; const documentSets = (await documentSetsResponse.json()) as DocumentSet[];
if (!llmOverridesResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch LLM override options - ${await documentSetsResponse.text()}`}
/>
);
}
const llmOverrideOptions = (await llmOverridesResponse.json()) as string[];
if (!defaultLLMResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch default LLM - ${await documentSetsResponse.text()}`}
/>
);
}
const defaultLLM = (await defaultLLMResponse.json()) as string;
return ( return (
<div className="dark"> <div className="dark">
<BackButton /> <BackButton />
@@ -31,7 +55,11 @@ export default async function Page() {
/> />
<Card> <Card>
<PersonaEditor documentSets={documentSets} /> <PersonaEditor
documentSets={documentSets}
llmOverrideOptions={llmOverrideOptions}
defaultLLM={defaultLLM}
/>
</Card> </Card>
</div> </div>
); );