mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 20:39:29 +02:00
Add option to add citations to Personas + allow for more chunks if an LLM model override is specified
This commit is contained in:
parent
cf4ede2130
commit
824677ca75
@ -382,10 +382,13 @@ def drop_messages_history_overflow(
|
||||
history_token_counts: list[int],
|
||||
final_msg: BaseMessage,
|
||||
final_msg_token_count: int,
|
||||
max_allowed_tokens: int | None,
|
||||
) -> list[BaseMessage]:
|
||||
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
||||
The System message should be kept if at all possible and the latest user input which is inserted in the
|
||||
prompt template must be included"""
|
||||
if max_allowed_tokens is None:
|
||||
max_allowed_tokens = GEN_AI_MAX_INPUT_TOKENS
|
||||
|
||||
if len(history_msgs) != len(history_token_counts):
|
||||
# This should never happen
|
||||
@ -395,7 +398,9 @@ def drop_messages_history_overflow(
|
||||
|
||||
# Start dropping from the history if necessary
|
||||
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
|
||||
ind_prev_msg_start = find_last_index(all_tokens)
|
||||
ind_prev_msg_start = find_last_index(
|
||||
all_tokens, max_prompt_tokens=max_allowed_tokens
|
||||
)
|
||||
|
||||
if system_msg and ind_prev_msg_start <= len(history_msgs):
|
||||
prompt.append(system_msg)
|
||||
|
@ -33,6 +33,7 @@ from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.document_index_utils import get_index_name
|
||||
@ -42,6 +43,7 @@ from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.llm.utils import get_llm_max_tokens
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
@ -62,6 +64,7 @@ logger = setup_logger()
|
||||
def generate_ai_chat_response(
|
||||
query_message: ChatMessage,
|
||||
history: list[ChatMessage],
|
||||
persona: Persona,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: dict[str, int],
|
||||
llm: LLM | None,
|
||||
@ -109,6 +112,9 @@ def generate_ai_chat_response(
|
||||
history_token_counts=history_token_counts,
|
||||
final_msg=user_message,
|
||||
final_msg_token_count=user_tokens,
|
||||
max_allowed_tokens=get_llm_max_tokens(persona.llm_model_version_override)
|
||||
if persona.llm_model_version_override
|
||||
else None,
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
@ -183,7 +189,9 @@ def stream_chat_message(
|
||||
)
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
llm = get_default_llm(
|
||||
gen_ai_model_version_override=persona.llm_model_version_override
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
llm = None
|
||||
|
||||
@ -408,6 +416,7 @@ def stream_chat_message(
|
||||
response_packets = generate_ai_chat_response(
|
||||
query_message=final_msg,
|
||||
history=history_msgs,
|
||||
persona=persona,
|
||||
context_docs=llm_docs,
|
||||
doc_id_to_rank_map=doc_id_to_rank_map,
|
||||
llm=llm,
|
||||
|
@ -14,6 +14,7 @@ from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import BaseMessageChunk
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from litellm import get_max_tokens # type: ignore
|
||||
from tiktoken.core import Encoding
|
||||
|
||||
from danswer.configs.app_configs import LOG_LEVEL
|
||||
@ -188,3 +189,11 @@ def test_llm(llm: LLM) -> bool:
|
||||
logger.warning(f"GenAI API key failed for the following reason: {e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_llm_max_tokens(model_name: str) -> int | None:
|
||||
"""Best effort attempt to get the max tokens for the LLM"""
|
||||
try:
|
||||
return get_max_tokens(model_name)
|
||||
except Exception:
|
||||
return None
|
||||
|
@ -62,7 +62,7 @@ def create_update_prompt(
|
||||
|
||||
|
||||
@basic_router.post("")
|
||||
def create_persona(
|
||||
def create_prompt(
|
||||
create_prompt_request: CreatePromptRequest,
|
||||
user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
|
@ -92,6 +92,8 @@ export function PersonaEditor({
|
||||
(documentSet) => documentSet.id
|
||||
) ?? ([] as number[]),
|
||||
num_chunks: existingPersona?.num_chunks ?? null,
|
||||
include_citations:
|
||||
existingPersona?.prompts[0]?.include_citations ?? true,
|
||||
llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false,
|
||||
llm_model_version_override:
|
||||
existingPersona?.llm_model_version_override ?? null,
|
||||
@ -107,6 +109,7 @@ export function PersonaEditor({
|
||||
disable_retrieval: Yup.boolean().required(),
|
||||
document_set_ids: Yup.array().of(Yup.number()),
|
||||
num_chunks: Yup.number().max(20).nullable(),
|
||||
include_citations: Yup.boolean().required(),
|
||||
llm_relevance_filter: Yup.boolean().required(),
|
||||
llm_model_version_override: Yup.string().nullable(),
|
||||
})
|
||||
@ -240,6 +243,18 @@ export function PersonaEditor({
|
||||
error={finalPromptError}
|
||||
/>
|
||||
|
||||
{!values.disable_retrieval && (
|
||||
<BooleanFormField
|
||||
name="include_citations"
|
||||
label="Include Citations"
|
||||
subtext={`
|
||||
If set, the response will include bracket citations ([1], [2], etc.)
|
||||
for each document used by the LLM to help inform the response. This is
|
||||
the same technique used by the default Personas. In general, we recommend
|
||||
to leave this enabled in order to increase trust in the LLM answer.`}
|
||||
/>
|
||||
)}
|
||||
|
||||
<BooleanFormField
|
||||
name="disable_retrieval"
|
||||
label="Disable Retrieval"
|
||||
|
@ -76,10 +76,12 @@ export function PersonasTable({ personas }: { personas: Persona[] }) {
|
||||
id: persona.id.toString(),
|
||||
cells: [
|
||||
<div key="name" className="flex">
|
||||
<FiEdit
|
||||
className="mr-1 my-auto cursor-pointer"
|
||||
onClick={() => router.push(`/admin/personas/${persona.id}`)}
|
||||
/>
|
||||
{!persona.default_persona && (
|
||||
<FiEdit
|
||||
className="mr-1 my-auto cursor-pointer"
|
||||
onClick={() => router.push(`/admin/personas/${persona.id}`)}
|
||||
/>
|
||||
)}
|
||||
<p className="text font-medium whitespace-normal break-none">
|
||||
{persona.name}
|
||||
</p>
|
||||
@ -129,9 +131,10 @@ export function PersonasTable({ personas }: { personas: Persona[] }) {
|
||||
</div>
|
||||
</div>,
|
||||
<div key="edit" className="flex">
|
||||
<div className="mx-auto my-auto hover:bg-hover rounded p-1 cursor-pointer">
|
||||
<div className="mx-auto my-auto">
|
||||
{!persona.default_persona ? (
|
||||
<div
|
||||
className="hover:bg-hover rounded p-1 cursor-pointer"
|
||||
onClick={async () => {
|
||||
const response = await deletePersona(persona.id);
|
||||
if (response.ok) {
|
||||
|
@ -7,6 +7,7 @@ interface PersonaCreationRequest {
|
||||
task_prompt: string;
|
||||
document_set_ids: number[];
|
||||
num_chunks: number | null;
|
||||
include_citations: boolean;
|
||||
llm_relevance_filter: boolean | null;
|
||||
llm_model_version_override: string | null;
|
||||
}
|
||||
@ -20,6 +21,7 @@ interface PersonaUpdateRequest {
|
||||
task_prompt: string;
|
||||
document_set_ids: number[];
|
||||
num_chunks: number | null;
|
||||
include_citations: boolean;
|
||||
llm_relevance_filter: boolean | null;
|
||||
llm_model_version_override: string | null;
|
||||
}
|
||||
@ -32,10 +34,12 @@ function createPrompt({
|
||||
personaName,
|
||||
systemPrompt,
|
||||
taskPrompt,
|
||||
includeCitations,
|
||||
}: {
|
||||
personaName: string;
|
||||
systemPrompt: string;
|
||||
taskPrompt: string;
|
||||
includeCitations: boolean;
|
||||
}) {
|
||||
return fetch("/api/prompt", {
|
||||
method: "POST",
|
||||
@ -48,6 +52,7 @@ function createPrompt({
|
||||
shared: true,
|
||||
system_prompt: systemPrompt,
|
||||
task_prompt: taskPrompt,
|
||||
include_citations: includeCitations,
|
||||
}),
|
||||
});
|
||||
}
|
||||
@ -57,11 +62,13 @@ function updatePrompt({
|
||||
personaName,
|
||||
systemPrompt,
|
||||
taskPrompt,
|
||||
includeCitations,
|
||||
}: {
|
||||
promptId: number;
|
||||
personaName: string;
|
||||
systemPrompt: string;
|
||||
taskPrompt: string;
|
||||
includeCitations: boolean;
|
||||
}) {
|
||||
return fetch(`/api/prompt/${promptId}`, {
|
||||
method: "PATCH",
|
||||
@ -74,6 +81,7 @@ function updatePrompt({
|
||||
shared: true,
|
||||
system_prompt: systemPrompt,
|
||||
task_prompt: taskPrompt,
|
||||
include_citations: includeCitations,
|
||||
}),
|
||||
});
|
||||
}
|
||||
@ -112,6 +120,7 @@ export async function createPersona(
|
||||
personaName: personaCreationRequest.name,
|
||||
systemPrompt: personaCreationRequest.system_prompt,
|
||||
taskPrompt: personaCreationRequest.task_prompt,
|
||||
includeCitations: personaCreationRequest.include_citations,
|
||||
});
|
||||
const promptId = createPromptResponse.ok
|
||||
? (await createPromptResponse.json()).id
|
||||
@ -147,6 +156,7 @@ export async function updatePersona(
|
||||
personaName: personaUpdateRequest.name,
|
||||
systemPrompt: personaUpdateRequest.system_prompt,
|
||||
taskPrompt: personaUpdateRequest.task_prompt,
|
||||
includeCitations: personaUpdateRequest.include_citations,
|
||||
});
|
||||
promptId = existingPromptId;
|
||||
} else {
|
||||
@ -154,6 +164,7 @@ export async function updatePersona(
|
||||
personaName: personaUpdateRequest.name,
|
||||
systemPrompt: personaUpdateRequest.system_prompt,
|
||||
taskPrompt: personaUpdateRequest.task_prompt,
|
||||
includeCitations: personaUpdateRequest.include_citations,
|
||||
});
|
||||
promptId = promptResponse.ok ? (await promptResponse.json()).id : null;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user