mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-12 13:59:35 +02:00
renames + fix of refined answer generation prompt
This commit is contained in:
parent
71304e4228
commit
e23dd0a3fa
@ -12,7 +12,7 @@ from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer
|
||||
SubQuestionAnswerCheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_ANSWER_CHECK_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@ -40,7 +40,7 @@ def check_sub_answer(
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_CHECK_PROMPT.format(
|
||||
content=SUB_ANSWER_CHECK_PROMPT.format(
|
||||
question=state.question,
|
||||
base_answer=state.answer,
|
||||
)
|
||||
|
@ -30,9 +30,11 @@ from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResul
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS,
|
||||
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE,
|
||||
@ -90,7 +92,12 @@ def generate_initial_answer(
|
||||
consolidated_context_docs, consolidated_context_docs
|
||||
)
|
||||
|
||||
decomp_questions = []
|
||||
sub_questions: list[str] = []
|
||||
streamed_documents = (
|
||||
relevant_docs
|
||||
if len(relevant_docs) > 0
|
||||
else state.orig_question_retrieved_documents[:15]
|
||||
)
|
||||
|
||||
# Use the query info from the base document retrieval
|
||||
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
|
||||
@ -102,8 +109,8 @@ def generate_initial_answer(
|
||||
relevance_list = relevance_from_docs(relevant_docs)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=relevant_docs,
|
||||
final_context_sections=relevant_docs,
|
||||
reranked_sections=streamed_documents,
|
||||
final_context_sections=streamed_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
@ -140,35 +147,44 @@ def generate_initial_answer(
|
||||
)
|
||||
|
||||
else:
|
||||
decomp_answer_results = state.sub_question_results
|
||||
sub_question_answer_results = state.sub_question_results
|
||||
|
||||
good_qa_list: list[str] = []
|
||||
answered_sub_questions: list[str] = []
|
||||
all_sub_questions: list[str] = [] # Separate list for tracking all questions
|
||||
|
||||
sub_question_num = 1
|
||||
for idx, sub_question_answer_result in enumerate(
|
||||
sub_question_answer_results, start=1
|
||||
):
|
||||
all_sub_questions.append(sub_question_answer_result.question)
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
decomp_questions.append(decomp_answer_result.question)
|
||||
if (
|
||||
decomp_answer_result.verified_high_quality
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
||||
):
|
||||
good_qa_list.append(
|
||||
is_valid_answer = (
|
||||
sub_question_answer_result.verified_high_quality
|
||||
and sub_question_answer_result.answer
|
||||
and sub_question_answer_result.answer != UNKNOWN_ANSWER
|
||||
)
|
||||
|
||||
if is_valid_answer:
|
||||
answered_sub_questions.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=decomp_answer_result.question,
|
||||
sub_answer=decomp_answer_result.answer,
|
||||
sub_question_num=sub_question_num,
|
||||
sub_question=sub_question_answer_result.question,
|
||||
sub_answer=sub_question_answer_result.answer,
|
||||
sub_question_num=idx,
|
||||
)
|
||||
)
|
||||
sub_question_num += 1
|
||||
|
||||
# Determine which base prompt to use given the sub-question information
|
||||
if len(good_qa_list) > 0:
|
||||
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
|
||||
base_prompt = INITIAL_RAG_PROMPT
|
||||
else:
|
||||
sub_question_answer_str = ""
|
||||
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||
# Use list comprehension for joining answers and determine prompt type
|
||||
sub_question_answer_str = (
|
||||
"\n\n------\n\n".join(answered_sub_questions)
|
||||
if answered_sub_questions
|
||||
else ""
|
||||
)
|
||||
base_prompt = (
|
||||
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||
if answered_sub_questions
|
||||
else INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||
)
|
||||
|
||||
sub_questions = all_sub_questions # Replace the original assignment
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
@ -275,7 +291,7 @@ def generate_initial_answer(
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=initial_agent_stats,
|
||||
generated_sub_questions=decomp_questions,
|
||||
generated_sub_questions=sub_questions,
|
||||
agent_base_end_time=agent_base_end_time,
|
||||
agent_base_metrics=agent_base_metrics,
|
||||
log_messages=[
|
||||
|
@ -23,10 +23,10 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS,
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
@ -79,7 +79,7 @@ def decompose_orig_question(
|
||||
)
|
||||
|
||||
else:
|
||||
decomposition_prompt = INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format(
|
||||
decomposition_prompt = INITIAL_QUESTION_DECOMPOSITION_PROMPT.format(
|
||||
question=question, history=history
|
||||
)
|
||||
|
||||
|
@ -24,8 +24,9 @@ def format_orig_question_search_output(
|
||||
sub_question_retrieval_stats = sub_question_retrieval_stats
|
||||
|
||||
return OrigQuestionRetrievalUpdate(
|
||||
orig_question_verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
|
||||
orig_question_sub_query_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
|
||||
orig_question_retrieved_documents=state.expanded_retrieval_result.context_documents,
|
||||
orig_question_retrieved_documents=state.retrieved_documents,
|
||||
orig_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
log_messages=[],
|
||||
)
|
||||
|
@ -1,11 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.deep_search.main.states import (
|
||||
OrigQuestionRetrievalUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
|
||||
QuestionRetrievalResult,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
|
||||
ExpandedRetrievalInput,
|
||||
)
|
||||
@ -23,14 +18,14 @@ class BaseRawSearchInput(ExpandedRetrievalInput):
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BaseRawSearchOutput(BaseModel):
|
||||
class BaseRawSearchOutput(OrigQuestionRetrievalUpdate):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
|
||||
# base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
|
||||
|
||||
|
||||
## Graph State
|
||||
|
@ -10,7 +10,9 @@ from onyx.agents.agent_search.deep_search.main.states import (
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ANSWER_COMPARISON_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
@ -28,7 +30,7 @@ def compare_answers(
|
||||
initial_answer = state.initial_answer
|
||||
refined_answer = state.refined_answer
|
||||
|
||||
compare_answers_prompt = ANSWER_COMPARISON_PROMPT.format(
|
||||
compare_answers_prompt = INITIAL_REFINED_ANSWER_COMPARISON_PROMPT.format(
|
||||
question=question, initial_answer=initial_answer, refined_answer=refined_answer
|
||||
)
|
||||
|
||||
|
@ -21,7 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
build_history_prompt,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES,
|
||||
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
@ -78,7 +78,7 @@ def create_refined_sub_questions(
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES.format(
|
||||
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT.format(
|
||||
question=question,
|
||||
history=history,
|
||||
entity_term_extraction_str=entity_term_extraction_str,
|
||||
|
@ -18,7 +18,14 @@ from onyx.agents.agent_search.shared_graph_utils.models import EntityExtractionR
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
|
||||
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Relationship
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Term
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
ENTITY_TERM_EXTRACTION_PROMPT,
|
||||
)
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
@ -57,11 +64,15 @@ def extract_entities_terms(
|
||||
doc_context = format_docs(initial_search_docs)
|
||||
|
||||
doc_context = trim_prompt_piece(
|
||||
graph_config.tooling.fast_llm.config, doc_context, ENTITY_TERM_PROMPT + question
|
||||
graph_config.tooling.fast_llm.config,
|
||||
doc_context,
|
||||
ENTITY_TERM_EXTRACTION_PROMPT + question,
|
||||
)
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
||||
content=ENTITY_TERM_EXTRACTION_PROMPT.format(
|
||||
question=question, context=doc_context
|
||||
),
|
||||
)
|
||||
]
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
|
@ -28,12 +28,14 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_sections,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS,
|
||||
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE_REVISED,
|
||||
REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
SUB_QUESTION_ANSWER_TEMPLATE_REFINED,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
@ -71,12 +73,17 @@ def generate_refined_answer(
|
||||
|
||||
verified_reranked_documents = state.verified_reranked_documents
|
||||
sub_questions_cited_documents = state.cited_documents
|
||||
all_original_question_documents = state.orig_question_retrieved_documents
|
||||
original_question_verified_documents = (
|
||||
state.orig_question_verified_reranked_documents
|
||||
)
|
||||
original_question_retrieved_documents = state.orig_question_retrieved_documents
|
||||
|
||||
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
|
||||
|
||||
counter = 0
|
||||
for original_doc_number, original_doc in enumerate(all_original_question_documents):
|
||||
for original_doc_number, original_doc in enumerate(
|
||||
original_question_verified_documents
|
||||
):
|
||||
if original_doc_number not in sub_questions_cited_documents:
|
||||
if (
|
||||
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
@ -92,16 +99,22 @@ def generate_refined_answer(
|
||||
consolidated_context_docs, consolidated_context_docs
|
||||
)
|
||||
|
||||
streaming_docs = (
|
||||
relevant_docs
|
||||
if len(relevant_docs) > 0
|
||||
else original_question_retrieved_documents[:15]
|
||||
)
|
||||
|
||||
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
# stream refined answer docs
|
||||
# stream refined answer docs, or original question docs if no relevant docs are found
|
||||
relevance_list = relevance_from_docs(relevant_docs)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=relevant_docs,
|
||||
final_context_sections=relevant_docs,
|
||||
reranked_sections=streaming_docs,
|
||||
final_context_sections=streaming_docs,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
@ -124,71 +137,62 @@ def generate_refined_answer(
|
||||
else:
|
||||
refined_doc_effectiveness = 10.0
|
||||
|
||||
decomp_answer_results = state.sub_question_results
|
||||
sub_question_answer_results = state.sub_question_results
|
||||
|
||||
answered_qa_list: list[str] = []
|
||||
decomp_questions = []
|
||||
answered_sub_question_answer_list: list[str] = []
|
||||
sub_questions: list[str] = []
|
||||
initial_answered_sub_questions: set[str] = set()
|
||||
refined_answered_sub_questions: set[str] = set()
|
||||
|
||||
initial_good_sub_questions: list[str] = []
|
||||
new_revised_good_sub_questions: list[str] = []
|
||||
for i, result in enumerate(sub_question_answer_results, 1):
|
||||
question_level, _ = parse_question_id(result.question_id)
|
||||
sub_questions.append(result.question)
|
||||
|
||||
sub_question_num = 1
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
question_level, question_num = parse_question_id(
|
||||
decomp_answer_result.question_id
|
||||
)
|
||||
|
||||
decomp_questions.append(decomp_answer_result.question)
|
||||
if (
|
||||
decomp_answer_result.verified_high_quality
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
||||
result.verified_high_quality
|
||||
and result.answer
|
||||
and result.answer != UNKNOWN_ANSWER
|
||||
):
|
||||
if question_level == 0:
|
||||
initial_good_sub_questions.append(decomp_answer_result.question)
|
||||
sub_question_type = "initial"
|
||||
else:
|
||||
new_revised_good_sub_questions.append(decomp_answer_result.question)
|
||||
sub_question_type = "refined"
|
||||
answered_qa_list.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE_REVISED.format(
|
||||
sub_question=decomp_answer_result.question,
|
||||
sub_answer=decomp_answer_result.answer,
|
||||
sub_question_num=sub_question_num,
|
||||
sub_question_type = "initial" if question_level == 0 else "refined"
|
||||
question_set = (
|
||||
initial_answered_sub_questions
|
||||
if question_level == 0
|
||||
else refined_answered_sub_questions
|
||||
)
|
||||
question_set.add(result.question)
|
||||
|
||||
answered_sub_question_answer_list.append(
|
||||
SUB_QUESTION_ANSWER_TEMPLATE_REFINED.format(
|
||||
sub_question=result.question,
|
||||
sub_answer=result.answer,
|
||||
sub_question_num=i,
|
||||
sub_question_type=sub_question_type,
|
||||
)
|
||||
)
|
||||
|
||||
sub_question_num += 1
|
||||
|
||||
initial_good_sub_questions = list(set(initial_good_sub_questions))
|
||||
new_revised_good_sub_questions = list(set(new_revised_good_sub_questions))
|
||||
total_good_sub_questions = list(
|
||||
set(initial_good_sub_questions + new_revised_good_sub_questions)
|
||||
# Calculate efficiency
|
||||
total_answered_questions = (
|
||||
initial_answered_sub_questions | refined_answered_sub_questions
|
||||
)
|
||||
revision_question_efficiency = (
|
||||
len(total_answered_questions) / len(initial_answered_sub_questions)
|
||||
if initial_answered_sub_questions
|
||||
else 10.0
|
||||
if refined_answered_sub_questions
|
||||
else 1.0
|
||||
)
|
||||
if len(initial_good_sub_questions) > 0:
|
||||
revision_question_efficiency: float = len(total_good_sub_questions) / len(
|
||||
initial_good_sub_questions
|
||||
)
|
||||
elif len(new_revised_good_sub_questions) > 0:
|
||||
revision_question_efficiency = 10.0
|
||||
else:
|
||||
revision_question_efficiency = 1.0
|
||||
|
||||
sub_question_answer_str = "\n\n------\n\n".join(list(set(answered_qa_list)))
|
||||
|
||||
# original answer
|
||||
|
||||
sub_question_answer_str = "\n\n------\n\n".join(
|
||||
set(answered_sub_question_answer_list)
|
||||
)
|
||||
initial_answer = state.initial_answer or ""
|
||||
|
||||
# Determine which persona-specification prompt to use
|
||||
|
||||
# Determine which base prompt to use given the sub-question information
|
||||
if len(answered_qa_list) > 0:
|
||||
base_prompt = REVISED_RAG_PROMPT
|
||||
else:
|
||||
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||
# Choose appropriate prompt template
|
||||
base_prompt = (
|
||||
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS
|
||||
if answered_sub_question_answer_list
|
||||
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||
)
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
relevant_docs_str = format_docs(relevant_docs)
|
||||
@ -211,7 +215,7 @@ def generate_refined_answer(
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=relevant_docs,
|
||||
relevant_docs=relevant_docs_str,
|
||||
initial_answer=remove_document_citations(initial_answer)
|
||||
if initial_answer
|
||||
else None,
|
||||
@ -221,8 +225,6 @@ def generate_refined_answer(
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
for message in model.stream(msg):
|
||||
@ -248,7 +250,7 @@ def generate_refined_answer(
|
||||
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||
streamed_tokens.append(content)
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
|
||||
)
|
||||
dispatch_main_answer_stop_info(1, writer)
|
||||
|
@ -129,6 +129,9 @@ class OrigQuestionRetrievalUpdate(LoggerUpdate):
|
||||
orig_question_retrieved_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
]
|
||||
orig_question_verified_reranked_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
]
|
||||
orig_question_sub_query_retrieval_results: list[QueryRetrievalResult] = []
|
||||
orig_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
|
||||
|
||||
|
@ -7,6 +7,7 @@ from onyx.context.search.models import InferenceSection
|
||||
|
||||
class QuestionRetrievalResult(BaseModel):
|
||||
expanded_query_results: list[QueryRetrievalResult] = []
|
||||
retrieved_documents: list[InferenceSection] = []
|
||||
verified_reranked_documents: list[InferenceSection] = []
|
||||
context_documents: list[InferenceSection] = []
|
||||
retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
|
||||
|
@ -17,7 +17,7 @@ from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states impor
|
||||
)
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL,
|
||||
QUERY_REWRITING_PROMPT,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
@ -47,7 +47,7 @@ def expand_queries(
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
||||
content=QUERY_REWRITING_PROMPT.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -82,6 +82,7 @@ def format_results(
|
||||
return ExpandedRetrievalUpdate(
|
||||
expanded_retrieval_result=QuestionRetrievalResult(
|
||||
expanded_query_results=state.query_retrieval_results,
|
||||
retrieved_documents=state.retrieved_documents,
|
||||
verified_reranked_documents=reranked_documents,
|
||||
context_documents=state.reranked_documents,
|
||||
retrieval_stats=sub_question_retrieval_stats,
|
||||
|
@ -13,7 +13,9 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
||||
DOCUMENT_VERIFICATION_PROMPT,
|
||||
)
|
||||
|
||||
|
||||
def verify_documents(
|
||||
@ -38,12 +40,12 @@ def verify_documents(
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
|
||||
document_content = trim_prompt_piece(
|
||||
fast_llm.config, document_content, VERIFIER_PROMPT + question
|
||||
fast_llm.config, document_content, DOCUMENT_VERIFICATION_PROMPT + question
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
content=DOCUMENT_VERIFICATION_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
|
@ -60,6 +60,9 @@ class ExpandedRetrievalUpdate(LoggerUpdate, BaseModel):
|
||||
class ExpandedRetrievalOutput(LoggerUpdate, BaseModel):
|
||||
expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
|
||||
base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
|
||||
retrieved_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
] = []
|
||||
|
||||
|
||||
## Graph State
|
||||
|
@ -7,8 +7,8 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
AgentPromptEnrichmentComponents,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_FRAMING_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_QUESTION_RAG_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
@ -45,10 +45,12 @@ def build_sub_question_answer_prompt(
|
||||
docs_str = "\n\n".join(docs_format_list)
|
||||
|
||||
docs_str = trim_prompt_piece(
|
||||
config, docs_str, BASE_RAG_PROMPT_v2 + question + original_question + date_str
|
||||
config,
|
||||
docs_str,
|
||||
SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
|
||||
)
|
||||
human_message = HumanMessage(
|
||||
content=BASE_RAG_PROMPT_v2.format(
|
||||
content=SUB_QUESTION_RAG_PROMPT.format(
|
||||
question=question,
|
||||
original_question=original_question,
|
||||
context=docs_str,
|
||||
@ -118,7 +120,7 @@ def build_history_prompt(config: GraphConfig, question: str) -> str:
|
||||
if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH:
|
||||
history = summarize_history(history, question, persona_base, model)
|
||||
|
||||
return HISTORY_PROMPT.format(history=history) if history else ""
|
||||
return HISTORY_FRAMING_PROMPT.format(history=history) if history else ""
|
||||
|
||||
|
||||
def get_prompt_enrichment_components(
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -73,7 +73,7 @@ def format_docs(docs: Sequence[InferenceSection]) -> str:
|
||||
formatted_doc_list = []
|
||||
|
||||
for doc_num, doc in enumerate(docs):
|
||||
formatted_doc_list.append(f"Document D{doc_num + 1}:\n{doc.combined_content}")
|
||||
formatted_doc_list.append(f"**Document D{doc_num + 1}:\n{doc.combined_content}")
|
||||
|
||||
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user