mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-15 02:20:52 +02:00
renames + fix of refined answer generation prompt
This commit is contained in:
parent
71304e4228
commit
e23dd0a3fa
@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
|||||||
SubQuestionAnswerCheckUpdate,
|
SubQuestionAnswerCheckUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import GraphConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
|
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_ANSWER_CHECK_PROMPT
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
get_langgraph_node_log_string,
|
get_langgraph_node_log_string,
|
||||||
@ -40,7 +40,7 @@ def check_sub_answer(
|
|||||||
)
|
)
|
||||||
msg = [
|
msg = [
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=SUB_CHECK_PROMPT.format(
|
content=SUB_ANSWER_CHECK_PROMPT.format(
|
||||||
question=state.question,
|
question=state.question,
|
||||||
base_answer=state.answer,
|
base_answer=state.answer,
|
||||||
)
|
)
|
||||||
|
@ -30,9 +30,11 @@ from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResul
|
|||||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||||
dedup_inference_sections,
|
dedup_inference_sections,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
|
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS,
|
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||||
|
)
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
|
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
SUB_QUESTION_ANSWER_TEMPLATE,
|
SUB_QUESTION_ANSWER_TEMPLATE,
|
||||||
@ -90,7 +92,12 @@ def generate_initial_answer(
|
|||||||
consolidated_context_docs, consolidated_context_docs
|
consolidated_context_docs, consolidated_context_docs
|
||||||
)
|
)
|
||||||
|
|
||||||
decomp_questions = []
|
sub_questions: list[str] = []
|
||||||
|
streamed_documents = (
|
||||||
|
relevant_docs
|
||||||
|
if len(relevant_docs) > 0
|
||||||
|
else state.orig_question_retrieved_documents[:15]
|
||||||
|
)
|
||||||
|
|
||||||
# Use the query info from the base document retrieval
|
# Use the query info from the base document retrieval
|
||||||
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
|
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
|
||||||
@ -102,8 +109,8 @@ def generate_initial_answer(
|
|||||||
relevance_list = relevance_from_docs(relevant_docs)
|
relevance_list = relevance_from_docs(relevant_docs)
|
||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
query=question,
|
query=question,
|
||||||
reranked_sections=relevant_docs,
|
reranked_sections=streamed_documents,
|
||||||
final_context_sections=relevant_docs,
|
final_context_sections=streamed_documents,
|
||||||
search_query_info=query_info,
|
search_query_info=query_info,
|
||||||
get_section_relevance=lambda: relevance_list,
|
get_section_relevance=lambda: relevance_list,
|
||||||
search_tool=graph_config.tooling.search_tool,
|
search_tool=graph_config.tooling.search_tool,
|
||||||
@ -140,35 +147,44 @@ def generate_initial_answer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
decomp_answer_results = state.sub_question_results
|
sub_question_answer_results = state.sub_question_results
|
||||||
|
|
||||||
good_qa_list: list[str] = []
|
answered_sub_questions: list[str] = []
|
||||||
|
all_sub_questions: list[str] = [] # Separate list for tracking all questions
|
||||||
|
|
||||||
sub_question_num = 1
|
for idx, sub_question_answer_result in enumerate(
|
||||||
|
sub_question_answer_results, start=1
|
||||||
|
):
|
||||||
|
all_sub_questions.append(sub_question_answer_result.question)
|
||||||
|
|
||||||
for decomp_answer_result in decomp_answer_results:
|
is_valid_answer = (
|
||||||
decomp_questions.append(decomp_answer_result.question)
|
sub_question_answer_result.verified_high_quality
|
||||||
if (
|
and sub_question_answer_result.answer
|
||||||
decomp_answer_result.verified_high_quality
|
and sub_question_answer_result.answer != UNKNOWN_ANSWER
|
||||||
and len(decomp_answer_result.answer) > 0
|
)
|
||||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
|
||||||
):
|
if is_valid_answer:
|
||||||
good_qa_list.append(
|
answered_sub_questions.append(
|
||||||
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||||
sub_question=decomp_answer_result.question,
|
sub_question=sub_question_answer_result.question,
|
||||||
sub_answer=decomp_answer_result.answer,
|
sub_answer=sub_question_answer_result.answer,
|
||||||
sub_question_num=sub_question_num,
|
sub_question_num=idx,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sub_question_num += 1
|
|
||||||
|
|
||||||
# Determine which base prompt to use given the sub-question information
|
# Use list comprehension for joining answers and determine prompt type
|
||||||
if len(good_qa_list) > 0:
|
sub_question_answer_str = (
|
||||||
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
|
"\n\n------\n\n".join(answered_sub_questions)
|
||||||
base_prompt = INITIAL_RAG_PROMPT
|
if answered_sub_questions
|
||||||
else:
|
else ""
|
||||||
sub_question_answer_str = ""
|
)
|
||||||
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
|
base_prompt = (
|
||||||
|
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||||
|
if answered_sub_questions
|
||||||
|
else INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
sub_questions = all_sub_questions # Replace the original assignment
|
||||||
|
|
||||||
model = graph_config.tooling.fast_llm
|
model = graph_config.tooling.fast_llm
|
||||||
|
|
||||||
@ -275,7 +291,7 @@ def generate_initial_answer(
|
|||||||
return InitialAnswerUpdate(
|
return InitialAnswerUpdate(
|
||||||
initial_answer=answer,
|
initial_answer=answer,
|
||||||
initial_agent_stats=initial_agent_stats,
|
initial_agent_stats=initial_agent_stats,
|
||||||
generated_sub_questions=decomp_questions,
|
generated_sub_questions=sub_questions,
|
||||||
agent_base_end_time=agent_base_end_time,
|
agent_base_end_time=agent_base_end_time,
|
||||||
agent_base_metrics=agent_base_metrics,
|
agent_base_metrics=agent_base_metrics,
|
||||||
log_messages=[
|
log_messages=[
|
||||||
|
@ -23,10 +23,10 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
|||||||
build_history_prompt,
|
build_history_prompt,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS,
|
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
INITIAL_QUESTION_DECOMPOSITION_PROMPT,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
@ -79,7 +79,7 @@ def decompose_orig_question(
|
|||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format(
|
decomposition_prompt = INITIAL_QUESTION_DECOMPOSITION_PROMPT.format(
|
||||||
question=question, history=history
|
question=question, history=history
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,8 +24,9 @@ def format_orig_question_search_output(
|
|||||||
sub_question_retrieval_stats = sub_question_retrieval_stats
|
sub_question_retrieval_stats = sub_question_retrieval_stats
|
||||||
|
|
||||||
return OrigQuestionRetrievalUpdate(
|
return OrigQuestionRetrievalUpdate(
|
||||||
|
orig_question_verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
|
||||||
orig_question_sub_query_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
|
orig_question_sub_query_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
|
||||||
orig_question_retrieved_documents=state.expanded_retrieval_result.context_documents,
|
orig_question_retrieved_documents=state.retrieved_documents,
|
||||||
orig_question_retrieval_stats=sub_question_retrieval_stats,
|
orig_question_retrieval_stats=sub_question_retrieval_stats,
|
||||||
log_messages=[],
|
log_messages=[],
|
||||||
)
|
)
|
||||||
|
@ -1,11 +1,6 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from onyx.agents.agent_search.deep_search.main.states import (
|
from onyx.agents.agent_search.deep_search.main.states import (
|
||||||
OrigQuestionRetrievalUpdate,
|
OrigQuestionRetrievalUpdate,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
|
|
||||||
QuestionRetrievalResult,
|
|
||||||
)
|
|
||||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||||
ExpandedRetrievalInput,
|
ExpandedRetrievalInput,
|
||||||
)
|
)
|
||||||
@ -23,14 +18,14 @@ class BaseRawSearchInput(ExpandedRetrievalInput):
|
|||||||
## Graph Output State
|
## Graph Output State
|
||||||
|
|
||||||
|
|
||||||
class BaseRawSearchOutput(BaseModel):
|
class BaseRawSearchOutput(OrigQuestionRetrievalUpdate):
|
||||||
"""
|
"""
|
||||||
This is a list of results even though each call of this subgraph only returns one result.
|
This is a list of results even though each call of this subgraph only returns one result.
|
||||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||||
results in a list so the add operator is used to add them together.
|
results in a list so the add operator is used to add them together.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
|
# base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
|
||||||
|
|
||||||
|
|
||||||
## Graph State
|
## Graph State
|
||||||
|
@ -10,7 +10,9 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
|||||||
)
|
)
|
||||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||||
from onyx.agents.agent_search.models import GraphConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
|
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
|
||||||
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
get_langgraph_node_log_string,
|
get_langgraph_node_log_string,
|
||||||
)
|
)
|
||||||
@ -28,7 +30,7 @@ def compare_answers(
|
|||||||
initial_answer = state.initial_answer
|
initial_answer = state.initial_answer
|
||||||
refined_answer = state.refined_answer
|
refined_answer = state.refined_answer
|
||||||
|
|
||||||
compare_answers_prompt = ANSWER_COMPARISON_PROMPT.format(
|
compare_answers_prompt = INITIAL_REFINED_ANSWER_COMPARISON_PROMPT.format(
|
||||||
question=question, initial_answer=initial_answer, refined_answer=refined_answer
|
question=question, initial_answer=initial_answer, refined_answer=refined_answer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
|||||||
build_history_prompt,
|
build_history_prompt,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES,
|
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
@ -78,7 +78,7 @@ def create_refined_sub_questions(
|
|||||||
|
|
||||||
msg = [
|
msg = [
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES.format(
|
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT.format(
|
||||||
question=question,
|
question=question,
|
||||||
history=history,
|
history=history,
|
||||||
entity_term_extraction_str=entity_term_extraction_str,
|
entity_term_extraction_str=entity_term_extraction_str,
|
||||||
|
@ -18,7 +18,14 @@ from onyx.agents.agent_search.shared_graph_utils.models import EntityExtractionR
|
|||||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||||
EntityRelationshipTermExtraction,
|
EntityRelationshipTermExtraction,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
|
|
||||||
|
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.models import Relationship
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.models import Term
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
|
ENTITY_TERM_EXTRACTION_PROMPT,
|
||||||
|
)
|
||||||
|
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
get_langgraph_node_log_string,
|
get_langgraph_node_log_string,
|
||||||
@ -57,11 +64,15 @@ def extract_entities_terms(
|
|||||||
doc_context = format_docs(initial_search_docs)
|
doc_context = format_docs(initial_search_docs)
|
||||||
|
|
||||||
doc_context = trim_prompt_piece(
|
doc_context = trim_prompt_piece(
|
||||||
graph_config.tooling.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
|
graph_config.tooling.fast_llm.config,
|
||||||
|
doc_context,
|
||||||
|
ENTITY_TERM_EXTRACTION_PROMPT + question,
|
||||||
)
|
)
|
||||||
msg = [
|
msg = [
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
content=ENTITY_TERM_EXTRACTION_PROMPT.format(
|
||||||
|
question=question, context=doc_context
|
||||||
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
fast_llm = graph_config.tooling.fast_llm
|
fast_llm = graph_config.tooling.fast_llm
|
||||||
|
@ -28,12 +28,14 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
|||||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||||
dedup_inference_sections,
|
dedup_inference_sections,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT
|
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS,
|
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
SUB_QUESTION_ANSWER_TEMPLATE_REVISED,
|
REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||||
|
)
|
||||||
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
|
SUB_QUESTION_ANSWER_TEMPLATE_REFINED,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
@ -71,12 +73,17 @@ def generate_refined_answer(
|
|||||||
|
|
||||||
verified_reranked_documents = state.verified_reranked_documents
|
verified_reranked_documents = state.verified_reranked_documents
|
||||||
sub_questions_cited_documents = state.cited_documents
|
sub_questions_cited_documents = state.cited_documents
|
||||||
all_original_question_documents = state.orig_question_retrieved_documents
|
original_question_verified_documents = (
|
||||||
|
state.orig_question_verified_reranked_documents
|
||||||
|
)
|
||||||
|
original_question_retrieved_documents = state.orig_question_retrieved_documents
|
||||||
|
|
||||||
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
|
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
|
||||||
|
|
||||||
counter = 0
|
counter = 0
|
||||||
for original_doc_number, original_doc in enumerate(all_original_question_documents):
|
for original_doc_number, original_doc in enumerate(
|
||||||
|
original_question_verified_documents
|
||||||
|
):
|
||||||
if original_doc_number not in sub_questions_cited_documents:
|
if original_doc_number not in sub_questions_cited_documents:
|
||||||
if (
|
if (
|
||||||
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
||||||
@ -92,16 +99,22 @@ def generate_refined_answer(
|
|||||||
consolidated_context_docs, consolidated_context_docs
|
consolidated_context_docs, consolidated_context_docs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
streaming_docs = (
|
||||||
|
relevant_docs
|
||||||
|
if len(relevant_docs) > 0
|
||||||
|
else original_question_retrieved_documents[:15]
|
||||||
|
)
|
||||||
|
|
||||||
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
|
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
|
||||||
assert (
|
assert (
|
||||||
graph_config.tooling.search_tool
|
graph_config.tooling.search_tool
|
||||||
), "search_tool must be provided for agentic search"
|
), "search_tool must be provided for agentic search"
|
||||||
# stream refined answer docs
|
# stream refined answer docs, or original question docs if no relevant docs are found
|
||||||
relevance_list = relevance_from_docs(relevant_docs)
|
relevance_list = relevance_from_docs(relevant_docs)
|
||||||
for tool_response in yield_search_responses(
|
for tool_response in yield_search_responses(
|
||||||
query=question,
|
query=question,
|
||||||
reranked_sections=relevant_docs,
|
reranked_sections=streaming_docs,
|
||||||
final_context_sections=relevant_docs,
|
final_context_sections=streaming_docs,
|
||||||
search_query_info=query_info,
|
search_query_info=query_info,
|
||||||
get_section_relevance=lambda: relevance_list,
|
get_section_relevance=lambda: relevance_list,
|
||||||
search_tool=graph_config.tooling.search_tool,
|
search_tool=graph_config.tooling.search_tool,
|
||||||
@ -124,71 +137,62 @@ def generate_refined_answer(
|
|||||||
else:
|
else:
|
||||||
refined_doc_effectiveness = 10.0
|
refined_doc_effectiveness = 10.0
|
||||||
|
|
||||||
decomp_answer_results = state.sub_question_results
|
sub_question_answer_results = state.sub_question_results
|
||||||
|
|
||||||
answered_qa_list: list[str] = []
|
answered_sub_question_answer_list: list[str] = []
|
||||||
decomp_questions = []
|
sub_questions: list[str] = []
|
||||||
|
initial_answered_sub_questions: set[str] = set()
|
||||||
|
refined_answered_sub_questions: set[str] = set()
|
||||||
|
|
||||||
initial_good_sub_questions: list[str] = []
|
for i, result in enumerate(sub_question_answer_results, 1):
|
||||||
new_revised_good_sub_questions: list[str] = []
|
question_level, _ = parse_question_id(result.question_id)
|
||||||
|
sub_questions.append(result.question)
|
||||||
|
|
||||||
sub_question_num = 1
|
|
||||||
|
|
||||||
for decomp_answer_result in decomp_answer_results:
|
|
||||||
question_level, question_num = parse_question_id(
|
|
||||||
decomp_answer_result.question_id
|
|
||||||
)
|
|
||||||
|
|
||||||
decomp_questions.append(decomp_answer_result.question)
|
|
||||||
if (
|
if (
|
||||||
decomp_answer_result.verified_high_quality
|
result.verified_high_quality
|
||||||
and len(decomp_answer_result.answer) > 0
|
and result.answer
|
||||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
and result.answer != UNKNOWN_ANSWER
|
||||||
):
|
):
|
||||||
if question_level == 0:
|
sub_question_type = "initial" if question_level == 0 else "refined"
|
||||||
initial_good_sub_questions.append(decomp_answer_result.question)
|
question_set = (
|
||||||
sub_question_type = "initial"
|
initial_answered_sub_questions
|
||||||
else:
|
if question_level == 0
|
||||||
new_revised_good_sub_questions.append(decomp_answer_result.question)
|
else refined_answered_sub_questions
|
||||||
sub_question_type = "refined"
|
)
|
||||||
answered_qa_list.append(
|
question_set.add(result.question)
|
||||||
SUB_QUESTION_ANSWER_TEMPLATE_REVISED.format(
|
|
||||||
sub_question=decomp_answer_result.question,
|
answered_sub_question_answer_list.append(
|
||||||
sub_answer=decomp_answer_result.answer,
|
SUB_QUESTION_ANSWER_TEMPLATE_REFINED.format(
|
||||||
sub_question_num=sub_question_num,
|
sub_question=result.question,
|
||||||
|
sub_answer=result.answer,
|
||||||
|
sub_question_num=i,
|
||||||
sub_question_type=sub_question_type,
|
sub_question_type=sub_question_type,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
sub_question_num += 1
|
# Calculate efficiency
|
||||||
|
total_answered_questions = (
|
||||||
initial_good_sub_questions = list(set(initial_good_sub_questions))
|
initial_answered_sub_questions | refined_answered_sub_questions
|
||||||
new_revised_good_sub_questions = list(set(new_revised_good_sub_questions))
|
)
|
||||||
total_good_sub_questions = list(
|
revision_question_efficiency = (
|
||||||
set(initial_good_sub_questions + new_revised_good_sub_questions)
|
len(total_answered_questions) / len(initial_answered_sub_questions)
|
||||||
|
if initial_answered_sub_questions
|
||||||
|
else 10.0
|
||||||
|
if refined_answered_sub_questions
|
||||||
|
else 1.0
|
||||||
)
|
)
|
||||||
if len(initial_good_sub_questions) > 0:
|
|
||||||
revision_question_efficiency: float = len(total_good_sub_questions) / len(
|
|
||||||
initial_good_sub_questions
|
|
||||||
)
|
|
||||||
elif len(new_revised_good_sub_questions) > 0:
|
|
||||||
revision_question_efficiency = 10.0
|
|
||||||
else:
|
|
||||||
revision_question_efficiency = 1.0
|
|
||||||
|
|
||||||
sub_question_answer_str = "\n\n------\n\n".join(list(set(answered_qa_list)))
|
|
||||||
|
|
||||||
# original answer
|
|
||||||
|
|
||||||
|
sub_question_answer_str = "\n\n------\n\n".join(
|
||||||
|
set(answered_sub_question_answer_list)
|
||||||
|
)
|
||||||
initial_answer = state.initial_answer or ""
|
initial_answer = state.initial_answer or ""
|
||||||
|
|
||||||
# Determine which persona-specification prompt to use
|
# Choose appropriate prompt template
|
||||||
|
base_prompt = (
|
||||||
# Determine which base prompt to use given the sub-question information
|
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||||
if len(answered_qa_list) > 0:
|
if answered_sub_question_answer_list
|
||||||
base_prompt = REVISED_RAG_PROMPT
|
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||||
else:
|
)
|
||||||
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
|
|
||||||
|
|
||||||
model = graph_config.tooling.fast_llm
|
model = graph_config.tooling.fast_llm
|
||||||
relevant_docs_str = format_docs(relevant_docs)
|
relevant_docs_str = format_docs(relevant_docs)
|
||||||
@ -211,7 +215,7 @@ def generate_refined_answer(
|
|||||||
answered_sub_questions=remove_document_citations(
|
answered_sub_questions=remove_document_citations(
|
||||||
sub_question_answer_str
|
sub_question_answer_str
|
||||||
),
|
),
|
||||||
relevant_docs=relevant_docs,
|
relevant_docs=relevant_docs_str,
|
||||||
initial_answer=remove_document_citations(initial_answer)
|
initial_answer=remove_document_citations(initial_answer)
|
||||||
if initial_answer
|
if initial_answer
|
||||||
else None,
|
else None,
|
||||||
@ -221,8 +225,6 @@ def generate_refined_answer(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Grader
|
|
||||||
|
|
||||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||||
dispatch_timings: list[float] = []
|
dispatch_timings: list[float] = []
|
||||||
for message in model.stream(msg):
|
for message in model.stream(msg):
|
||||||
@ -248,7 +250,7 @@ def generate_refined_answer(
|
|||||||
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||||
streamed_tokens.append(content)
|
streamed_tokens.append(content)
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||||
)
|
)
|
||||||
dispatch_main_answer_stop_info(1, writer)
|
dispatch_main_answer_stop_info(1, writer)
|
||||||
|
@ -129,6 +129,9 @@ class OrigQuestionRetrievalUpdate(LoggerUpdate):
|
|||||||
orig_question_retrieved_documents: Annotated[
|
orig_question_retrieved_documents: Annotated[
|
||||||
list[InferenceSection], dedup_inference_sections
|
list[InferenceSection], dedup_inference_sections
|
||||||
]
|
]
|
||||||
|
orig_question_verified_reranked_documents: Annotated[
|
||||||
|
list[InferenceSection], dedup_inference_sections
|
||||||
|
]
|
||||||
orig_question_sub_query_retrieval_results: list[QueryRetrievalResult] = []
|
orig_question_sub_query_retrieval_results: list[QueryRetrievalResult] = []
|
||||||
orig_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
|
orig_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ from onyx.context.search.models import InferenceSection
|
|||||||
|
|
||||||
class QuestionRetrievalResult(BaseModel):
|
class QuestionRetrievalResult(BaseModel):
|
||||||
expanded_query_results: list[QueryRetrievalResult] = []
|
expanded_query_results: list[QueryRetrievalResult] = []
|
||||||
|
retrieved_documents: list[InferenceSection] = []
|
||||||
verified_reranked_documents: list[InferenceSection] = []
|
verified_reranked_documents: list[InferenceSection] = []
|
||||||
context_documents: list[InferenceSection] = []
|
context_documents: list[InferenceSection] = []
|
||||||
retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
|
retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
|
||||||
|
@ -17,7 +17,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
|||||||
)
|
)
|
||||||
from onyx.agents.agent_search.models import GraphConfig
|
from onyx.agents.agent_search.models import GraphConfig
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
REWRITE_PROMPT_MULTI_ORIGINAL,
|
QUERY_REWRITING_PROMPT,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
@ -47,7 +47,7 @@ def expand_queries(
|
|||||||
|
|
||||||
msg = [
|
msg = [
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
content=QUERY_REWRITING_PROMPT.format(question=question),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -82,6 +82,7 @@ def format_results(
|
|||||||
return ExpandedRetrievalUpdate(
|
return ExpandedRetrievalUpdate(
|
||||||
expanded_retrieval_result=QuestionRetrievalResult(
|
expanded_retrieval_result=QuestionRetrievalResult(
|
||||||
expanded_query_results=state.query_retrieval_results,
|
expanded_query_results=state.query_retrieval_results,
|
||||||
|
retrieved_documents=state.retrieved_documents,
|
||||||
verified_reranked_documents=reranked_documents,
|
verified_reranked_documents=reranked_documents,
|
||||||
context_documents=state.reranked_documents,
|
context_documents=state.reranked_documents,
|
||||||
retrieval_stats=sub_question_retrieval_stats,
|
retrieval_stats=sub_question_retrieval_stats,
|
||||||
|
@ -13,7 +13,9 @@ from onyx.agents.agent_search.models import GraphConfig
|
|||||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||||
trim_prompt_piece,
|
trim_prompt_piece,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||||
|
DOCUMENT_VERIFICATION_PROMPT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def verify_documents(
|
def verify_documents(
|
||||||
@ -38,12 +40,12 @@ def verify_documents(
|
|||||||
fast_llm = graph_config.tooling.fast_llm
|
fast_llm = graph_config.tooling.fast_llm
|
||||||
|
|
||||||
document_content = trim_prompt_piece(
|
document_content = trim_prompt_piece(
|
||||||
fast_llm.config, document_content, VERIFIER_PROMPT + question
|
fast_llm.config, document_content, DOCUMENT_VERIFICATION_PROMPT + question
|
||||||
)
|
)
|
||||||
|
|
||||||
msg = [
|
msg = [
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=VERIFIER_PROMPT.format(
|
content=DOCUMENT_VERIFICATION_PROMPT.format(
|
||||||
question=question, document_content=document_content
|
question=question, document_content=document_content
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -60,6 +60,9 @@ class ExpandedRetrievalUpdate(LoggerUpdate, BaseModel):
|
|||||||
class ExpandedRetrievalOutput(LoggerUpdate, BaseModel):
|
class ExpandedRetrievalOutput(LoggerUpdate, BaseModel):
|
||||||
expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
|
expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
|
||||||
base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
|
base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
|
||||||
|
retrieved_documents: Annotated[
|
||||||
|
list[InferenceSection], dedup_inference_sections
|
||||||
|
] = []
|
||||||
|
|
||||||
|
|
||||||
## Graph State
|
## Graph State
|
||||||
|
@ -7,8 +7,8 @@ from onyx.agents.agent_search.models import GraphConfig
|
|||||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||||
AgentPromptEnrichmentComponents,
|
AgentPromptEnrichmentComponents,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
|
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_FRAMING_PROMPT
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
|
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_QUESTION_RAG_PROMPT
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
get_persona_agent_prompt_expressions,
|
get_persona_agent_prompt_expressions,
|
||||||
)
|
)
|
||||||
@ -45,10 +45,12 @@ def build_sub_question_answer_prompt(
|
|||||||
docs_str = "\n\n".join(docs_format_list)
|
docs_str = "\n\n".join(docs_format_list)
|
||||||
|
|
||||||
docs_str = trim_prompt_piece(
|
docs_str = trim_prompt_piece(
|
||||||
config, docs_str, BASE_RAG_PROMPT_v2 + question + original_question + date_str
|
config,
|
||||||
|
docs_str,
|
||||||
|
SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
|
||||||
)
|
)
|
||||||
human_message = HumanMessage(
|
human_message = HumanMessage(
|
||||||
content=BASE_RAG_PROMPT_v2.format(
|
content=SUB_QUESTION_RAG_PROMPT.format(
|
||||||
question=question,
|
question=question,
|
||||||
original_question=original_question,
|
original_question=original_question,
|
||||||
context=docs_str,
|
context=docs_str,
|
||||||
@ -118,7 +120,7 @@ def build_history_prompt(config: GraphConfig, question: str) -> str:
|
|||||||
if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH:
|
if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH:
|
||||||
history = summarize_history(history, question, persona_base, model)
|
history = summarize_history(history, question, persona_base, model)
|
||||||
|
|
||||||
return HISTORY_PROMPT.format(history=history) if history else ""
|
return HISTORY_FRAMING_PROMPT.format(history=history) if history else ""
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_enrichment_components(
|
def get_prompt_enrichment_components(
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -73,7 +73,7 @@ def format_docs(docs: Sequence[InferenceSection]) -> str:
|
|||||||
formatted_doc_list = []
|
formatted_doc_list = []
|
||||||
|
|
||||||
for doc_num, doc in enumerate(docs):
|
for doc_num, doc in enumerate(docs):
|
||||||
formatted_doc_list.append(f"Document D{doc_num + 1}:\n{doc.combined_content}")
|
formatted_doc_list.append(f"**Document D{doc_num + 1}:\n{doc.combined_content}")
|
||||||
|
|
||||||
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)
|
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user