renames + fix of refined answer generation prompt

This commit is contained in:
joachim-danswer 2025-02-01 23:02:25 -08:00 committed by Evan Lohn
parent 71304e4228
commit e23dd0a3fa
18 changed files with 477 additions and 967 deletions

View File

@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
SubQuestionAnswerCheckUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_ANSWER_CHECK_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@ -40,7 +40,7 @@ def check_sub_answer(
)
msg = [
HumanMessage(
content=SUB_CHECK_PROMPT.format(
content=SUB_ANSWER_CHECK_PROMPT.format(
question=state.question,
base_answer=state.answer,
)

View File

@ -30,9 +30,11 @@ from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResul
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS,
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
SUB_QUESTION_ANSWER_TEMPLATE,
@ -90,7 +92,12 @@ def generate_initial_answer(
consolidated_context_docs, consolidated_context_docs
)
decomp_questions = []
sub_questions: list[str] = []
streamed_documents = (
relevant_docs
if len(relevant_docs) > 0
else state.orig_question_retrieved_documents[:15]
)
# Use the query info from the base document retrieval
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
@ -102,8 +109,8 @@ def generate_initial_answer(
relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses(
query=question,
reranked_sections=relevant_docs,
final_context_sections=relevant_docs,
reranked_sections=streamed_documents,
final_context_sections=streamed_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@ -140,35 +147,44 @@ def generate_initial_answer(
)
else:
decomp_answer_results = state.sub_question_results
sub_question_answer_results = state.sub_question_results
good_qa_list: list[str] = []
answered_sub_questions: list[str] = []
all_sub_questions: list[str] = [] # Separate list for tracking all questions
sub_question_num = 1
for idx, sub_question_answer_result in enumerate(
sub_question_answer_results, start=1
):
all_sub_questions.append(sub_question_answer_result.question)
for decomp_answer_result in decomp_answer_results:
decomp_questions.append(decomp_answer_result.question)
if (
decomp_answer_result.verified_high_quality
and len(decomp_answer_result.answer) > 0
and decomp_answer_result.answer != UNKNOWN_ANSWER
):
good_qa_list.append(
is_valid_answer = (
sub_question_answer_result.verified_high_quality
and sub_question_answer_result.answer
and sub_question_answer_result.answer != UNKNOWN_ANSWER
)
if is_valid_answer:
answered_sub_questions.append(
SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=decomp_answer_result.question,
sub_answer=decomp_answer_result.answer,
sub_question_num=sub_question_num,
sub_question=sub_question_answer_result.question,
sub_answer=sub_question_answer_result.answer,
sub_question_num=idx,
)
)
sub_question_num += 1
# Determine which base prompt to use given the sub-question information
if len(good_qa_list) > 0:
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
base_prompt = INITIAL_RAG_PROMPT
else:
sub_question_answer_str = ""
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
# Use list comprehension for joining answers and determine prompt type
sub_question_answer_str = (
"\n\n------\n\n".join(answered_sub_questions)
if answered_sub_questions
else ""
)
base_prompt = (
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
if answered_sub_questions
else INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
sub_questions = all_sub_questions # Replace the original assignment
model = graph_config.tooling.fast_llm
@ -275,7 +291,7 @@ def generate_initial_answer(
return InitialAnswerUpdate(
initial_answer=answer,
initial_agent_stats=initial_agent_stats,
generated_sub_questions=decomp_questions,
generated_sub_questions=sub_questions,
agent_base_end_time=agent_base_end_time,
agent_base_metrics=agent_base_metrics,
log_messages=[

View File

@ -23,10 +23,10 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS,
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
INITIAL_QUESTION_DECOMPOSITION_PROMPT,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -79,7 +79,7 @@ def decompose_orig_question(
)
else:
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format(
decomposition_prompt = INITIAL_QUESTION_DECOMPOSITION_PROMPT.format(
question=question, history=history
)

View File

@ -24,8 +24,9 @@ def format_orig_question_search_output(
sub_question_retrieval_stats = sub_question_retrieval_stats
return OrigQuestionRetrievalUpdate(
orig_question_verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
orig_question_sub_query_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
orig_question_retrieved_documents=state.expanded_retrieval_result.context_documents,
orig_question_retrieved_documents=state.retrieved_documents,
orig_question_retrieval_stats=sub_question_retrieval_stats,
log_messages=[],
)

View File

@ -1,11 +1,6 @@
from pydantic import BaseModel
from onyx.agents.agent_search.deep_search.main.states import (
OrigQuestionRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
QuestionRetrievalResult,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
@ -23,14 +18,14 @@ class BaseRawSearchInput(ExpandedRetrievalInput):
## Graph Output State
class BaseRawSearchOutput(BaseModel):
class BaseRawSearchOutput(OrigQuestionRetrievalUpdate):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
# base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
## Graph State

View File

@ -10,7 +10,9 @@ from onyx.agents.agent_search.deep_search.main.states import (
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
@ -28,7 +30,7 @@ def compare_answers(
initial_answer = state.initial_answer
refined_answer = state.refined_answer
compare_answers_prompt = ANSWER_COMPARISON_PROMPT.format(
compare_answers_prompt = INITIAL_REFINED_ANSWER_COMPARISON_PROMPT.format(
question=question, initial_answer=initial_answer, refined_answer=refined_answer
)

View File

@ -21,7 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES,
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -78,7 +78,7 @@ def create_refined_sub_questions(
msg = [
HumanMessage(
content=DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES.format(
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT.format(
question=question,
history=history,
entity_term_extraction_str=entity_term_extraction_str,

View File

@ -18,7 +18,14 @@ from onyx.agents.agent_search.shared_graph_utils.models import EntityExtractionR
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
from onyx.agents.agent_search.shared_graph_utils.models import Relationship
from onyx.agents.agent_search.shared_graph_utils.models import Term
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ENTITY_TERM_EXTRACTION_PROMPT,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
@ -57,11 +64,15 @@ def extract_entities_terms(
doc_context = format_docs(initial_search_docs)
doc_context = trim_prompt_piece(
graph_config.tooling.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
graph_config.tooling.fast_llm.config,
doc_context,
ENTITY_TERM_EXTRACTION_PROMPT + question,
)
msg = [
HumanMessage(
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
content=ENTITY_TERM_EXTRACTION_PROMPT.format(
question=question, context=doc_context
),
)
]
fast_llm = graph_config.tooling.fast_llm

View File

@ -28,12 +28,14 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS,
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
SUB_QUESTION_ANSWER_TEMPLATE_REVISED,
REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
SUB_QUESTION_ANSWER_TEMPLATE_REFINED,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -71,12 +73,17 @@ def generate_refined_answer(
verified_reranked_documents = state.verified_reranked_documents
sub_questions_cited_documents = state.cited_documents
all_original_question_documents = state.orig_question_retrieved_documents
original_question_verified_documents = (
state.orig_question_verified_reranked_documents
)
original_question_retrieved_documents = state.orig_question_retrieved_documents
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
counter = 0
for original_doc_number, original_doc in enumerate(all_original_question_documents):
for original_doc_number, original_doc in enumerate(
original_question_verified_documents
):
if original_doc_number not in sub_questions_cited_documents:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
@ -92,16 +99,22 @@ def generate_refined_answer(
consolidated_context_docs, consolidated_context_docs
)
streaming_docs = (
relevant_docs
if len(relevant_docs) > 0
else original_question_retrieved_documents[:15]
)
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
# stream refined answer docs
# stream refined answer docs, or original question docs if no relevant docs are found
relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses(
query=question,
reranked_sections=relevant_docs,
final_context_sections=relevant_docs,
reranked_sections=streaming_docs,
final_context_sections=streaming_docs,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@ -124,71 +137,62 @@ def generate_refined_answer(
else:
refined_doc_effectiveness = 10.0
decomp_answer_results = state.sub_question_results
sub_question_answer_results = state.sub_question_results
answered_qa_list: list[str] = []
decomp_questions = []
answered_sub_question_answer_list: list[str] = []
sub_questions: list[str] = []
initial_answered_sub_questions: set[str] = set()
refined_answered_sub_questions: set[str] = set()
initial_good_sub_questions: list[str] = []
new_revised_good_sub_questions: list[str] = []
for i, result in enumerate(sub_question_answer_results, 1):
question_level, _ = parse_question_id(result.question_id)
sub_questions.append(result.question)
sub_question_num = 1
for decomp_answer_result in decomp_answer_results:
question_level, question_num = parse_question_id(
decomp_answer_result.question_id
)
decomp_questions.append(decomp_answer_result.question)
if (
decomp_answer_result.verified_high_quality
and len(decomp_answer_result.answer) > 0
and decomp_answer_result.answer != UNKNOWN_ANSWER
result.verified_high_quality
and result.answer
and result.answer != UNKNOWN_ANSWER
):
if question_level == 0:
initial_good_sub_questions.append(decomp_answer_result.question)
sub_question_type = "initial"
else:
new_revised_good_sub_questions.append(decomp_answer_result.question)
sub_question_type = "refined"
answered_qa_list.append(
SUB_QUESTION_ANSWER_TEMPLATE_REVISED.format(
sub_question=decomp_answer_result.question,
sub_answer=decomp_answer_result.answer,
sub_question_num=sub_question_num,
sub_question_type = "initial" if question_level == 0 else "refined"
question_set = (
initial_answered_sub_questions
if question_level == 0
else refined_answered_sub_questions
)
question_set.add(result.question)
answered_sub_question_answer_list.append(
SUB_QUESTION_ANSWER_TEMPLATE_REFINED.format(
sub_question=result.question,
sub_answer=result.answer,
sub_question_num=i,
sub_question_type=sub_question_type,
)
)
sub_question_num += 1
initial_good_sub_questions = list(set(initial_good_sub_questions))
new_revised_good_sub_questions = list(set(new_revised_good_sub_questions))
total_good_sub_questions = list(
set(initial_good_sub_questions + new_revised_good_sub_questions)
# Calculate efficiency
total_answered_questions = (
initial_answered_sub_questions | refined_answered_sub_questions
)
revision_question_efficiency = (
len(total_answered_questions) / len(initial_answered_sub_questions)
if initial_answered_sub_questions
else 10.0
if refined_answered_sub_questions
else 1.0
)
if len(initial_good_sub_questions) > 0:
revision_question_efficiency: float = len(total_good_sub_questions) / len(
initial_good_sub_questions
)
elif len(new_revised_good_sub_questions) > 0:
revision_question_efficiency = 10.0
else:
revision_question_efficiency = 1.0
sub_question_answer_str = "\n\n------\n\n".join(list(set(answered_qa_list)))
# original answer
sub_question_answer_str = "\n\n------\n\n".join(
set(answered_sub_question_answer_list)
)
initial_answer = state.initial_answer or ""
# Determine which persona-specification prompt to use
# Determine which base prompt to use given the sub-question information
if len(answered_qa_list) > 0:
base_prompt = REVISED_RAG_PROMPT
else:
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
# Choose appropriate prompt template
base_prompt = (
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS
if answered_sub_question_answer_list
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
model = graph_config.tooling.fast_llm
relevant_docs_str = format_docs(relevant_docs)
@ -211,7 +215,7 @@ def generate_refined_answer(
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=relevant_docs,
relevant_docs=relevant_docs_str,
initial_answer=remove_document_citations(initial_answer)
if initial_answer
else None,
@ -221,8 +225,6 @@ def generate_refined_answer(
)
]
# Grader
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
for message in model.stream(msg):
@ -248,7 +250,7 @@ def generate_refined_answer(
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
streamed_tokens.append(content)
logger.info(
logger.debug(
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
)
dispatch_main_answer_stop_info(1, writer)

View File

@ -129,6 +129,9 @@ class OrigQuestionRetrievalUpdate(LoggerUpdate):
orig_question_retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
orig_question_verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
orig_question_sub_query_retrieval_results: list[QueryRetrievalResult] = []
orig_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()

View File

@ -7,6 +7,7 @@ from onyx.context.search.models import InferenceSection
class QuestionRetrievalResult(BaseModel):
expanded_query_results: list[QueryRetrievalResult] = []
retrieved_documents: list[InferenceSection] = []
verified_reranked_documents: list[InferenceSection] = []
context_documents: list[InferenceSection] = []
retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()

View File

@ -17,7 +17,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import (
REWRITE_PROMPT_MULTI_ORIGINAL,
QUERY_REWRITING_PROMPT,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -47,7 +47,7 @@ def expand_queries(
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
content=QUERY_REWRITING_PROMPT.format(question=question),
)
]

View File

@ -82,6 +82,7 @@ def format_results(
return ExpandedRetrievalUpdate(
expanded_retrieval_result=QuestionRetrievalResult(
expanded_query_results=state.query_retrieval_results,
retrieved_documents=state.retrieved_documents,
verified_reranked_documents=reranked_documents,
context_documents=state.reranked_documents,
retrieval_stats=sub_question_retrieval_stats,

View File

@ -13,7 +13,9 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
DOCUMENT_VERIFICATION_PROMPT,
)
def verify_documents(
@ -38,12 +40,12 @@ def verify_documents(
fast_llm = graph_config.tooling.fast_llm
document_content = trim_prompt_piece(
fast_llm.config, document_content, VERIFIER_PROMPT + question
fast_llm.config, document_content, DOCUMENT_VERIFICATION_PROMPT + question
)
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
content=DOCUMENT_VERIFICATION_PROMPT.format(
question=question, document_content=document_content
)
)

View File

@ -60,6 +60,9 @@ class ExpandedRetrievalUpdate(LoggerUpdate, BaseModel):
class ExpandedRetrievalOutput(LoggerUpdate, BaseModel):
expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
## Graph State

View File

@ -7,8 +7,8 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import (
AgentPromptEnrichmentComponents,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_FRAMING_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_QUESTION_RAG_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
@ -45,10 +45,12 @@ def build_sub_question_answer_prompt(
docs_str = "\n\n".join(docs_format_list)
docs_str = trim_prompt_piece(
config, docs_str, BASE_RAG_PROMPT_v2 + question + original_question + date_str
config,
docs_str,
SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
)
human_message = HumanMessage(
content=BASE_RAG_PROMPT_v2.format(
content=SUB_QUESTION_RAG_PROMPT.format(
question=question,
original_question=original_question,
context=docs_str,
@ -118,7 +120,7 @@ def build_history_prompt(config: GraphConfig, question: str) -> str:
if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH:
history = summarize_history(history, question, persona_base, model)
return HISTORY_PROMPT.format(history=history) if history else ""
return HISTORY_FRAMING_PROMPT.format(history=history) if history else ""
def get_prompt_enrichment_components(

File diff suppressed because it is too large Load Diff

View File

@ -73,7 +73,7 @@ def format_docs(docs: Sequence[InferenceSection]) -> str:
formatted_doc_list = []
for doc_num, doc in enumerate(docs):
formatted_doc_list.append(f"Document D{doc_num + 1}:\n{doc.combined_content}")
formatted_doc_list.append(f"**Document D{doc_num + 1}:\n{doc.combined_content}")
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)