refined search + question answering as sub-graphs

This commit is contained in:
joachim-danswer 2025-01-26 12:03:28 -08:00 committed by Evan Lohn
parent 4baf3dc484
commit 8c9577aa95
13 changed files with 391 additions and 125 deletions

View File

@ -0,0 +1,55 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionInput,
)
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.states import (
SearchSQState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
def parallelize_initial_sub_question_answering(
state: SearchSQState,
) -> list[Send | Hashable]:
now_start = datetime.now()
if len(state.initial_decomp_questions) > 0:
# sub_question_record_ids = [subq_record.id for subq_record in state["sub_question_records"]]
# if len(state["sub_question_records"]) == 0:
# if state["config"].use_persistence:
# raise ValueError("No sub-questions found for initial decompozed questions")
# else:
# # in this case, we are doing retrieval on the original question.
# # to make all the logic consistent, we create a new sub-question
# # with the same content as the original question
# sub_question_record_ids = [1] * len(state["initial_decomp_questions"])
return [
Send(
"answer_query_subgraph",
AnswerQuestionInput(
question=question,
question_id=make_question_id(0, question_nr + 1),
log_messages=[
f"{now_start} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_nr, question in enumerate(state.initial_decomp_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@ -0,0 +1,170 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.graph_builder import (
answer_query_graph_builder,
)
from onyx.agents.agent_search.deep_search_a.base_raw_search.graph_builder import (
base_raw_search_graph_builder,
)
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.edges import (
parallelize_initial_sub_question_answering,
)
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.nodes.generate_initial_answer import (
generate_initial_answer,
)
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.nodes.ingest_initial_base_retrieval import (
ingest_initial_base_retrieval,
)
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.nodes.ingest_initial_sub_question_answers import (
ingest_initial_sub_question_answers,
)
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.nodes.initial_answer_quality_check import (
initial_answer_quality_check,
)
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.nodes.initial_sub_question_creation import (
initial_sub_question_creation,
)
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.nodes.retrieval_consolidation import (
retrieval_consolidation,
)
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.states import (
SearchSQInput,
)
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.states import (
SearchSQState,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def initial_search_sq_subgraph_builder(test_mode: bool = False) -> StateGraph:
graph = StateGraph(
state_schema=SearchSQState,
input=SearchSQInput,
)
graph.add_node(
node="initial_sub_question_creation",
action=initial_sub_question_creation,
)
answer_query_subgraph = answer_query_graph_builder().compile()
graph.add_node(
node="answer_query_subgraph",
action=answer_query_subgraph,
)
base_raw_search_subgraph = base_raw_search_graph_builder().compile()
graph.add_node(
node="base_raw_search_subgraph",
action=base_raw_search_subgraph,
)
graph.add_node(
node="ingest_initial_retrieval",
action=ingest_initial_base_retrieval,
)
graph.add_node(
node="retrieval_consolidation",
action=retrieval_consolidation,
)
graph.add_node(
node="ingest_initial_sub_question_answers",
action=ingest_initial_sub_question_answers,
)
graph.add_node(
node="generate_initial_answer",
action=generate_initial_answer,
)
graph.add_node(
node="initial_answer_quality_check",
action=initial_answer_quality_check,
)
### Add edges ###
# raph.add_edge(start_key=START, end_key="base_raw_search_subgraph")
graph.add_edge(
start_key=START,
end_key="base_raw_search_subgraph",
)
# graph.add_edge(
# start_key="agent_search_start",
# end_key="entity_term_extraction_llm",
# )
graph.add_edge(
start_key=START,
end_key="initial_sub_question_creation",
)
graph.add_edge(
start_key="base_raw_search_subgraph",
end_key="ingest_initial_retrieval",
)
graph.add_edge(
start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
end_key="retrieval_consolidation",
)
graph.add_edge(
start_key="retrieval_consolidation",
end_key="generate_initial_answer",
)
# graph.add_edge(
# start_key="LLM",
# end_key=END,
# )
# graph.add_edge(
# start_key=START,
# end_key="initial_sub_question_creation",
# )
graph.add_conditional_edges(
source="initial_sub_question_creation",
path=parallelize_initial_sub_question_answering,
path_map=["answer_query_subgraph"],
)
graph.add_edge(
start_key="answer_query_subgraph",
end_key="ingest_initial_sub_question_answers",
)
graph.add_edge(
start_key="retrieval_consolidation",
end_key="generate_initial_answer",
)
graph.add_edge(
start_key="generate_initial_answer",
end_key="initial_answer_quality_check",
)
graph.add_edge(
start_key="initial_answer_quality_check",
end_key=END,
)
# graph.add_edge(
# start_key="generate_refined_answer",
# end_key="check_refined_answer",
# )
# graph.add_edge(
# start_key="check_refined_answer",
# end_key=END,
# )
return graph

View File

@ -7,6 +7,9 @@ from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.states import (
SearchSQState,
)
from onyx.agents.agent_search.deep_search_a.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search_a.main.operations import (
calculate_initial_agent_stats,
@ -17,7 +20,6 @@ from onyx.agents.agent_search.deep_search_a.main.operations import (
remove_document_citations,
)
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpdate
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
@ -56,7 +58,7 @@ from onyx.tools.tool_implementations.search.search_tool import yield_search_resp
def generate_initial_answer(
state: MainState, config: RunnableConfig
state: SearchSQState, config: RunnableConfig
) -> InitialAnswerUpdate:
now_start = datetime.now()

View File

@ -1,13 +1,15 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.states import (
SearchSQState,
)
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import (
InitialAnswerQualityUpdate,
)
from onyx.agents.agent_search.deep_search_a.main.states import MainState
def initial_answer_quality_check(state: MainState) -> InitialAnswerQualityUpdate:
def initial_answer_quality_check(state: SearchSQState) -> InitialAnswerQualityUpdate:
"""
Check whether the final output satisfies the original user question

View File

@ -6,11 +6,13 @@ from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.states import (
SearchSQState,
)
from onyx.agents.agent_search.deep_search_a.main.models import AgentRefinedMetrics
from onyx.agents.agent_search.deep_search_a.main.operations import dispatch_subquestion
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import BaseDecompUpdate
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
@ -29,7 +31,7 @@ from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
def initial_sub_question_creation(
state: MainState, config: RunnableConfig
state: SearchSQState, config: RunnableConfig
) -> BaseDecompUpdate:
now_start = datetime.now()

View File

@ -1,11 +1,13 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.states import (
SearchSQState,
)
from onyx.agents.agent_search.deep_search_a.main.states import LoggerUpdate
from onyx.agents.agent_search.deep_search_a.main.states import MainState
def retrieval_consolidation(
state: MainState,
state: SearchSQState,
) -> LoggerUpdate:
now_start = datetime.now()

View File

@ -0,0 +1,46 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search_a.expanded_retrieval.models import (
ExpandedRetrievalResult,
)
from onyx.agents.agent_search.deep_search_a.main.states import BaseDecompUpdate
from onyx.agents.agent_search.deep_search_a.main.states import DecompAnswersUpdate
from onyx.agents.agent_search.deep_search_a.main.states import ExpandedRetrievalUpdate
from onyx.agents.agent_search.deep_search_a.main.states import ExploratorySearchUpdate
from onyx.agents.agent_search.deep_search_a.main.states import (
InitialAnswerQualityUpdate,
)
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpdate
### States ###
class SearchSQInput(CoreState):
pass
## Graph State
class SearchSQState(
# This includes the core state
SearchSQInput,
BaseDecompUpdate,
InitialAnswerUpdate,
DecompAnswersUpdate,
ExpandedRetrievalUpdate,
InitialAnswerQualityUpdate,
ExploratorySearchUpdate,
):
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add]
## Graph Output State - presently not used
class SearchSQOutput(TypedDict):
log_messages: list[str]

View File

@ -2,21 +2,15 @@ from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search_a.answer_initial_sub_question.graph_builder import (
answer_query_graph_builder,
)
from onyx.agents.agent_search.deep_search_a.answer_refinement_sub_question.graph_builder import (
answer_refined_query_graph_builder,
)
from onyx.agents.agent_search.deep_search_a.base_raw_search.graph_builder import (
base_raw_search_graph_builder,
from onyx.agents.agent_search.deep_search_a.initial_search_sq_subgraph.graph_builder import (
initial_search_sq_subgraph_builder,
)
from onyx.agents.agent_search.deep_search_a.main.edges import (
continue_to_refined_answer_or_end,
)
from onyx.agents.agent_search.deep_search_a.main.edges import (
parallelize_initial_sub_question_answering,
)
from onyx.agents.agent_search.deep_search_a.main.edges import (
parallelize_refined_sub_question_answering,
)
@ -32,44 +26,21 @@ from onyx.agents.agent_search.deep_search_a.main.nodes.agent_search_start import
from onyx.agents.agent_search.deep_search_a.main.nodes.answer_comparison import (
answer_comparison,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.entity_term_extraction_llm import (
entity_term_extraction_llm,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.direct_llm_handling import (
direct_llm_handling,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_initial_answer import (
generate_initial_answer,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import (
generate_refined_answer,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_initial_base_retrieval import (
ingest_initial_base_retrieval,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_initial_sub_question_answers import (
ingest_initial_sub_question_answers,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.ingest_refined_answers import (
ingest_refined_answers,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.initial_answer_quality_check import (
initial_answer_quality_check,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.initial_sub_question_creation import (
initial_sub_question_creation,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_answer_decision import (
refined_answer_decision,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.refined_sub_question_creation import (
refined_sub_question_creation,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.retrieval_consolidation import (
retrieval_consolidation,
)
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
@ -130,21 +101,28 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
action=agent_search_start,
)
# graph.add_node(
# node="initial_sub_question_creation",
# action=initial_sub_question_creation,
# )
initial_search_sq_subgraph = initial_search_sq_subgraph_builder().compile()
graph.add_node(
node="initial_sub_question_creation",
action=initial_sub_question_creation,
)
answer_query_subgraph = answer_query_graph_builder().compile()
graph.add_node(
node="answer_query_subgraph",
action=answer_query_subgraph,
node="initial_search_sq_subgraph",
action=initial_search_sq_subgraph,
)
base_raw_search_subgraph = base_raw_search_graph_builder().compile()
graph.add_node(
node="base_raw_search_subgraph",
action=base_raw_search_subgraph,
)
# answer_query_subgraph = answer_query_graph_builder().compile()
# graph.add_node(
# node="answer_query_subgraph",
# action=answer_query_subgraph,
# )
# base_raw_search_subgraph = base_raw_search_graph_builder().compile()
# graph.add_node(
# node="base_raw_search_subgraph",
# action=base_raw_search_subgraph,
# )
# refined_answer_subgraph = refined_answers_graph_builder().compile()
# graph.add_node(
@ -178,34 +156,34 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
# action=check_refined_answer,
# )
graph.add_node(
node="ingest_initial_retrieval",
action=ingest_initial_base_retrieval,
)
graph.add_node(
node="retrieval_consolidation",
action=retrieval_consolidation,
)
graph.add_node(
node="ingest_initial_sub_question_answers",
action=ingest_initial_sub_question_answers,
)
graph.add_node(
node="generate_initial_answer",
action=generate_initial_answer,
)
graph.add_node(
node="initial_answer_quality_check",
action=initial_answer_quality_check,
)
# graph.add_node(
# node="ingest_initial_retrieval",
# action=ingest_initial_base_retrieval,
# )
# graph.add_node(
# node="entity_term_extraction_llm",
# action=entity_term_extraction_llm,
# node="retrieval_consolidation",
# action=retrieval_consolidation,
# )
# graph.add_node(
# node="ingest_initial_sub_question_answers",
# action=ingest_initial_sub_question_answers,
# )
# graph.add_node(
# node="generate_initial_answer",
# action=generate_initial_answer,
# )
# graph.add_node(
# node="initial_answer_quality_check",
# action=initial_answer_quality_check,
# )
graph.add_node(
node="entity_term_extraction_llm",
action=entity_term_extraction_llm,
)
graph.add_node(
node="refined_answer_decision",
action=refined_answer_decision,
@ -261,33 +239,37 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
graph.add_edge(
start_key="agent_search_start",
end_key="base_raw_search_subgraph",
end_key="initial_search_sq_subgraph",
)
# graph.add_edge(
# start_key="agent_search_start",
# end_key="entity_term_extraction_llm",
# end_key="base_raw_search_subgraph",
# )
graph.add_edge(
start_key="agent_search_start",
end_key="initial_sub_question_creation",
end_key="entity_term_extraction_llm",
)
graph.add_edge(
start_key="base_raw_search_subgraph",
end_key="ingest_initial_retrieval",
)
# graph.add_edge(
# start_key="agent_search_start",
# end_key="initial_sub_question_creation",
# )
graph.add_edge(
start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
end_key="retrieval_consolidation",
)
# graph.add_edge(
# start_key="base_raw_search_subgraph",
# end_key="ingest_initial_retrieval",
# )
graph.add_edge(
start_key="retrieval_consolidation",
end_key="generate_initial_answer",
)
# graph.add_edge(
# start_key=["ingest_initial_retrieval", "ingest_initial_sub_question_answers"],
# end_key="retrieval_consolidation",
# )
# graph.add_edge(
# start_key="retrieval_consolidation",
# end_key="generate_initial_answer",
# )
# graph.add_edge(
# start_key="LLM",
@ -299,37 +281,42 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
# end_key="initial_sub_question_creation",
# )
graph.add_conditional_edges(
source="initial_sub_question_creation",
path=parallelize_initial_sub_question_answering,
path_map=["answer_query_subgraph"],
)
graph.add_edge(
start_key="answer_query_subgraph",
end_key="ingest_initial_sub_question_answers",
)
# graph.add_conditional_edges(
# source="initial_sub_question_creation",
# path=parallelize_initial_sub_question_answering,
# path_map=["answer_query_subgraph"],
# )
# graph.add_edge(
# start_key="answer_query_subgraph",
# end_key="ingest_initial_sub_question_answers",
# )
graph.add_edge(
start_key="retrieval_consolidation",
end_key="generate_initial_answer",
)
# graph.add_edge(
# start_key="retrieval_consolidation",
# end_key="generate_initial_answer",
# )
# graph.add_edge(
# start_key="generate_initial_answer",
# end_key="entity_term_extraction_llm",
# )
graph.add_edge(
start_key="generate_initial_answer",
end_key="initial_answer_quality_check",
)
# graph.add_edge(
# start_key="generate_initial_answer",
# end_key="initial_answer_quality_check",
# )
# graph.add_edge(
# start_key=["initial_answer_quality_check", "entity_term_extraction_llm"],
# end_key="refined_answer_decision",
# )
# graph.add_edge(
# start_key="initial_answer_quality_check",
# end_key="refined_answer_decision",
# )
graph.add_edge(
start_key="initial_answer_quality_check",
start_key=["initial_search_sq_subgraph", "entity_term_extraction_llm"],
end_key="refined_answer_decision",
)

View File

@ -17,11 +17,14 @@ from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
format_entity_term_extraction,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION
from onyx.tools.models import ToolCallKickoff
@ -51,14 +54,10 @@ def refined_sub_question_creation(
base_answer = state.initial_answer
history = build_history_prompt(agent_a_config.prompt_builder)
# get the entity term extraction dict and properly format it
# entity_retlation_term_extractions = state.entity_relation_term_extractions
entity_retlation_term_extractions = state.entity_relation_term_extractions
# entity_term_extraction_str = format_entity_term_extraction(
# entity_retlation_term_extractions
# )
docs_str = format_docs(
state.all_original_question_documents[:AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION]
entity_term_extraction_str = format_entity_term_extraction(
entity_retlation_term_extractions
)
initial_question_answers = state.decomp_answer_results
@ -73,10 +72,10 @@ def refined_sub_question_creation(
msg = [
HumanMessage(
content=DEEP_DECOMPOSE_PROMPT.format(
content=DEEP_DECOMPOSE_PROMPT_WITH_ENTITIES.format(
question=question,
history=history,
docs_str=docs_str,
entity_term_extraction_str=entity_term_extraction_str,
base_answer=base_answer,
answered_sub_questions="\n - ".join(addressed_question_list),
failed_sub_questions="\n - ".join(failed_question_list),

View File

@ -213,7 +213,8 @@ if __name__ == "__main__":
# query="What are the temperatures in Munich, Hawaii, and New York?",
# query="When was Washington born?",
# query="What is Onyx?",
query="What is the difference between astronomy and astrology?",
# query="What is the difference between astronomy and astrology?",
query="Do a search to tell me what is the difference between astronomy and astrology?",
)
# Joachim custom persona

View File

@ -207,7 +207,7 @@ MODIFIED_RAG_PROMPT = (
Answer:"""
)
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
ERT_INFORMED_DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
@ -284,7 +284,7 @@ ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
DEEP_DECOMPOSE_PROMPT = """ \n
DOC_INFORMED_DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they