renames + fix of refined answer generation prompt

This commit is contained in:
joachim-danswer 2025-02-01 23:02:25 -08:00 committed by Evan Lohn
parent 71304e4228
commit e23dd0a3fa
18 changed files with 477 additions and 967 deletions

View File

@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
SubQuestionAnswerCheckUpdate, SubQuestionAnswerCheckUpdate,
) )
from onyx.agents.agent_search.models import GraphConfig from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_ANSWER_CHECK_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
from onyx.agents.agent_search.shared_graph_utils.utils import ( from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string, get_langgraph_node_log_string,
@ -40,7 +40,7 @@ def check_sub_answer(
) )
msg = [ msg = [
HumanMessage( HumanMessage(
content=SUB_CHECK_PROMPT.format( content=SUB_ANSWER_CHECK_PROMPT.format(
question=state.question, question=state.question,
base_answer=state.answer, base_answer=state.answer,
) )

View File

@ -30,9 +30,11 @@ from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResul
from onyx.agents.agent_search.shared_graph_utils.operators import ( from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections, dedup_inference_sections,
) )
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import ( from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS, INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
) )
from onyx.agents.agent_search.shared_graph_utils.prompts import ( from onyx.agents.agent_search.shared_graph_utils.prompts import (
SUB_QUESTION_ANSWER_TEMPLATE, SUB_QUESTION_ANSWER_TEMPLATE,
@ -90,7 +92,12 @@ def generate_initial_answer(
consolidated_context_docs, consolidated_context_docs consolidated_context_docs, consolidated_context_docs
) )
decomp_questions = [] sub_questions: list[str] = []
streamed_documents = (
relevant_docs
if len(relevant_docs) > 0
else state.orig_question_retrieved_documents[:15]
)
# Use the query info from the base document retrieval # Use the query info from the base document retrieval
query_info = get_query_info(state.orig_question_sub_query_retrieval_results) query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
@ -102,8 +109,8 @@ def generate_initial_answer(
relevance_list = relevance_from_docs(relevant_docs) relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses( for tool_response in yield_search_responses(
query=question, query=question,
reranked_sections=relevant_docs, reranked_sections=streamed_documents,
final_context_sections=relevant_docs, final_context_sections=streamed_documents,
search_query_info=query_info, search_query_info=query_info,
get_section_relevance=lambda: relevance_list, get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool, search_tool=graph_config.tooling.search_tool,
@ -140,35 +147,44 @@ def generate_initial_answer(
) )
else: else:
decomp_answer_results = state.sub_question_results sub_question_answer_results = state.sub_question_results
good_qa_list: list[str] = [] answered_sub_questions: list[str] = []
all_sub_questions: list[str] = [] # Separate list for tracking all questions
sub_question_num = 1 for idx, sub_question_answer_result in enumerate(
sub_question_answer_results, start=1
):
all_sub_questions.append(sub_question_answer_result.question)
for decomp_answer_result in decomp_answer_results: is_valid_answer = (
decomp_questions.append(decomp_answer_result.question) sub_question_answer_result.verified_high_quality
if ( and sub_question_answer_result.answer
decomp_answer_result.verified_high_quality and sub_question_answer_result.answer != UNKNOWN_ANSWER
and len(decomp_answer_result.answer) > 0 )
and decomp_answer_result.answer != UNKNOWN_ANSWER
): if is_valid_answer:
good_qa_list.append( answered_sub_questions.append(
SUB_QUESTION_ANSWER_TEMPLATE.format( SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=decomp_answer_result.question, sub_question=sub_question_answer_result.question,
sub_answer=decomp_answer_result.answer, sub_answer=sub_question_answer_result.answer,
sub_question_num=sub_question_num, sub_question_num=idx,
) )
) )
sub_question_num += 1
# Determine which base prompt to use given the sub-question information # Use list comprehension for joining answers and determine prompt type
if len(good_qa_list) > 0: sub_question_answer_str = (
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) "\n\n------\n\n".join(answered_sub_questions)
base_prompt = INITIAL_RAG_PROMPT if answered_sub_questions
else: else ""
sub_question_answer_str = "" )
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS base_prompt = (
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
if answered_sub_questions
else INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
sub_questions = all_sub_questions # Replace the original assignment
model = graph_config.tooling.fast_llm model = graph_config.tooling.fast_llm
@ -275,7 +291,7 @@ def generate_initial_answer(
return InitialAnswerUpdate( return InitialAnswerUpdate(
initial_answer=answer, initial_answer=answer,
initial_agent_stats=initial_agent_stats, initial_agent_stats=initial_agent_stats,
generated_sub_questions=decomp_questions, generated_sub_questions=sub_questions,
agent_base_end_time=agent_base_end_time, agent_base_end_time=agent_base_end_time,
agent_base_metrics=agent_base_metrics, agent_base_metrics=agent_base_metrics,
log_messages=[ log_messages=[

View File

@ -23,10 +23,10 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt, build_history_prompt,
) )
from onyx.agents.agent_search.shared_graph_utils.prompts import ( from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS, INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
) )
from onyx.agents.agent_search.shared_graph_utils.prompts import ( from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH, INITIAL_QUESTION_DECOMPOSITION_PROMPT,
) )
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import ( from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -79,7 +79,7 @@ def decompose_orig_question(
) )
else: else:
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format( decomposition_prompt = INITIAL_QUESTION_DECOMPOSITION_PROMPT.format(
question=question, history=history question=question, history=history
) )

View File

@ -24,8 +24,9 @@ def format_orig_question_search_output(
sub_question_retrieval_stats = sub_question_retrieval_stats sub_question_retrieval_stats = sub_question_retrieval_stats
return OrigQuestionRetrievalUpdate( return OrigQuestionRetrievalUpdate(
orig_question_verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
orig_question_sub_query_retrieval_results=state.expanded_retrieval_result.expanded_query_results, orig_question_sub_query_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
orig_question_retrieved_documents=state.expanded_retrieval_result.context_documents, orig_question_retrieved_documents=state.retrieved_documents,
orig_question_retrieval_stats=sub_question_retrieval_stats, orig_question_retrieval_stats=sub_question_retrieval_stats,
log_messages=[], log_messages=[],
) )

View File

@ -1,11 +1,6 @@
from pydantic import BaseModel
from onyx.agents.agent_search.deep_search.main.states import ( from onyx.agents.agent_search.deep_search.main.states import (
OrigQuestionRetrievalUpdate, OrigQuestionRetrievalUpdate,
) )
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
QuestionRetrievalResult,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import ( from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput, ExpandedRetrievalInput,
) )
@ -23,14 +18,14 @@ class BaseRawSearchInput(ExpandedRetrievalInput):
## Graph Output State ## Graph Output State
class BaseRawSearchOutput(BaseModel): class BaseRawSearchOutput(OrigQuestionRetrievalUpdate):
""" """
This is a list of results even though each call of this subgraph only returns one result. This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together. results in a list so the add operator is used to add them together.
""" """
base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult() # base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
## Graph State ## Graph State

View File

@ -10,7 +10,9 @@ from onyx.agents.agent_search.deep_search.main.states import (
) )
from onyx.agents.agent_search.deep_search.main.states import MainState from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
)
from onyx.agents.agent_search.shared_graph_utils.utils import ( from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string, get_langgraph_node_log_string,
) )
@ -28,7 +30,7 @@ def compare_answers(
initial_answer = state.initial_answer initial_answer = state.initial_answer
refined_answer = state.refined_answer refined_answer = state.refined_answer
compare_answers_prompt = ANSWER_COMPARISON_PROMPT.format( compare_answers_prompt = INITIAL_REFINED_ANSWER_COMPARISON_PROMPT.format(
question=question, initial_answer=initial_answer, refined_answer=refined_answer question=question, initial_answer=initial_answer, refined_answer=refined_answer
) )

View File

@ -21,7 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt, build_history_prompt,
) )
from onyx.agents.agent_search.shared_graph_utils.prompts import ( from onyx.agents.agent_search.shared_graph_utils.prompts import (
DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES, REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
) )
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import ( from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -78,7 +78,7 @@ def create_refined_sub_questions(
msg = [ msg = [
HumanMessage( HumanMessage(
content=DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES.format( content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT.format(
question=question, question=question,
history=history, history=history,
entity_term_extraction_str=entity_term_extraction_str, entity_term_extraction_str=entity_term_extraction_str,

View File

@ -18,7 +18,14 @@ from onyx.agents.agent_search.shared_graph_utils.models import EntityExtractionR
from onyx.agents.agent_search.shared_graph_utils.models import ( from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction, EntityRelationshipTermExtraction,
) )
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
from onyx.agents.agent_search.shared_graph_utils.models import Relationship
from onyx.agents.agent_search.shared_graph_utils.models import Term
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ENTITY_TERM_EXTRACTION_PROMPT,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import ( from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string, get_langgraph_node_log_string,
@ -57,11 +64,15 @@ def extract_entities_terms(
doc_context = format_docs(initial_search_docs) doc_context = format_docs(initial_search_docs)
doc_context = trim_prompt_piece( doc_context = trim_prompt_piece(
graph_config.tooling.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question graph_config.tooling.fast_llm.config,
doc_context,
ENTITY_TERM_EXTRACTION_PROMPT + question,
) )
msg = [ msg = [
HumanMessage( HumanMessage(
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context), content=ENTITY_TERM_EXTRACTION_PROMPT.format(
question=question, context=doc_context
),
) )
] ]
fast_llm = graph_config.tooling.fast_llm fast_llm = graph_config.tooling.fast_llm

View File

@ -28,12 +28,14 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import ( from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections, dedup_inference_sections,
) )
from onyx.agents.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import ( from onyx.agents.agent_search.shared_graph_utils.prompts import (
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS, REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
) )
from onyx.agents.agent_search.shared_graph_utils.prompts import ( from onyx.agents.agent_search.shared_graph_utils.prompts import (
SUB_QUESTION_ANSWER_TEMPLATE_REVISED, REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
SUB_QUESTION_ANSWER_TEMPLATE_REFINED,
) )
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
from onyx.agents.agent_search.shared_graph_utils.utils import ( from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -71,12 +73,17 @@ def generate_refined_answer(
verified_reranked_documents = state.verified_reranked_documents verified_reranked_documents = state.verified_reranked_documents
sub_questions_cited_documents = state.cited_documents sub_questions_cited_documents = state.cited_documents
all_original_question_documents = state.orig_question_retrieved_documents original_question_verified_documents = (
state.orig_question_verified_reranked_documents
)
original_question_retrieved_documents = state.orig_question_retrieved_documents
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
counter = 0 counter = 0
for original_doc_number, original_doc in enumerate(all_original_question_documents): for original_doc_number, original_doc in enumerate(
original_question_verified_documents
):
if original_doc_number not in sub_questions_cited_documents: if original_doc_number not in sub_questions_cited_documents:
if ( if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS counter <= AGENT_MIN_ORIG_QUESTION_DOCS
@ -92,16 +99,22 @@ def generate_refined_answer(
consolidated_context_docs, consolidated_context_docs consolidated_context_docs, consolidated_context_docs
) )
streaming_docs = (
relevant_docs
if len(relevant_docs) > 0
else original_question_retrieved_documents[:15]
)
query_info = get_query_info(state.orig_question_sub_query_retrieval_results) query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
assert ( assert (
graph_config.tooling.search_tool graph_config.tooling.search_tool
), "search_tool must be provided for agentic search" ), "search_tool must be provided for agentic search"
# stream refined answer docs # stream refined answer docs, or original question docs if no relevant docs are found
relevance_list = relevance_from_docs(relevant_docs) relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses( for tool_response in yield_search_responses(
query=question, query=question,
reranked_sections=relevant_docs, reranked_sections=streaming_docs,
final_context_sections=relevant_docs, final_context_sections=streaming_docs,
search_query_info=query_info, search_query_info=query_info,
get_section_relevance=lambda: relevance_list, get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool, search_tool=graph_config.tooling.search_tool,
@ -124,71 +137,62 @@ def generate_refined_answer(
else: else:
refined_doc_effectiveness = 10.0 refined_doc_effectiveness = 10.0
decomp_answer_results = state.sub_question_results sub_question_answer_results = state.sub_question_results
answered_qa_list: list[str] = [] answered_sub_question_answer_list: list[str] = []
decomp_questions = [] sub_questions: list[str] = []
initial_answered_sub_questions: set[str] = set()
refined_answered_sub_questions: set[str] = set()
initial_good_sub_questions: list[str] = [] for i, result in enumerate(sub_question_answer_results, 1):
new_revised_good_sub_questions: list[str] = [] question_level, _ = parse_question_id(result.question_id)
sub_questions.append(result.question)
sub_question_num = 1
for decomp_answer_result in decomp_answer_results:
question_level, question_num = parse_question_id(
decomp_answer_result.question_id
)
decomp_questions.append(decomp_answer_result.question)
if ( if (
decomp_answer_result.verified_high_quality result.verified_high_quality
and len(decomp_answer_result.answer) > 0 and result.answer
and decomp_answer_result.answer != UNKNOWN_ANSWER and result.answer != UNKNOWN_ANSWER
): ):
if question_level == 0: sub_question_type = "initial" if question_level == 0 else "refined"
initial_good_sub_questions.append(decomp_answer_result.question) question_set = (
sub_question_type = "initial" initial_answered_sub_questions
else: if question_level == 0
new_revised_good_sub_questions.append(decomp_answer_result.question) else refined_answered_sub_questions
sub_question_type = "refined" )
answered_qa_list.append( question_set.add(result.question)
SUB_QUESTION_ANSWER_TEMPLATE_REVISED.format(
sub_question=decomp_answer_result.question, answered_sub_question_answer_list.append(
sub_answer=decomp_answer_result.answer, SUB_QUESTION_ANSWER_TEMPLATE_REFINED.format(
sub_question_num=sub_question_num, sub_question=result.question,
sub_answer=result.answer,
sub_question_num=i,
sub_question_type=sub_question_type, sub_question_type=sub_question_type,
) )
) )
sub_question_num += 1 # Calculate efficiency
total_answered_questions = (
initial_good_sub_questions = list(set(initial_good_sub_questions)) initial_answered_sub_questions | refined_answered_sub_questions
new_revised_good_sub_questions = list(set(new_revised_good_sub_questions)) )
total_good_sub_questions = list( revision_question_efficiency = (
set(initial_good_sub_questions + new_revised_good_sub_questions) len(total_answered_questions) / len(initial_answered_sub_questions)
if initial_answered_sub_questions
else 10.0
if refined_answered_sub_questions
else 1.0
) )
if len(initial_good_sub_questions) > 0:
revision_question_efficiency: float = len(total_good_sub_questions) / len(
initial_good_sub_questions
)
elif len(new_revised_good_sub_questions) > 0:
revision_question_efficiency = 10.0
else:
revision_question_efficiency = 1.0
sub_question_answer_str = "\n\n------\n\n".join(list(set(answered_qa_list)))
# original answer
sub_question_answer_str = "\n\n------\n\n".join(
set(answered_sub_question_answer_list)
)
initial_answer = state.initial_answer or "" initial_answer = state.initial_answer or ""
# Determine which persona-specification prompt to use # Choose appropriate prompt template
base_prompt = (
# Determine which base prompt to use given the sub-question information REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS
if len(answered_qa_list) > 0: if answered_sub_question_answer_list
base_prompt = REVISED_RAG_PROMPT else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
else: )
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
model = graph_config.tooling.fast_llm model = graph_config.tooling.fast_llm
relevant_docs_str = format_docs(relevant_docs) relevant_docs_str = format_docs(relevant_docs)
@ -211,7 +215,7 @@ def generate_refined_answer(
answered_sub_questions=remove_document_citations( answered_sub_questions=remove_document_citations(
sub_question_answer_str sub_question_answer_str
), ),
relevant_docs=relevant_docs, relevant_docs=relevant_docs_str,
initial_answer=remove_document_citations(initial_answer) initial_answer=remove_document_citations(initial_answer)
if initial_answer if initial_answer
else None, else None,
@ -221,8 +225,6 @@ def generate_refined_answer(
) )
] ]
# Grader
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""] streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = [] dispatch_timings: list[float] = []
for message in model.stream(msg): for message in model.stream(msg):
@ -248,7 +250,7 @@ def generate_refined_answer(
dispatch_timings.append((end_stream_token - start_stream_token).microseconds) dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
streamed_tokens.append(content) streamed_tokens.append(content)
logger.info( logger.debug(
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}" f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
) )
dispatch_main_answer_stop_info(1, writer) dispatch_main_answer_stop_info(1, writer)

View File

@ -129,6 +129,9 @@ class OrigQuestionRetrievalUpdate(LoggerUpdate):
orig_question_retrieved_documents: Annotated[ orig_question_retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections list[InferenceSection], dedup_inference_sections
] ]
orig_question_verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
orig_question_sub_query_retrieval_results: list[QueryRetrievalResult] = [] orig_question_sub_query_retrieval_results: list[QueryRetrievalResult] = []
orig_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats() orig_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()

View File

@ -7,6 +7,7 @@ from onyx.context.search.models import InferenceSection
class QuestionRetrievalResult(BaseModel): class QuestionRetrievalResult(BaseModel):
expanded_query_results: list[QueryRetrievalResult] = [] expanded_query_results: list[QueryRetrievalResult] = []
retrieved_documents: list[InferenceSection] = []
verified_reranked_documents: list[InferenceSection] = [] verified_reranked_documents: list[InferenceSection] = []
context_documents: list[InferenceSection] = [] context_documents: list[InferenceSection] = []
retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats() retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()

View File

@ -17,7 +17,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
) )
from onyx.agents.agent_search.models import GraphConfig from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import ( from onyx.agents.agent_search.shared_graph_utils.prompts import (
REWRITE_PROMPT_MULTI_ORIGINAL, QUERY_REWRITING_PROMPT,
) )
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import ( from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -47,7 +47,7 @@ def expand_queries(
msg = [ msg = [
HumanMessage( HumanMessage(
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question), content=QUERY_REWRITING_PROMPT.format(question=question),
) )
] ]

View File

@ -82,6 +82,7 @@ def format_results(
return ExpandedRetrievalUpdate( return ExpandedRetrievalUpdate(
expanded_retrieval_result=QuestionRetrievalResult( expanded_retrieval_result=QuestionRetrievalResult(
expanded_query_results=state.query_retrieval_results, expanded_query_results=state.query_retrieval_results,
retrieved_documents=state.retrieved_documents,
verified_reranked_documents=reranked_documents, verified_reranked_documents=reranked_documents,
context_documents=state.reranked_documents, context_documents=state.reranked_documents,
retrieval_stats=sub_question_retrieval_stats, retrieval_stats=sub_question_retrieval_stats,

View File

@ -13,7 +13,9 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece, trim_prompt_piece,
) )
from onyx.agents.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT from onyx.agents.agent_search.shared_graph_utils.prompts import (
DOCUMENT_VERIFICATION_PROMPT,
)
def verify_documents( def verify_documents(
@ -38,12 +40,12 @@ def verify_documents(
fast_llm = graph_config.tooling.fast_llm fast_llm = graph_config.tooling.fast_llm
document_content = trim_prompt_piece( document_content = trim_prompt_piece(
fast_llm.config, document_content, VERIFIER_PROMPT + question fast_llm.config, document_content, DOCUMENT_VERIFICATION_PROMPT + question
) )
msg = [ msg = [
HumanMessage( HumanMessage(
content=VERIFIER_PROMPT.format( content=DOCUMENT_VERIFICATION_PROMPT.format(
question=question, document_content=document_content question=question, document_content=document_content
) )
) )

View File

@ -60,6 +60,9 @@ class ExpandedRetrievalUpdate(LoggerUpdate, BaseModel):
class ExpandedRetrievalOutput(LoggerUpdate, BaseModel): class ExpandedRetrievalOutput(LoggerUpdate, BaseModel):
expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult() expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult() base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
## Graph State ## Graph State

View File

@ -7,8 +7,8 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import ( from onyx.agents.agent_search.shared_graph_utils.models import (
AgentPromptEnrichmentComponents, AgentPromptEnrichmentComponents,
) )
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2 from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_FRAMING_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_QUESTION_RAG_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import ( from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions, get_persona_agent_prompt_expressions,
) )
@ -45,10 +45,12 @@ def build_sub_question_answer_prompt(
docs_str = "\n\n".join(docs_format_list) docs_str = "\n\n".join(docs_format_list)
docs_str = trim_prompt_piece( docs_str = trim_prompt_piece(
config, docs_str, BASE_RAG_PROMPT_v2 + question + original_question + date_str config,
docs_str,
SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
) )
human_message = HumanMessage( human_message = HumanMessage(
content=BASE_RAG_PROMPT_v2.format( content=SUB_QUESTION_RAG_PROMPT.format(
question=question, question=question,
original_question=original_question, original_question=original_question,
context=docs_str, context=docs_str,
@ -118,7 +120,7 @@ def build_history_prompt(config: GraphConfig, question: str) -> str:
if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH: if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH:
history = summarize_history(history, question, persona_base, model) history = summarize_history(history, question, persona_base, model)
return HISTORY_PROMPT.format(history=history) if history else "" return HISTORY_FRAMING_PROMPT.format(history=history) if history else ""
def get_prompt_enrichment_components( def get_prompt_enrichment_components(

File diff suppressed because it is too large Load Diff

View File

@ -73,7 +73,7 @@ def format_docs(docs: Sequence[InferenceSection]) -> str:
formatted_doc_list = [] formatted_doc_list = []
for doc_num, doc in enumerate(docs): for doc_num, doc in enumerate(docs):
formatted_doc_list.append(f"Document D{doc_num + 1}:\n{doc.combined_content}") formatted_doc_list.append(f"**Document D{doc_num + 1}:\n{doc.combined_content}")
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list) return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)