Add option to add citations to Personas + allow for more chunks if an LLM model override is specified

This commit is contained in:
Weves 2024-01-27 00:41:05 -08:00 committed by Chris Weaver
parent cf4ede2130
commit 824677ca75
7 changed files with 60 additions and 8 deletions

View File

@ -382,10 +382,13 @@ def drop_messages_history_overflow(
history_token_counts: list[int],
final_msg: BaseMessage,
final_msg_token_count: int,
max_allowed_tokens: int | None,
) -> list[BaseMessage]:
"""As message history grows, messages need to be dropped starting from the furthest in the past.
The System message should be kept if at all possible and the latest user input which is inserted in the
prompt template must be included"""
if max_allowed_tokens is None:
max_allowed_tokens = GEN_AI_MAX_INPUT_TOKENS
if len(history_msgs) != len(history_token_counts):
# This should never happen
@ -395,7 +398,9 @@ def drop_messages_history_overflow(
# Start dropping from the history if necessary
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
ind_prev_msg_start = find_last_index(all_tokens)
ind_prev_msg_start = find_last_index(
all_tokens, max_prompt_tokens=max_allowed_tokens
)
if system_msg and ind_prev_msg_start <= len(history_msgs):
prompt.append(system_msg)

View File

@ -33,6 +33,7 @@ from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import User
from danswer.document_index.document_index_utils import get_index_name
@ -42,6 +43,7 @@ from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_token_encode
from danswer.llm.utils import get_llm_max_tokens
from danswer.llm.utils import translate_history_to_basemessages
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
@ -62,6 +64,7 @@ logger = setup_logger()
def generate_ai_chat_response(
query_message: ChatMessage,
history: list[ChatMessage],
persona: Persona,
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
llm: LLM | None,
@ -109,6 +112,9 @@ def generate_ai_chat_response(
history_token_counts=history_token_counts,
final_msg=user_message,
final_msg_token_count=user_tokens,
max_allowed_tokens=get_llm_max_tokens(persona.llm_model_version_override)
if persona.llm_model_version_override
else None,
)
# Good Debug/Breakpoint
@ -183,7 +189,9 @@ def stream_chat_message(
)
try:
llm = get_default_llm()
llm = get_default_llm(
gen_ai_model_version_override=persona.llm_model_version_override
)
except GenAIDisabledException:
llm = None
@ -408,6 +416,7 @@ def stream_chat_message(
response_packets = generate_ai_chat_response(
query_message=final_msg,
history=history_msgs,
persona=persona,
context_docs=llm_docs,
doc_id_to_rank_map=doc_id_to_rank_map,
llm=llm,

View File

@ -14,6 +14,7 @@ from langchain.schema.messages import BaseMessage
from langchain.schema.messages import BaseMessageChunk
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from litellm import get_max_tokens # type: ignore
from tiktoken.core import Encoding
from danswer.configs.app_configs import LOG_LEVEL
@ -188,3 +189,11 @@ def test_llm(llm: LLM) -> bool:
logger.warning(f"GenAI API key failed for the following reason: {e}")
return False
def get_llm_max_tokens(model_name: str) -> int | None:
"""Best effort attempt to get the max tokens for the LLM"""
try:
return get_max_tokens(model_name)
except Exception:
return None

View File

@ -62,7 +62,7 @@ def create_update_prompt(
@basic_router.post("")
def create_persona(
def create_prompt(
create_prompt_request: CreatePromptRequest,
user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),

View File

@ -92,6 +92,8 @@ export function PersonaEditor({
(documentSet) => documentSet.id
) ?? ([] as number[]),
num_chunks: existingPersona?.num_chunks ?? null,
include_citations:
existingPersona?.prompts[0]?.include_citations ?? true,
llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false,
llm_model_version_override:
existingPersona?.llm_model_version_override ?? null,
@ -107,6 +109,7 @@ export function PersonaEditor({
disable_retrieval: Yup.boolean().required(),
document_set_ids: Yup.array().of(Yup.number()),
num_chunks: Yup.number().max(20).nullable(),
include_citations: Yup.boolean().required(),
llm_relevance_filter: Yup.boolean().required(),
llm_model_version_override: Yup.string().nullable(),
})
@ -240,6 +243,18 @@ export function PersonaEditor({
error={finalPromptError}
/>
{!values.disable_retrieval && (
<BooleanFormField
name="include_citations"
label="Include Citations"
subtext={`
If set, the response will include bracket citations ([1], [2], etc.)
for each document used by the LLM to help inform the response. This is
the same technique used by the default Personas. In general, we recommend
to leave this enabled in order to increase trust in the LLM answer.`}
/>
)}
<BooleanFormField
name="disable_retrieval"
label="Disable Retrieval"

View File

@ -76,10 +76,12 @@ export function PersonasTable({ personas }: { personas: Persona[] }) {
id: persona.id.toString(),
cells: [
<div key="name" className="flex">
<FiEdit
className="mr-1 my-auto cursor-pointer"
onClick={() => router.push(`/admin/personas/${persona.id}`)}
/>
{!persona.default_persona && (
<FiEdit
className="mr-1 my-auto cursor-pointer"
onClick={() => router.push(`/admin/personas/${persona.id}`)}
/>
)}
<p className="text font-medium whitespace-normal break-none">
{persona.name}
</p>
@ -129,9 +131,10 @@ export function PersonasTable({ personas }: { personas: Persona[] }) {
</div>
</div>,
<div key="edit" className="flex">
<div className="mx-auto my-auto hover:bg-hover rounded p-1 cursor-pointer">
<div className="mx-auto my-auto">
{!persona.default_persona ? (
<div
className="hover:bg-hover rounded p-1 cursor-pointer"
onClick={async () => {
const response = await deletePersona(persona.id);
if (response.ok) {

View File

@ -7,6 +7,7 @@ interface PersonaCreationRequest {
task_prompt: string;
document_set_ids: number[];
num_chunks: number | null;
include_citations: boolean;
llm_relevance_filter: boolean | null;
llm_model_version_override: string | null;
}
@ -20,6 +21,7 @@ interface PersonaUpdateRequest {
task_prompt: string;
document_set_ids: number[];
num_chunks: number | null;
include_citations: boolean;
llm_relevance_filter: boolean | null;
llm_model_version_override: string | null;
}
@ -32,10 +34,12 @@ function createPrompt({
personaName,
systemPrompt,
taskPrompt,
includeCitations,
}: {
personaName: string;
systemPrompt: string;
taskPrompt: string;
includeCitations: boolean;
}) {
return fetch("/api/prompt", {
method: "POST",
@ -48,6 +52,7 @@ function createPrompt({
shared: true,
system_prompt: systemPrompt,
task_prompt: taskPrompt,
include_citations: includeCitations,
}),
});
}
@ -57,11 +62,13 @@ function updatePrompt({
personaName,
systemPrompt,
taskPrompt,
includeCitations,
}: {
promptId: number;
personaName: string;
systemPrompt: string;
taskPrompt: string;
includeCitations: boolean;
}) {
return fetch(`/api/prompt/${promptId}`, {
method: "PATCH",
@ -74,6 +81,7 @@ function updatePrompt({
shared: true,
system_prompt: systemPrompt,
task_prompt: taskPrompt,
include_citations: includeCitations,
}),
});
}
@ -112,6 +120,7 @@ export async function createPersona(
personaName: personaCreationRequest.name,
systemPrompt: personaCreationRequest.system_prompt,
taskPrompt: personaCreationRequest.task_prompt,
includeCitations: personaCreationRequest.include_citations,
});
const promptId = createPromptResponse.ok
? (await createPromptResponse.json()).id
@ -147,6 +156,7 @@ export async function updatePersona(
personaName: personaUpdateRequest.name,
systemPrompt: personaUpdateRequest.system_prompt,
taskPrompt: personaUpdateRequest.task_prompt,
includeCitations: personaUpdateRequest.include_citations,
});
promptId = existingPromptId;
} else {
@ -154,6 +164,7 @@ export async function updatePersona(
personaName: personaUpdateRequest.name,
systemPrompt: personaUpdateRequest.system_prompt,
taskPrompt: personaUpdateRequest.task_prompt,
includeCitations: personaUpdateRequest.include_citations,
});
promptId = promptResponse.ok ? (await promptResponse.json()).id : null;
}