ensure consistency of answers + update llm relevance prompting (#2045)

This commit is contained in:
pablodanswer 2024-08-05 08:27:15 -07:00 committed by GitHub
parent 66e4dded91
commit a3ea217f40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 100 additions and 57 deletions

View File

@ -11,7 +11,13 @@ it must contain information that is USEFUL for answering the query.
If the section contains ANY useful information, that is good enough, \
it does not need to fully answer the every part of the user query.
Title: {{title}}
{{optional_metadata}}
Reference Section:
```
{{chunk_text}}
```

View File

@ -212,11 +212,17 @@ def filter_sections(
section.center_chunk.content if use_chunk else section.combined_content
for section in sections_to_filter
]
metadata_list = [section.center_chunk.metadata for section in sections_to_filter]
titles = [
section.center_chunk.semantic_identifier for section in sections_to_filter
]
llm_chunk_selection = llm_batch_eval_sections(
query=query.query,
section_contents=contents,
llm=llm,
titles=titles,
metadata_list=metadata_list,
)
return [

View File

@ -12,17 +12,33 @@ from danswer.utils.threadpool_concurrency import run_functions_tuples_in_paralle
logger = setup_logger()
def llm_eval_section(query: str, section_content: str, llm: LLM) -> bool:
def llm_eval_section(
query: str,
section_content: str,
llm: LLM,
title: str,
metadata: dict[str, str | list[str]],
) -> bool:
def _get_metadata_str(metadata: dict[str, str | list[str]]) -> str:
metadata_str = "\n\nMetadata:\n"
for key, value in metadata.items():
value_str = ", ".join(value) if isinstance(value, list) else value
metadata_str += f"{key} - {value_str}\n"
return metadata_str + "\nContent:"
def _get_usefulness_messages() -> list[dict[str, str]]:
metadata_str = _get_metadata_str(metadata) if metadata else ""
messages = [
{
"role": "user",
"content": SECTION_FILTER_PROMPT.format(
chunk_text=section_content, user_query=query
title=title,
chunk_text=section_content,
user_query=query,
optional_metadata=metadata_str,
),
},
]
return messages
def _extract_usefulness(model_output: str) -> bool:
@ -34,9 +50,6 @@ def llm_eval_section(query: str, section_content: str, llm: LLM) -> bool:
messages = _get_usefulness_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
# When running in a batch, it takes as long as the longest thread
# And when running a large batch, one may fail and take the whole timeout
# instead cap it to 5 seconds
model_output = message_to_string(llm.invoke(filled_llm_prompt))
logger.debug(model_output)
@ -44,7 +57,12 @@ def llm_eval_section(query: str, section_content: str, llm: LLM) -> bool:
def llm_batch_eval_sections(
query: str, section_contents: list[str], llm: LLM, use_threads: bool = True
query: str,
section_contents: list[str],
llm: LLM,
titles: list[str],
metadata_list: list[dict[str, str | list[str]]],
use_threads: bool = True,
) -> list[bool]:
if DISABLE_LLM_DOC_RELEVANCE:
raise RuntimeError(
@ -54,8 +72,10 @@ def llm_batch_eval_sections(
if use_threads:
functions_with_args: list[tuple[Callable, tuple]] = [
(llm_eval_section, (query, section_content, llm))
for section_content in section_contents
(llm_eval_section, (query, section_content, llm, title, metadata))
for section_content, title, metadata in zip(
section_contents, titles, metadata_list
)
]
logger.debug(
@ -70,6 +90,8 @@ def llm_batch_eval_sections(
else:
return [
llm_eval_section(query, section_content, llm)
for section_content in section_contents
llm_eval_section(query, section_content, llm, title, metadata)
for section_content, title, metadata in zip(
section_contents, titles, metadata_list
)
]

View File

@ -14,6 +14,7 @@ import { useContext, useEffect, useState } from "react";
import { Tooltip } from "../tooltip/Tooltip";
import KeyboardSymbol from "@/lib/browserUtilities";
import { SettingsContext } from "../settings/SettingsProvider";
import { DISABLE_LLM_DOC_RELEVANCE } from "@/lib/constants";
const getSelectedDocumentIds = (
documents: SearchDanswerDocument[],
@ -140,6 +141,7 @@ export const SearchResultsDisplay = ({
showAll ||
(searchResponse &&
searchResponse.additional_relevance &&
searchResponse.additional_relevance[doc.document_id] &&
searchResponse.additional_relevance[doc.document_id].relevant) ||
doc.is_relevant
);
@ -175,46 +177,48 @@ export const SearchResultsDisplay = ({
<div className="mt-4">
<div className="font-bold flex justify-between text-emphasis border-b mb-3 pb-1 border-border text-lg">
<p>Results</p>
{(contentEnriched || searchResponse.additional_relevance) && (
<Tooltip delayDuration={1000} content={`${commandSymbol}O`}>
<button
onClick={() => {
performSweep();
if (agenticResults) {
setShowAll((showAll) => !showAll);
}
}}
className={`flex items-center justify-center animate-fade-in-up rounded-lg p-1 text-xs transition-all duration-300 w-20 h-8 ${
!sweep
? "bg-green-500 text-text-800"
: "bg-rose-700 text-text-100"
}`}
style={{
transform: sweep ? "rotateZ(180deg)" : "rotateZ(0deg)",
}}
>
<div
className={`flex items-center ${sweep ? "rotate-180" : ""}`}
{!DISABLE_LLM_DOC_RELEVANCE &&
(contentEnriched || searchResponse.additional_relevance) && (
<Tooltip delayDuration={1000} content={`${commandSymbol}O`}>
<button
onClick={() => {
performSweep();
if (agenticResults) {
setShowAll((showAll) => !showAll);
}
}}
className={`flex items-center justify-center animate-fade-in-up rounded-lg p-1 text-xs transition-all duration-300 w-20 h-8 ${
!sweep
? "bg-green-500 text-text-800"
: "bg-rose-700 text-text-100"
}`}
style={{
transform: sweep ? "rotateZ(180deg)" : "rotateZ(0deg)",
}}
>
<span></span>
{!sweep
? agenticResults
? "Show All"
: "Focus"
: agenticResults
? "Focus"
: "Show All"}
<span className="ml-1">
{!sweep ? (
<BroomIcon className="h-4 w-4" />
) : (
<UndoIcon className="h-4 w-4" />
)}
</span>
</div>
</button>
</Tooltip>
)}
<div
className={`flex items-center ${sweep ? "rotate-180" : ""}`}
>
<span></span>
{!sweep
? agenticResults
? "Show All"
: "Focus"
: agenticResults
? "Focus"
: "Show All"}
<span className="ml-1">
{!sweep ? (
<BroomIcon className="h-4 w-4" />
) : (
<UndoIcon className="h-4 w-4" />
)}
</span>
</div>
</button>
</Tooltip>
)}
</div>
{agenticResults &&

View File

@ -170,11 +170,12 @@ export const SearchSection = ({
if (existingSearchIdRaw == null) {
return;
}
function extractFirstUserMessage(
chatSession: SearchSession
function extractFirstMessageByType(
chatSession: SearchSession,
messageType: "user" | "assistant"
): string | null {
const userMessage = chatSession?.messages.find(
(msg) => msg.message_type === "user"
(msg) => msg.message_type === messageType
);
return userMessage ? userMessage.message : null;
}
@ -184,14 +185,18 @@ export const SearchSection = ({
`/api/query/search-session/${existingSearchessionId}`
);
const searchSession = (await response.json()) as SearchSession;
const message = extractFirstUserMessage(searchSession);
const userMessage = extractFirstMessageByType(searchSession, "user");
const assistantMessage = extractFirstMessageByType(
searchSession,
"assistant"
);
if (message) {
setQuery(message);
if (userMessage) {
setQuery(userMessage);
const danswerDocs: SearchResponse = {
documents: searchSession.documents,
suggestedSearchType: null,
answer: null,
answer: assistantMessage || "Search response not found",
quotes: null,
selectedDocIndices: null,
error: null,