mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 21:33:56 +02:00
Add option to add citations to Personas + allow for more chunks if an LLM model override is specified
This commit is contained in:
@@ -382,10 +382,13 @@ def drop_messages_history_overflow(
|
|||||||
history_token_counts: list[int],
|
history_token_counts: list[int],
|
||||||
final_msg: BaseMessage,
|
final_msg: BaseMessage,
|
||||||
final_msg_token_count: int,
|
final_msg_token_count: int,
|
||||||
|
max_allowed_tokens: int | None,
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
"""As message history grows, messages need to be dropped starting from the furthest in the past.
|
||||||
The System message should be kept if at all possible and the latest user input which is inserted in the
|
The System message should be kept if at all possible and the latest user input which is inserted in the
|
||||||
prompt template must be included"""
|
prompt template must be included"""
|
||||||
|
if max_allowed_tokens is None:
|
||||||
|
max_allowed_tokens = GEN_AI_MAX_INPUT_TOKENS
|
||||||
|
|
||||||
if len(history_msgs) != len(history_token_counts):
|
if len(history_msgs) != len(history_token_counts):
|
||||||
# This should never happen
|
# This should never happen
|
||||||
@@ -395,7 +398,9 @@ def drop_messages_history_overflow(
|
|||||||
|
|
||||||
# Start dropping from the history if necessary
|
# Start dropping from the history if necessary
|
||||||
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
|
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
|
||||||
ind_prev_msg_start = find_last_index(all_tokens)
|
ind_prev_msg_start = find_last_index(
|
||||||
|
all_tokens, max_prompt_tokens=max_allowed_tokens
|
||||||
|
)
|
||||||
|
|
||||||
if system_msg and ind_prev_msg_start <= len(history_msgs):
|
if system_msg and ind_prev_msg_start <= len(history_msgs):
|
||||||
prompt.append(system_msg)
|
prompt.append(system_msg)
|
||||||
|
@@ -33,6 +33,7 @@ from danswer.db.chat import get_or_create_root_message
|
|||||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
|
from danswer.db.models import Persona
|
||||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.document_index.document_index_utils import get_index_name
|
from danswer.document_index.document_index_utils import get_index_name
|
||||||
@@ -42,6 +43,7 @@ from danswer.llm.exceptions import GenAIDisabledException
|
|||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.utils import get_default_llm_token_encode
|
from danswer.llm.utils import get_default_llm_token_encode
|
||||||
|
from danswer.llm.utils import get_llm_max_tokens
|
||||||
from danswer.llm.utils import translate_history_to_basemessages
|
from danswer.llm.utils import translate_history_to_basemessages
|
||||||
from danswer.search.models import OptionalSearchSetting
|
from danswer.search.models import OptionalSearchSetting
|
||||||
from danswer.search.models import RetrievalDetails
|
from danswer.search.models import RetrievalDetails
|
||||||
@@ -62,6 +64,7 @@ logger = setup_logger()
|
|||||||
def generate_ai_chat_response(
|
def generate_ai_chat_response(
|
||||||
query_message: ChatMessage,
|
query_message: ChatMessage,
|
||||||
history: list[ChatMessage],
|
history: list[ChatMessage],
|
||||||
|
persona: Persona,
|
||||||
context_docs: list[LlmDoc],
|
context_docs: list[LlmDoc],
|
||||||
doc_id_to_rank_map: dict[str, int],
|
doc_id_to_rank_map: dict[str, int],
|
||||||
llm: LLM | None,
|
llm: LLM | None,
|
||||||
@@ -109,6 +112,9 @@ def generate_ai_chat_response(
|
|||||||
history_token_counts=history_token_counts,
|
history_token_counts=history_token_counts,
|
||||||
final_msg=user_message,
|
final_msg=user_message,
|
||||||
final_msg_token_count=user_tokens,
|
final_msg_token_count=user_tokens,
|
||||||
|
max_allowed_tokens=get_llm_max_tokens(persona.llm_model_version_override)
|
||||||
|
if persona.llm_model_version_override
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Good Debug/Breakpoint
|
# Good Debug/Breakpoint
|
||||||
@@ -183,7 +189,9 @@ def stream_chat_message(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm = get_default_llm()
|
llm = get_default_llm(
|
||||||
|
gen_ai_model_version_override=persona.llm_model_version_override
|
||||||
|
)
|
||||||
except GenAIDisabledException:
|
except GenAIDisabledException:
|
||||||
llm = None
|
llm = None
|
||||||
|
|
||||||
@@ -408,6 +416,7 @@ def stream_chat_message(
|
|||||||
response_packets = generate_ai_chat_response(
|
response_packets = generate_ai_chat_response(
|
||||||
query_message=final_msg,
|
query_message=final_msg,
|
||||||
history=history_msgs,
|
history=history_msgs,
|
||||||
|
persona=persona,
|
||||||
context_docs=llm_docs,
|
context_docs=llm_docs,
|
||||||
doc_id_to_rank_map=doc_id_to_rank_map,
|
doc_id_to_rank_map=doc_id_to_rank_map,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
@@ -14,6 +14,7 @@ from langchain.schema.messages import BaseMessage
|
|||||||
from langchain.schema.messages import BaseMessageChunk
|
from langchain.schema.messages import BaseMessageChunk
|
||||||
from langchain.schema.messages import HumanMessage
|
from langchain.schema.messages import HumanMessage
|
||||||
from langchain.schema.messages import SystemMessage
|
from langchain.schema.messages import SystemMessage
|
||||||
|
from litellm import get_max_tokens # type: ignore
|
||||||
from tiktoken.core import Encoding
|
from tiktoken.core import Encoding
|
||||||
|
|
||||||
from danswer.configs.app_configs import LOG_LEVEL
|
from danswer.configs.app_configs import LOG_LEVEL
|
||||||
@@ -188,3 +189,11 @@ def test_llm(llm: LLM) -> bool:
|
|||||||
logger.warning(f"GenAI API key failed for the following reason: {e}")
|
logger.warning(f"GenAI API key failed for the following reason: {e}")
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_max_tokens(model_name: str) -> int | None:
|
||||||
|
"""Best effort attempt to get the max tokens for the LLM"""
|
||||||
|
try:
|
||||||
|
return get_max_tokens(model_name)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
@@ -62,7 +62,7 @@ def create_update_prompt(
|
|||||||
|
|
||||||
|
|
||||||
@basic_router.post("")
|
@basic_router.post("")
|
||||||
def create_persona(
|
def create_prompt(
|
||||||
create_prompt_request: CreatePromptRequest,
|
create_prompt_request: CreatePromptRequest,
|
||||||
user: User | None = Depends(current_admin_user),
|
user: User | None = Depends(current_admin_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
|
@@ -92,6 +92,8 @@ export function PersonaEditor({
|
|||||||
(documentSet) => documentSet.id
|
(documentSet) => documentSet.id
|
||||||
) ?? ([] as number[]),
|
) ?? ([] as number[]),
|
||||||
num_chunks: existingPersona?.num_chunks ?? null,
|
num_chunks: existingPersona?.num_chunks ?? null,
|
||||||
|
include_citations:
|
||||||
|
existingPersona?.prompts[0]?.include_citations ?? true,
|
||||||
llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false,
|
llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false,
|
||||||
llm_model_version_override:
|
llm_model_version_override:
|
||||||
existingPersona?.llm_model_version_override ?? null,
|
existingPersona?.llm_model_version_override ?? null,
|
||||||
@@ -107,6 +109,7 @@ export function PersonaEditor({
|
|||||||
disable_retrieval: Yup.boolean().required(),
|
disable_retrieval: Yup.boolean().required(),
|
||||||
document_set_ids: Yup.array().of(Yup.number()),
|
document_set_ids: Yup.array().of(Yup.number()),
|
||||||
num_chunks: Yup.number().max(20).nullable(),
|
num_chunks: Yup.number().max(20).nullable(),
|
||||||
|
include_citations: Yup.boolean().required(),
|
||||||
llm_relevance_filter: Yup.boolean().required(),
|
llm_relevance_filter: Yup.boolean().required(),
|
||||||
llm_model_version_override: Yup.string().nullable(),
|
llm_model_version_override: Yup.string().nullable(),
|
||||||
})
|
})
|
||||||
@@ -240,6 +243,18 @@ export function PersonaEditor({
|
|||||||
error={finalPromptError}
|
error={finalPromptError}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
{!values.disable_retrieval && (
|
||||||
|
<BooleanFormField
|
||||||
|
name="include_citations"
|
||||||
|
label="Include Citations"
|
||||||
|
subtext={`
|
||||||
|
If set, the response will include bracket citations ([1], [2], etc.)
|
||||||
|
for each document used by the LLM to help inform the response. This is
|
||||||
|
the same technique used by the default Personas. In general, we recommend
|
||||||
|
to leave this enabled in order to increase trust in the LLM answer.`}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
<BooleanFormField
|
<BooleanFormField
|
||||||
name="disable_retrieval"
|
name="disable_retrieval"
|
||||||
label="Disable Retrieval"
|
label="Disable Retrieval"
|
||||||
|
@@ -76,10 +76,12 @@ export function PersonasTable({ personas }: { personas: Persona[] }) {
|
|||||||
id: persona.id.toString(),
|
id: persona.id.toString(),
|
||||||
cells: [
|
cells: [
|
||||||
<div key="name" className="flex">
|
<div key="name" className="flex">
|
||||||
<FiEdit
|
{!persona.default_persona && (
|
||||||
className="mr-1 my-auto cursor-pointer"
|
<FiEdit
|
||||||
onClick={() => router.push(`/admin/personas/${persona.id}`)}
|
className="mr-1 my-auto cursor-pointer"
|
||||||
/>
|
onClick={() => router.push(`/admin/personas/${persona.id}`)}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
<p className="text font-medium whitespace-normal break-none">
|
<p className="text font-medium whitespace-normal break-none">
|
||||||
{persona.name}
|
{persona.name}
|
||||||
</p>
|
</p>
|
||||||
@@ -129,9 +131,10 @@ export function PersonasTable({ personas }: { personas: Persona[] }) {
|
|||||||
</div>
|
</div>
|
||||||
</div>,
|
</div>,
|
||||||
<div key="edit" className="flex">
|
<div key="edit" className="flex">
|
||||||
<div className="mx-auto my-auto hover:bg-hover rounded p-1 cursor-pointer">
|
<div className="mx-auto my-auto">
|
||||||
{!persona.default_persona ? (
|
{!persona.default_persona ? (
|
||||||
<div
|
<div
|
||||||
|
className="hover:bg-hover rounded p-1 cursor-pointer"
|
||||||
onClick={async () => {
|
onClick={async () => {
|
||||||
const response = await deletePersona(persona.id);
|
const response = await deletePersona(persona.id);
|
||||||
if (response.ok) {
|
if (response.ok) {
|
||||||
|
@@ -7,6 +7,7 @@ interface PersonaCreationRequest {
|
|||||||
task_prompt: string;
|
task_prompt: string;
|
||||||
document_set_ids: number[];
|
document_set_ids: number[];
|
||||||
num_chunks: number | null;
|
num_chunks: number | null;
|
||||||
|
include_citations: boolean;
|
||||||
llm_relevance_filter: boolean | null;
|
llm_relevance_filter: boolean | null;
|
||||||
llm_model_version_override: string | null;
|
llm_model_version_override: string | null;
|
||||||
}
|
}
|
||||||
@@ -20,6 +21,7 @@ interface PersonaUpdateRequest {
|
|||||||
task_prompt: string;
|
task_prompt: string;
|
||||||
document_set_ids: number[];
|
document_set_ids: number[];
|
||||||
num_chunks: number | null;
|
num_chunks: number | null;
|
||||||
|
include_citations: boolean;
|
||||||
llm_relevance_filter: boolean | null;
|
llm_relevance_filter: boolean | null;
|
||||||
llm_model_version_override: string | null;
|
llm_model_version_override: string | null;
|
||||||
}
|
}
|
||||||
@@ -32,10 +34,12 @@ function createPrompt({
|
|||||||
personaName,
|
personaName,
|
||||||
systemPrompt,
|
systemPrompt,
|
||||||
taskPrompt,
|
taskPrompt,
|
||||||
|
includeCitations,
|
||||||
}: {
|
}: {
|
||||||
personaName: string;
|
personaName: string;
|
||||||
systemPrompt: string;
|
systemPrompt: string;
|
||||||
taskPrompt: string;
|
taskPrompt: string;
|
||||||
|
includeCitations: boolean;
|
||||||
}) {
|
}) {
|
||||||
return fetch("/api/prompt", {
|
return fetch("/api/prompt", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
@@ -48,6 +52,7 @@ function createPrompt({
|
|||||||
shared: true,
|
shared: true,
|
||||||
system_prompt: systemPrompt,
|
system_prompt: systemPrompt,
|
||||||
task_prompt: taskPrompt,
|
task_prompt: taskPrompt,
|
||||||
|
include_citations: includeCitations,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -57,11 +62,13 @@ function updatePrompt({
|
|||||||
personaName,
|
personaName,
|
||||||
systemPrompt,
|
systemPrompt,
|
||||||
taskPrompt,
|
taskPrompt,
|
||||||
|
includeCitations,
|
||||||
}: {
|
}: {
|
||||||
promptId: number;
|
promptId: number;
|
||||||
personaName: string;
|
personaName: string;
|
||||||
systemPrompt: string;
|
systemPrompt: string;
|
||||||
taskPrompt: string;
|
taskPrompt: string;
|
||||||
|
includeCitations: boolean;
|
||||||
}) {
|
}) {
|
||||||
return fetch(`/api/prompt/${promptId}`, {
|
return fetch(`/api/prompt/${promptId}`, {
|
||||||
method: "PATCH",
|
method: "PATCH",
|
||||||
@@ -74,6 +81,7 @@ function updatePrompt({
|
|||||||
shared: true,
|
shared: true,
|
||||||
system_prompt: systemPrompt,
|
system_prompt: systemPrompt,
|
||||||
task_prompt: taskPrompt,
|
task_prompt: taskPrompt,
|
||||||
|
include_citations: includeCitations,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -112,6 +120,7 @@ export async function createPersona(
|
|||||||
personaName: personaCreationRequest.name,
|
personaName: personaCreationRequest.name,
|
||||||
systemPrompt: personaCreationRequest.system_prompt,
|
systemPrompt: personaCreationRequest.system_prompt,
|
||||||
taskPrompt: personaCreationRequest.task_prompt,
|
taskPrompt: personaCreationRequest.task_prompt,
|
||||||
|
includeCitations: personaCreationRequest.include_citations,
|
||||||
});
|
});
|
||||||
const promptId = createPromptResponse.ok
|
const promptId = createPromptResponse.ok
|
||||||
? (await createPromptResponse.json()).id
|
? (await createPromptResponse.json()).id
|
||||||
@@ -147,6 +156,7 @@ export async function updatePersona(
|
|||||||
personaName: personaUpdateRequest.name,
|
personaName: personaUpdateRequest.name,
|
||||||
systemPrompt: personaUpdateRequest.system_prompt,
|
systemPrompt: personaUpdateRequest.system_prompt,
|
||||||
taskPrompt: personaUpdateRequest.task_prompt,
|
taskPrompt: personaUpdateRequest.task_prompt,
|
||||||
|
includeCitations: personaUpdateRequest.include_citations,
|
||||||
});
|
});
|
||||||
promptId = existingPromptId;
|
promptId = existingPromptId;
|
||||||
} else {
|
} else {
|
||||||
@@ -154,6 +164,7 @@ export async function updatePersona(
|
|||||||
personaName: personaUpdateRequest.name,
|
personaName: personaUpdateRequest.name,
|
||||||
systemPrompt: personaUpdateRequest.system_prompt,
|
systemPrompt: personaUpdateRequest.system_prompt,
|
||||||
taskPrompt: personaUpdateRequest.task_prompt,
|
taskPrompt: personaUpdateRequest.task_prompt,
|
||||||
|
includeCitations: personaUpdateRequest.include_citations,
|
||||||
});
|
});
|
||||||
promptId = promptResponse.ok ? (await promptResponse.json()).id : null;
|
promptId = promptResponse.ok ? (await promptResponse.json()).id : null;
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user